Source code for hwtLib.amba.axi_comp.cache.lru_array

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from hwt.code import If, Or, Concat
from hwt.code_utils import rename_signal
from hwt.hdl.constants import WRITE, READ
from hwt.hdl.types.defs import BIT
from hwt.hdl.types.struct import HStruct
from hwt.interfaces.std import VectSignal, Handshaked, HandshakeSync
from hwt.interfaces.utils import addClkRstn, propagateClkRstn
from hwt.math import log2ceil
from hwt.pyUtils.arrayQuery import flatten, grouper
from hwt.synthesizer.hObjList import HObjList
from hwt.synthesizer.param import Param
from hwtLib.amba.axi_comp.cache.addrTypeConfig import CacheAddrTypeConfig
from hwtLib.amba.axi_comp.cache.pseudo_lru import PseudoLru
from hwtLib.common_nonstd_interfaces.addr_data_hs import AddrDataHs
from hwtLib.common_nonstd_interfaces.addr_hs import AddrHs
from hwtLib.logic.binToOneHot import binToOneHot
from hwtLib.mem.ramXor import RamXorSingleClock


# extract variant with id
[docs]class IndexWayHs(HandshakeSync): def _config(self): self.ID_WIDTH = Param(0) self.INDEX_WIDTH = Param(10) self.WAY_CNT = Param(4) def _declr(self): if self.ID_WIDTH: self.id = VectSignal(self.ID_WIDTH) self.index = VectSignal(self.INDEX_WIDTH) if self.WAY_CNT > 1: self.way = VectSignal(log2ceil(self.WAY_CNT - 1)) HandshakeSync._declr(self)
[docs]class AxiCacheLruArray(CacheAddrTypeConfig): """ A memory storing Tree-PLRU records with multiple ports. The access using various ports is merged together. The victim_req port also marks the way as lastly used. The set port dissables all discards all pending updates and it is ment to be used for an intialization of the array/cache. .. figure:: ./_static/AxiCacheLruArray.png .. hwt-autodoc:: """ def _config(self): CacheAddrTypeConfig._config(self) self.INCR_PORT_CNT = Param(2) self.WAY_CNT = Param(4)
[docs] def _compute_constants(self): assert self.WAY_CNT >= 1, self.WAY_CNT self._compupte_tag_index_offset_widths() self.LRU_WIDTH = PseudoLru.lru_reg_width(self.WAY_CNT)
def _declr(self): self._compute_constants() addClkRstn(self) # used to initialize the LRU data (in the case of cache reset) # while set port is active all other ports are blocked s = self.set = AddrDataHs() s.ADDR_WIDTH = self.INDEX_W s.DATA_WIDTH = self.LRU_WIDTH # used to increment the LRU data in the case of hit self.incr = HObjList(IndexWayHs() for _ in range(self.INCR_PORT_CNT)) for i in self.incr: i.INDEX_WIDTH = self.INDEX_W i.WAY_CNT = self.WAY_CNT # get a victim for a selected cacheline index # The cacheline returned as a victim is also marked as used just now vr = self.victim_req = AddrHs() vr.ADDR_WIDTH = self.INDEX_W vr.ID_WIDTH = 0 vd = self.victim_data = Handshaked()._m() vd.DATA_WIDTH = log2ceil(self.WAY_CNT - 1) m = self.lru_mem = RamXorSingleClock() m.ADDR_WIDTH = self.INDEX_W m.DATA_WIDTH = self.LRU_WIDTH m.PORT_CNT = ( # victim_req preload, victim_req write back or set, READ, WRITE, # incr preload, incr write back... *flatten((READ, WRITE) for _ in range(self.INCR_PORT_CNT)) )
[docs] def merge_successor_writes_into_incr_one_hot(self, succ_writes, incr_val_oh): if succ_writes: for succ_write_vld, succ_write_oh in succ_writes: en_mask = Concat(*[succ_write_vld for _ in range(succ_write_oh._dtype.bit_length())]) incr_val_oh = incr_val_oh | (succ_write_oh & en_mask) return incr_val_oh
def _impl(self): m = self.lru_mem victim_req_r, victim_req_w = m.port[:2] # victim selection ports victim_req = self.victim_req victim_req_r.en(victim_req.vld) victim_req_r.addr(victim_req.addr) victim_req_tmp = self._reg( "victim_req_tmp", HStruct( (victim_req.addr._dtype, "index"), (BIT, "vld") ), def_val={"vld": 0} ) set_ = self.set victim_data = self.victim_data victim_req.rd(~set_.vld & (~victim_req_tmp.vld | victim_data.rd)) If((~victim_req_tmp.vld | victim_data.rd) & ~set_.vld, victim_req_tmp.index(victim_req.addr), victim_req_tmp.vld(victim_req.vld), ) incr_rw = list(grouper(2, m.port[2:])) # in the first stp we have to collect all pending addresses # because we need it in order to resolve any potential access merging incr_tmp_mask_oh = [] for i, (incr_in, (incr_r, _)) in enumerate(zip(self.incr, incr_rw)): incr_tmp = self._reg( f"incr_tmp{i:d}", HStruct( (incr_in.index._dtype, "index"), (incr_in.way._dtype, "way"), (BIT, "vld") ), def_val={"vld": 0} ) incr_tmp.index(incr_in.index) incr_tmp.way(incr_in.way) incr_tmp.vld(incr_in.vld & ~set_.vld) incr_val_oh = rename_signal(self, binToOneHot(incr_in.way), f"incr_val{i:d}_oh") incr_tmp_mask_oh.append((incr_tmp, incr_val_oh)) lru = PseudoLru(victim_req_r.dout) victim = rename_signal(self, lru.get_lru(), "victim") victim_oh = rename_signal(self, binToOneHot(victim), "victim_oh") victim_data.data(victim) victim_data.vld(victim_req_tmp.vld & ~set_.vld) succ_writes = [ (incr2_tmp.vld & incr2_tmp.index._eq(victim_req_tmp.index), incr2_val_oh) for incr2_tmp, incr2_val_oh in incr_tmp_mask_oh ] _victim_oh = self.merge_successor_writes_into_incr_one_hot(succ_writes, victim_oh) _victim_oh = rename_signal(self, _victim_oh, f"victim_val{i:d}_oh_final") set_.rd(1) If(set_.vld, # repurpose victim req victim_req_w port for an intialization of LRU array using "set" port victim_req_w.en(1), victim_req_w.addr(set_.addr), victim_req_w.din(set_.data), ).Else( # use victim_req_w port for a victim req write back as usuall victim_req_w.en(victim_req_tmp.vld), victim_req_w.addr(victim_req_tmp.index), victim_req_w.din(lru.mark_use_many(_victim_oh)), ) for i, (incr_in, (incr_r, incr_w), (incr_tmp, incr_val_oh)) in enumerate(zip(self.incr, incr_rw, incr_tmp_mask_oh)): # drive incr_r port incr_r.addr(incr_in.index) incr_r.en(incr_in.vld) incr_in.rd(~set_.vld) # resolve the final mask of LRU incrementation for this port and drive incr_w prev_writes = [ victim_req_tmp.vld & victim_req_tmp.index._eq(incr_tmp.index) ] + [ incr2_tmp.vld & incr2_tmp.index._eq(incr_tmp.index) for incr2_tmp, _ in incr_tmp_mask_oh[:i] ] # if any of previous port writes to same index we need to ommit this write as it is writen by some previous port incr_w.addr(incr_tmp.index) incr_w.en(incr_tmp.vld & ~Or(*prev_writes) & ~set_.vld) succ_writes = [ (incr2_tmp.vld & incr2_tmp.index._eq(incr_tmp.index), incr2_val_oh) for incr2_tmp, incr2_val_oh in incr_tmp_mask_oh[i + 1:] ] # if collides with others merge the incr_val_oh incr_val_oh = self.merge_successor_writes_into_incr_one_hot(succ_writes, incr_val_oh) incr_val_oh = rename_signal(self, incr_val_oh, f"incr_val{i:d}_oh_final") incr_w.din(PseudoLru(incr_r.dout).mark_use_many(incr_val_oh)) propagateClkRstn(self)
if __name__ == "__main__": from hwt.synthesizer.utils import to_rtl_str u = AxiCacheLruArray() print(to_rtl_str(u))