#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from hwt.code import If
from hwt.hdl.constants import Time
from hwt.interfaces.std import VectSignal
from hwt.simulator.simTestCase import SimTestCase
from hwt.synthesizer.param import Param
from hwt.synthesizer.unit import Unit
from hwt.synthesizer.hObjList import HObjList
[docs]class BitonicSorter(Unit):
"""
Bitonic sorter of arbitrary data
.. hwt-autodoc::
"""
[docs] def __init__(self, cmpFn=lambda x, y: x < y):
"""
:param cmpFn: function (item0, item1) if returns true,
items are not swaped
"""
Unit.__init__(self)
self.cmpFn = cmpFn
def _config(self):
self.ITEMS = Param(2)
self.DATA_WIDTH = Param(64)
self.SIGNED = Param(False)
def _declr(self):
w = self.DATA_WIDTH
sig = bool(self.SIGNED)
self.inputs = HObjList(
VectSignal(w, sig) for _ in range(int(self.ITEMS))
)
self.outputs = HObjList(
VectSignal(w, sig)._m() for _ in range(int(self.ITEMS))
)
[docs] def bitonic_sort(self, cmpFn, x, layer=0, offset=0):
if len(x) <= 1:
return x
else:
_offset = len(x) // 2
first = self.bitonic_sort(cmpFn,
x[:_offset],
layer, offset)
second = self.bitonic_sort(lambda x, y: ~cmpFn(x, y),
x[_offset:],
layer, offset + _offset)
return self.bitonic_merge(cmpFn, first + second,
layer=layer + _offset, offset=offset)
[docs] def bitonic_merge(self, cmpFn, x, layer, offset):
# assume input x is bitonic, and sorted list is returned
if len(x) == 1:
return x
else:
x = self.bitonic_compare(cmpFn, x, layer, offset)
_offset = len(x) // 2
first = self.bitonic_merge(cmpFn, x[:_offset],
layer + 1, offset)
second = self.bitonic_merge(cmpFn, x[_offset:],
layer + 1, offset + _offset)
return first + second
[docs] def bitonic_compare(self, cmpFn, x, layer, offset):
dist = len(x) // 2
_x = [self._sig(f"sort_tmp_{layer:d}_{offset:d}", x[0]._dtype) for _ in x]
for i in range(dist):
If(cmpFn(x[i], x[i + dist]),
# keep
_x[i](x[i]),
_x[i + dist](x[i + dist])
).Else(
# swap
_x[i](x[i + dist]),
_x[i + dist](x[i]),
)
return _x
def _impl(self):
outs = self.bitonic_sort(self.cmpFn, self.inputs)
for o, otmp in zip(self.outputs, outs):
o(otmp)
[docs]class BitonicSorterTC(SimTestCase):
SIM_TIME = 40 * Time.ns
[docs] @classmethod
def setUpClass(cls):
cls.u = BitonicSorter()
cls.compileSim(cls.u)
[docs] def getOutputs(self):
return [outp._ag.data[-1] for outp in self.u.outputs]
[docs] def test_reversed(self):
u = self.u
ref = [i for i in range(int(u.ITEMS))]
self.setInputs(reversed(ref))
self.runSim(self.SIM_TIME)
self.assertValSequenceEqual(self.getOutputs(), ref)
[docs] def test_sorted(self):
u = self.u
ref = [i for i in range(int(u.ITEMS))]
self.setInputs(ref)
self.runSim(self.SIM_TIME)
self.assertValSequenceEqual(self.getOutputs(), ref)
if __name__ == "__main__":
import unittest
from hwt.synthesizer.utils import to_rtl_str
u = BitonicSorter()
print(to_rtl_str(u))
testLoader = unittest.TestLoader()
# suite = unittest.TestSuite([BitonicSorterTC("test_sorted")])
suite = testLoader.loadTestsFromTestCase(BitonicSorterTC)
runner = unittest.TextTestRunner(verbosity=3)
runner.run(suite)