#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from hwt.code import If
from hwt.constants import Time
from hwt.hwIOs.hwIOArray import HwIOArray
from hwt.hwIOs.std import HwIOVectSignal
from hwt.hwModule import HwModule
from hwt.hwParam import HwParam
from hwt.pyUtils.typingFuture import override
from hwt.simulator.simTestCase import SimTestCase
[docs]
class BitonicSorter(HwModule):
"""
Bitonic sorter of arbitrary data
.. hwt-autodoc::
"""
[docs]
def __init__(self, cmpFn=lambda x, y: x < y):
"""
:param cmpFn: function (item0, item1) if returns true,
items are not swaped
"""
HwModule.__init__(self)
self.cmpFn = cmpFn
@override
def hwConfig(self):
self.ITEMS = HwParam(2)
self.DATA_WIDTH = HwParam(64)
self.SIGNED = HwParam(False)
@override
def hwDeclr(self):
w = self.DATA_WIDTH
sig = bool(self.SIGNED)
self.inputs = HwIOArray(
HwIOVectSignal(w, sig) for _ in range(int(self.ITEMS))
)
self.outputs = HwIOArray(
HwIOVectSignal(w, sig) for _ in range(int(self.ITEMS))
)._m()
[docs]
def bitonic_sort(self, cmpFn, x, layer=0, offset=0):
if len(x) <= 1:
return x
else:
_offset = len(x) // 2
first = self.bitonic_sort(cmpFn,
x[:_offset],
layer, offset)
second = self.bitonic_sort(lambda x, y: ~cmpFn(x, y),
x[_offset:],
layer, offset + _offset)
return self.bitonic_merge(cmpFn, first + second,
layer=layer + _offset, offset=offset)
[docs]
def bitonic_merge(self, cmpFn, x, layer, offset):
# assume input x is bitonic, and sorted list is returned
if len(x) == 1:
return x
else:
x = self.bitonic_compare(cmpFn, x, layer, offset)
_offset = len(x) // 2
first = self.bitonic_merge(cmpFn, x[:_offset],
layer + 1, offset)
second = self.bitonic_merge(cmpFn, x[_offset:],
layer + 1, offset + _offset)
return first + second
[docs]
def bitonic_compare(self, cmpFn, x, layer, offset):
dist = len(x) // 2
_x = [self._sig(f"sort_tmp_{layer:d}_{offset:d}", x[0]._dtype) for _ in x]
for i in range(dist):
If(cmpFn(x[i], x[i + dist]),
# keep
_x[i](x[i]),
_x[i + dist](x[i + dist])
).Else(
# swap
_x[i](x[i + dist]),
_x[i + dist](x[i]),
)
return _x
@override
def hwImpl(self):
outs = self.bitonic_sort(self.cmpFn, self.inputs)
for o, otmp in zip(self.outputs, outs):
o(otmp)
[docs]
class BitonicSorterTC(SimTestCase):
SIM_TIME = 40 * Time.ns
[docs]
@classmethod
@override
def setUpClass(cls):
cls.dut = BitonicSorter()
cls.compileSim(cls.dut)
[docs]
def getOutputs(self):
return [outp._ag.data[-1] for outp in self.dut.outputs]
[docs]
def test_reversed(self):
dut = self.dut
ref = [i for i in range(int(dut.ITEMS))]
self.setInputs(reversed(ref))
self.runSim(self.SIM_TIME)
self.assertValSequenceEqual(self.getOutputs(), ref)
[docs]
def test_sorted(self):
dut = self.dut
ref = [i for i in range(int(dut.ITEMS))]
self.setInputs(ref)
self.runSim(self.SIM_TIME)
self.assertValSequenceEqual(self.getOutputs(), ref)
if __name__ == "__main__":
import unittest
from hwt.synth import to_rtl_str
m = BitonicSorter()
print(to_rtl_str(m))
testLoader = unittest.TestLoader()
# suite = unittest.TestSuite([BitonicSorterTC("test_sorted")])
suite = testLoader.loadTestsFromTestCase(BitonicSorterTC)
runner = unittest.TextTestRunner(verbosity=3)
runner.run(suite)