Source code for hwtLib.amba.axi_comp.cache.pseudo_lru

from operator import ne
from typing import List, Dict

from hwt.code import Concat, And, Or
from hwt.code_utils import _mkOp
from hwt.math import isPow2, log2ceil
from hwt.synthesizer.rtlLevel.rtlSignal import RtlSignal


[docs]def parity(bit_vector): return _mkOp(ne)(*bit_vector)
# https://chipress.co/2019/07/09/how-to-implement-pseudo-lru/
[docs]class PseudoLru(): """ Tree-PLRU, Pseudo Last Recently Used (LRU) algorithm * Often used to select least used value in caches etc. Example for four-way set associative cache (three bits) each bit represents one branch point in a binary decision tree; let 1 represent that the left side has been referenced more recently than the right side, and 0 vice-versa .. code-block:: are all 4 lines valid? / \ yes no, use an invalid line | | | bit_0 == 0? state | replace ref to | next state / \ ------+-------- -------+----------- y n 00x | line_0 line_0 | 11_ / \ 01x | line_1 line_1 | 10_ bit_1 == 0? bit_2 == 0? 1x0 | line_2 line_2 | 0_1 / \ / \ 1x1 | line_3 line_3 | 0_0 y n y n / \ / \ ('x' means ('_' means unchanged) line_0 line_1 line_2 line_3 don't care) :note: that there is a 6-bit encoding for true LRU for four-way set associative bit 0: bank[1] more recently used than bank[0] bit 1: bank[2] more recently used than bank[0] bit 2: bank[2] more recently used than bank[1] bit 3: bank[3] more recently used than bank[0] bit 4: bank[3] more recently used than bank[1] bit 5: bank[3] more recently used than bank[2] :note: this is not a component in order to make this alg independent on lru reg storage type :ivar lru_reg: register with bits which represents binary tree used in pseudo LRU. It uses a common binary tree in array node representation index of left is 2x parent index; index of right is 2x parent index + 1 """
[docs] @staticmethod def lru_reg_width(items): return 2 ** log2ceil(items) - 1
[docs] @staticmethod def lru_reg_items(width): return 2 ** log2ceil(width + 1)
[docs] def __init__(self, lru_reg: RtlSignal): assert isPow2(lru_reg._dtype.bit_length() - 1) or lru_reg._dtype.bit_length() == 1, lru_reg._dtype.bit_length() self.lru_regs = lru_reg
[docs] def node_selected_mask(self, lru_tree, node_i): """ :ivar lru_tree: array with lru binary tree, nodes are lru registers, leafs are select flags :ivar node_i: index of node which we are checking """ is_leaf = 2 * node_i + 1 >= self.lru_regs._dtype.bit_length() if is_leaf: yield lru_tree[node_i] else: # right yield from self.node_selected_mask(lru_tree, 2 * node_i + 1) # left yield from self.node_selected_mask(lru_tree, 2 * node_i + 2)
[docs] def mark_use_many(self, used_item_mask): """ Mark values as used just now """ lru_tree = [*self.lru_regs, *used_item_mask] invert_mask = [] for i in range(self.lru_regs._dtype.bit_length()): # flip lru node if it is accessed odd-number times do_invert = parity(self.node_selected_mask(lru_tree, i)) invert_mask.append(do_invert) return self.lru_regs ^ Concat(*reversed(invert_mask))
[docs] def _build_node_paths(self, node_paths: Dict[int, List[RtlSignal]], i: int, prefix: List[RtlSignal]): """ Collect in tree paths for items. """ is_last_level = 2 * i + 1 >= self.lru_regs._dtype.bit_length() this_node_bit = self.lru_regs[i] if is_last_level: this_node_bit = ~this_node_bit node_paths[i] = [*prefix, this_node_bit] else: node_paths[i] = [*prefix, this_node_bit] # left self._build_node_paths(node_paths, 2 * i + 1, [*prefix, ~this_node_bit]) # right self._build_node_paths(node_paths, 2 * i + 2, [*prefix, this_node_bit])
[docs] def get_lru(self): """ To find LRU, we can perform a depth-first-search starting from root, and traverse nodes in lower levels. If the node is 0, then we traverse the left sub-tree; otherwise, we traverse the right sub-tree. In the diagram above, the LRU is set 3. """ # node_index: bits rlu register node_paths = {} self._build_node_paths(node_paths, 0, tuple()) # also number of levels of rlu tree bin_index_w = log2ceil(self.lru_reg_items(self.lru_regs._dtype.bit_length())) lru_index_bin = [] # msb first in lru binary index for output_bit_i in range(bin_index_w): items_on_current_level = int(2 ** output_bit_i) current_level_offset = 2 ** output_bit_i - 1 possible_paths = [] for node_i in range( current_level_offset, current_level_offset + items_on_current_level): p = node_paths[node_i] possible_paths.append(And(*p)) lru_index_bin.append(Or(*possible_paths)) # MSB was first so the result is in little endian MSB..LSB return Concat(*lru_index_bin)