From ff9ed80c3f57ac53cb0269a55f4a9e9efca782b9 Mon Sep 17 00:00:00 2001 From: Sebastian Holzapfel Date: Mon, 21 Apr 2025 12:55:42 +0200 Subject: [PATCH] lib.fixed: add initial implementation and test cases --- amaranth/lib/fixed.py | 328 ++++++++++++++++++++++++++ tests/test_lib_fixed.py | 506 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 834 insertions(+) create mode 100644 amaranth/lib/fixed.py create mode 100644 tests/test_lib_fixed.py diff --git a/amaranth/lib/fixed.py b/amaranth/lib/fixed.py new file mode 100644 index 000000000..0102f34f8 --- /dev/null +++ b/amaranth/lib/fixed.py @@ -0,0 +1,328 @@ +# Based on latest iteration of fixed point types RFC, which +# is an effort undertaken by the Amaranth community, as well +# as an early (incomplete) RFC implementation by zyp@ +# +# RFC (community): https://github.com/amaranth-lang/rfcs/pull/41 +# Early implementation (zyp@): https://github.com/amaranth-lang/amaranth/pull/1005 +# +# SPDX-License-Identifier: BSD-3-Clause + +from .. import hdl, Mux +from ..utils import bits_for + +__all__ = ["Shape", "SQ", "UQ", "Value", "Const"] + +class Shape(hdl.ShapeCastable): + + def __init__(self, shape, f_bits=0): + self._storage_shape = shape + self.i_bits, self.f_bits = shape.width-f_bits, f_bits + if self.i_bits < 0 or self.f_bits < 0: + raise TypeError(f"fixed.Shape may not be created with negative bit widths (i_bits={self.i_bits}, f_bits={self.f_bits})") + if shape.signed and self.i_bits == 0: + raise TypeError(f"A signed fixed.Shape cannot be created with i_bits=0") + if self.i_bits + self.f_bits == 0: + raise TypeError(f"fixed.Shape may not be created with zero width") + + @property + def signed(self): + return self._storage_shape.signed + + @staticmethod + def cast(shape, f_bits=0): + if not isinstance(shape, hdl.Shape): + raise TypeError(f"Object {shape!r} cannot be converted to a fixed.Shape") + return Shape(shape, f_bits) + + def const(self, value): + if value is None: + value = 0 + return Const(value, self)._target + + def as_shape(self): + return self._storage_shape + + def __call__(self, target): + return Value(self, target) + + def min(self): + c = Const(0, self) + c._value = c._min_value() + return c + + def max(self): + c = Const(0, self) + c._value = c._max_value() + return c + + def from_bits(self, raw): + c = Const(0, self) + c._value = raw + if self.signed and raw > c._max_value(): + # 2s complement signed value, but `raw` was unsigned. + c._value = c._min_value() + c._value - c._max_value() - 1 + if c._value < c._min_value() or c._value > c._max_value(): + raise ValueError( + f"{raw} outside expected range {c._min_value()}, {c._max_value()}") + return c + + def __repr__(self): + return f"fixed.Shape({self._storage_shape}, f_bits={self.f_bits})" + + +class SQ(Shape): + def __init__(self, i_bits, f_bits): + super().__init__(hdl.Shape(i_bits + f_bits, signed=True), f_bits) + + +class UQ(Shape): + def __init__(self, i_bits, f_bits): + super().__init__(hdl.Shape(i_bits + f_bits, signed=False), f_bits) + + +class Value(hdl.ValueCastable): + def __init__(self, shape, target): + self._shape = shape + if self.signed and not target.shape().signed: + # When methods bit-pick or concatenate to + # the _target of a Value, and then use this + # to reconstruct a Value, we may lose the + # signedness of its underlying _target. + self._target = target.as_signed() + else: + self._target = target + + @property + def signed(self): + return self._shape.signed + + @staticmethod + def cast(value, f_bits=0): + return Shape.cast(value.shape(), f_bits)(value) + + @property + def i_bits(self): + return self._shape.i_bits + + @property + def f_bits(self): + return self._shape.f_bits + + def shape(self): + return self._shape + + def as_value(self): + return self._target + + def eq(self, other): + if isinstance(other, hdl.Value): + return self.as_value().eq(other) + elif isinstance(other, int) or isinstance(other, float): + other = Const(other, self.shape()) + elif not isinstance(other, Value): + raise TypeError(f"Object {other!r} cannot be converted to a fixed.Value") + other = other.reshape(self.f_bits) + return self.as_value().eq(other.as_value()) + + def reshape(self, f_bits): + # If we're increasing precision, extend with more fractional bits. If we're + # reducing precision, truncate bits. + shape = hdl.Shape(self.i_bits + f_bits, signed=self.signed) + if f_bits > self.f_bits: + result = Shape(shape, f_bits)(hdl.Cat(hdl.Const(0, f_bits - self.f_bits), self.as_value())) + else: + result = Shape(shape, f_bits)(self.as_value()[self.f_bits - f_bits:]) + return result + + def truncate(self, f_bits=0): + if f_bits > self.f_bits: + raise ValueError( + f"`.truncate(f_bits={f_bits}) exceeds the underlying type's f_bits={self.f_bits}. " + "Use `.reshape()` to instead extend `f_bits`." + ) + return self.reshape(f_bits) + + def clamp(self, lo, hi): + if not isinstance(lo, Value) or not isinstance(hi, Value): + raise TypeError(f"Cannot `clamp` as lo, hi are not fixed.Value") + lo = lo.reshape(self.f_bits) + hi = hi.reshape(self.f_bits) + return Value(self.shape(), Mux( + self > hi, hi, + Mux(self < lo, lo, self) + )) + + def saturate(self, shape): + if not isinstance(shape, Shape): + raise TypeError(f"Cannot `saturate` to bounds of {shape!r} as it is not a fixed.Shape") + if not shape.i_bits <= self.i_bits: + raise ValueError(f"Cannot `saturate`: shape.i_bits={shape.i_bits} > self.i_bits={self.i_bits} would have no effect.") + clamped = self.reshape(shape.f_bits).clamp(shape.min(), shape.max()) + return Value(shape, clamped.as_value()) + + def _binary_op(self, rhs, operator, callable_f_bits = lambda a, b: max(a, b), pre_reshape=True, post_cast=True): + if isinstance(rhs, hdl.Value): + rhs = Value.cast(rhs) + elif isinstance(rhs, int): + rhs = Const(rhs) + elif not isinstance(rhs, Value): + raise TypeError(f"Object {rhs!r} cannot be converted to a fixed.Value") + f_bits = callable_f_bits(self.f_bits, rhs.f_bits) + if pre_reshape: + lhs = self.reshape(f_bits) + rhs = rhs.reshape(f_bits) + else: + lhs = self + value = getattr(lhs.as_value(), operator)(rhs.as_value()) + return Value.cast(value, f_bits) if post_cast else value + + def __mul__(self, other): + return self._binary_op(other, '__mul__', lambda a, b: a + b, pre_reshape=False) + + __rmul__ = __mul__ + + def __add__(self, other): + return self._binary_op(other, '__add__') + + __radd__ = __add__ + + def __sub__(self, other): + return self._binary_op(other, '__sub__') + + def __rsub__(self, other): + return -self.__sub__(other) + + def __pos__(self): + return self + + def __neg__(self): + return Value.cast(-self.as_value(), self.f_bits) + + def __abs__(self): + return Value.cast(abs(self.as_value()), self.f_bits) + + def __lshift__(self, other): + if isinstance(other, int): + if other < 0: + raise ValueError("Shift amount cannot be negative") + + if other > self.f_bits: + value = hdl.Cat(hdl.Const(0, other - self.f_bits), self.as_value()) + return Value.cast(value.as_signed() if self.signed else value) + else: + return Value.cast(self.as_value(), self.f_bits - other) + elif not isinstance(other, hdl.Value): + raise TypeError("Shift amount must be an integer value") + if other.signed: + raise TypeError("Shift amount must be unsigned") + return Value.cast(self.as_value() << other, self.f_bits) + + def __rshift__(self, other): + if isinstance(other, int): + if other < 0: + raise ValueError("Shift amount cannot be negative") + # Extend f_bits by fixed shift amount. + i_bits = self.i_bits - other + f_bits = self.f_bits + other + numerator = self.as_value() + elif isinstance(other, hdl.Value): + if other.shape().signed: + raise TypeError("Shift amount must be unsigned") + # Extend by maximum possible shift represented by hdl.Value. + f_bits = self.f_bits + 2**other.shape().width - 1 + i_bits = self.i_bits - (f_bits - self.f_bits) + numerator = self.reshape(f_bits).as_value() >> other + else: + raise TypeError("Shift amount must be an integer value") + # Always keep at least 1 sign bit and prohibit negative i_bits. + # TODO: should we concat to _target for sign extension? (likely unnecessary) + if self.signed: + return SQ(max(1, i_bits), f_bits)(numerator) + else: + return UQ(max(0, i_bits), f_bits)(numerator) + + def _binary_compare(self, other, operator): + return self._binary_op(other, operator, post_cast=False) + + def __lt__(self, other): + return self._binary_compare(other, '__lt__') + + def __ge__(self, other): + return self._binary_compare(other, '__ge__') + + def __gt__(self, other): + return self._binary_compare(other, '__gt__') + + def __le__(self, other): + return self._binary_compare(other, '__le__') + + def __eq__(self, other): + return self._binary_compare(other, '__eq__') + + def __repr__(self): + return f"fixed.{'SQ' if self.signed else 'UQ'}({self.i_bits}, {self.f_bits}) {self._target!r}" + + +class Const(Value): + def __init__(self, value, shape=None, clamp=False): + + if isinstance(value, float) or isinstance(value, int): + num, den = value.as_integer_ratio() + elif isinstance(value, Const): + # FIXME: Memory inits seem to construct a fixed.Const with fixed.Const + self._shape = value._shape + self._value = value._value + return + else: + raise TypeError(f"Object {value!r} cannot be converted to a fixed.Const") + + # Determine smallest possible shape if not already selected. + if shape is None: + signed = num < 0 + f_bits = bits_for(den) - 1 + i_bits = max(0, bits_for(abs(num)) - f_bits) + shape = SQ(i_bits+1, f_bits) if signed else UQ(i_bits, f_bits) + + # Scale value to given precision. + if 2**shape.f_bits > den: + num *= 2**shape.f_bits // den + elif 2**shape.f_bits < den: + num = round(num / (den // 2**shape.f_bits)) + value = num + + self._shape = shape + + if value > self._max_value(): + if clamp: + value = self._max_value() + else: + raise ValueError(f"Constant {value!r} does not fit in {shape!r}.") + + if value < self._min_value(): + if clamp: + value = self._min_value() + else: + raise ValueError(f"Constant {value!r} does not fit in {shape!r}. ") + + self._value = value + + def _max_value(self): + return 2**(self._shape.i_bits + + self._shape.f_bits - (1 if self.signed else 0)) - 1 + + def _min_value(self): + if self._shape.signed: + return -1 * 2**(self._shape.i_bits + + self._shape.f_bits - 1) + else: + return 0 + + @property + def _target(self): + return hdl.Const(self._value, self._shape.as_shape()) + + def as_integer_ratio(self): + return self._value, 2**self.f_bits + + def as_float(self): + return self._value / 2**self.f_bits diff --git a/tests/test_lib_fixed.py b/tests/test_lib_fixed.py new file mode 100644 index 000000000..885c86233 --- /dev/null +++ b/tests/test_lib_fixed.py @@ -0,0 +1,506 @@ +from .utils import * + +from amaranth.hdl import * +from amaranth.sim import Simulator +from amaranth.lib import fixed + +class TestFixedShape(FHDLTestCase): + + def test_shape_uq_init(self): + + s = fixed.UQ(6, 5) + self.assertEqual(s.i_bits, 6) + self.assertEqual(s.f_bits, 5) + self.assertFalse(s.signed) + + s = fixed.UQ(0, 1) + self.assertEqual(s.i_bits, 0) + self.assertEqual(s.f_bits, 1) + self.assertFalse(s.signed) + + s = fixed.UQ(1, 0) + self.assertEqual(s.i_bits, 1) + self.assertEqual(s.f_bits, 0) + self.assertFalse(s.signed) + + with self.assertRaises(TypeError): + fixed.UQ(-1, 0) + + with self.assertRaises(TypeError): + fixed.UQ(1, -1) + + def test_shape_sq_init(self): + + s = fixed.SQ(6, 5) + self.assertEqual(s.i_bits, 6) + self.assertEqual(s.f_bits, 5) + self.assertTrue(s.signed) + + s = fixed.SQ(1, 0) + self.assertEqual(s.i_bits, 1) + self.assertEqual(s.f_bits, 0) + self.assertTrue(s.signed) + + with self.assertRaises(TypeError): + fixed.SQ(0, 1) + + with self.assertRaises(TypeError): + fixed.SQ(-1, 0) + + with self.assertRaises(TypeError): + fixed.SQ(1, -1) + + def test_cast_from_shape(self): + + s = fixed.Shape.cast(signed(12), f_bits=4) + self.assertEqual(s.i_bits, 8) + self.assertEqual(s.f_bits, 4) + self.assertTrue(s.signed) + + with self.assertRaises(TypeError): + fixed.Shape.cast("not a shape") + + def test_cast_to_shape(self): + + fixed_shape = fixed.Shape(unsigned(11), f_bits=5) + hdl_shape = fixed_shape.as_shape() + self.assertEqual(hdl_shape.width, 11) + self.assertFalse(hdl_shape.signed) + + def test_min_max(self): + + self.assertEqual(fixed.UQ(2, 4).max().as_value().__repr__(), "(const 6'd63)") + self.assertEqual(fixed.UQ(2, 4).min().as_value().__repr__(), "(const 6'd0)") + self.assertEqual(fixed.UQ(2, 4).max().as_float(), 3.9375) + self.assertEqual(fixed.UQ(2, 4).min().as_float(), 0) + + self.assertEqual(fixed.UQ(0, 2).max().as_value().__repr__(), "(const 2'd3)") + self.assertEqual(fixed.UQ(0, 2).min().as_value().__repr__(), "(const 2'd0)") + self.assertEqual(fixed.UQ(0, 2).max().as_float(), 0.75) + self.assertEqual(fixed.UQ(0, 2).min().as_float(), 0) + + self.assertEqual(fixed.SQ(2, 4).max().as_value().__repr__(), "(const 6'sd31)") + self.assertEqual(fixed.SQ(2, 4).min().as_value().__repr__(), "(const 6'sd-32)") + self.assertEqual(fixed.SQ(2, 4).max().as_float(), 1.9375) + self.assertEqual(fixed.SQ(2, 4).min().as_float(), -2) + + self.assertEqual(fixed.SQ(1, 0).max().as_value().__repr__(), "(const 1'sd0)") + self.assertEqual(fixed.SQ(1, 0).min().as_value().__repr__(), "(const 1'sd-1)") + self.assertEqual(fixed.SQ(1, 0).max().as_float(), 0) + self.assertEqual(fixed.SQ(1, 0).min().as_float(), -1) + + def test_from_bits(self): + + self.assertEqual(fixed.UQ(2, 4).from_bits(0b100000).as_float(), 2.0) + self.assertEqual(fixed.UQ(2, 4).from_bits(0b010000).as_float(), 1.0) + self.assertEqual(fixed.UQ(2, 4).from_bits(0b001000).as_float(), 0.5) + self.assertEqual(fixed.UQ(2, 4).from_bits(0b000100).as_float(), 0.25) + self.assertEqual(fixed.UQ(2, 4).from_bits(0b000000).as_float(), 0) + + self.assertEqual(fixed.SQ(2, 4).from_bits(0b000000).as_float(), 0) + self.assertEqual(fixed.SQ(2, 4).from_bits(0b000001).as_float(), 0.0625) + self.assertEqual(fixed.SQ(2, 4).from_bits(0b111111).as_float(), -0.0625) + self.assertEqual(fixed.SQ(2, 4).from_bits(0b010000).as_float(), 1) + self.assertEqual(fixed.SQ(2, 4).from_bits(0b100000).as_float(), -2) + +class TestFixedValue(FHDLTestCase): + + def assertFixedEqual(self, expression, expected, force_expected_shape=False): + + m = Module() + output = Signal.like(expected if force_expected_shape else expression) + m.d.comb += output.eq(expression) + + async def testbench(ctx): + out = ctx.get(output) + self.assertEqual(out.i_bits, expected.i_bits) + self.assertEqual(out.f_bits, expected.f_bits) + self.assertEqual(out.as_float(), expected.as_float()) + self.assertEqual(out.as_value().value, expected.as_value().value) + self.assertEqual(out.signed, expected.signed) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.run() + + def assertFixedBool(self, expression, expected): + + m = Module() + output = Signal.like(expression) + m.d.comb += output.eq(expression) + + async def testbench(ctx): + self.assertEqual(ctx.get(output), 1 if expected else 0) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.run() + + def test_mul(self): + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 2)) * fixed.Const(0.25, fixed.SQ(1, 2)), + fixed.Const(0.375, fixed.SQ(4, 4)) + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 2)) * fixed.Const(-0.25, fixed.SQ(1, 2)), + fixed.Const(-0.375, fixed.SQ(4, 4)) + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 2)) * 3, + fixed.Const(4.5, fixed.UQ(5, 2)) + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 2)) * -3, + fixed.Const(-4.5, fixed.SQ(6, 2)) + ) + + with self.assertRaises(TypeError): + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 2)) * 3.5, + fixed.Const(4.5, fixed.UQ(5, 2)) + ) + + + def test_add(self): + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) + fixed.Const(0.25, fixed.SQ(1, 2)), + fixed.Const(1.75, fixed.SQ(5, 3)), + ) + + self.assertFixedEqual( + fixed.Const(0.5, fixed.UQ(3, 3)) + fixed.Const(-0.75, fixed.SQ(1, 2)), + fixed.Const(-0.25, fixed.SQ(5, 3)) + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) + fixed.Const(0.25, fixed.UQ(1, 2)), + fixed.Const(1.75, fixed.UQ(4, 3)), + ) + + def test_sub(self): + + self.assertFixedEqual( + fixed.Const(1.5, fixed.SQ(3, 3)) - fixed.Const(1.75, fixed.SQ(2, 2)), + fixed.Const(-0.25, fixed.SQ(4, 3)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) - fixed.Const(2, fixed.UQ(2, 2)), + fixed.Const(-0.5, fixed.SQ(4, 3)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) - 3, + fixed.Const(-1.5, fixed.SQ(4, 3)), + ) + + self.assertFixedEqual( + 3 - fixed.Const(1.5, fixed.UQ(3, 3)), + fixed.Const(1.5, fixed.SQ(5, 3)), + ) + + def test_shift(self): + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) << 1, + fixed.Const(3.0, fixed.UQ(4, 2)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) >> 1, + fixed.Const(0.75, fixed.UQ(2, 4)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.SQ(3, 3)) >> 3, + fixed.Const(0.1875, fixed.SQ(1, 6)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.SQ(3, 3)) >> Const(3, unsigned(2)), + fixed.Const(0.1875, fixed.SQ(1, 6)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) >> Const(3, unsigned(2)), + fixed.Const(0.1875, fixed.UQ(0, 6)), + ) + + self.assertFixedEqual( + fixed.Const(1.5, fixed.UQ(3, 3)) >> 3, + fixed.Const(0.1875, fixed.UQ(0, 6)), + ) + + self.assertFixedEqual( + fixed.Const(-1.5, fixed.SQ(3, 3)) << 4, + fixed.Const(-24.0, fixed.SQ(7, 0)), + ) + + with self.assertRaises(ValueError): + fixed.Const(1.5, fixed.UQ(3, 3)) << -1 + + with self.assertRaises(ValueError): + fixed.Const(1.5, fixed.UQ(3, 3)) >> -1 + + with self.assertRaises(TypeError): + fixed.Const(1.5, fixed.UQ(3, 3)) >> Const(-1, signed(2)) + + def test_abs(self): + + # fixed.SQ -> fixed.UQ + + self.assertFixedEqual( + abs(fixed.Const(-1.5, fixed.SQ(3, 3))), + fixed.Const(1.5, fixed.UQ(3, 3)) + ) + + self.assertFixedEqual( + abs(fixed.Const(-1, fixed.SQ(1, 2))), + fixed.Const(1, fixed.UQ(1, 2)) + ) + + self.assertFixedEqual( + abs(fixed.Const(-4, fixed.SQ(3, 3))), + fixed.Const(4, fixed.UQ(3, 3)) + ) + + # fixed.UQ -> fixed.UQ + + self.assertFixedEqual( + abs(fixed.Const(7, fixed.UQ(3, 3))), + fixed.Const(7, fixed.UQ(3, 3)) + ) + + def test_neg(self): + + # fixed.SQ -> fixed.SQ + + self.assertFixedEqual( + -fixed.Const(-1.5, fixed.SQ(3, 3)), + fixed.Const(1.5, fixed.SQ(4, 3)) + ) + + self.assertFixedEqual( + -fixed.Const(-1, fixed.SQ(1, 2)), + fixed.Const(1, fixed.SQ(2, 2)) + ) + + self.assertFixedEqual( + -fixed.Const(1.5, fixed.SQ(2, 2)), + fixed.Const(-1.5, fixed.SQ(3, 2)) + ) + + # fixed.UQ -> fixed.SQ + + self.assertFixedEqual( + -fixed.Const(1.5, fixed.UQ(2, 2)), + fixed.Const(-1.5, fixed.SQ(3, 2)) + ) + + def test_clamp(self): + + self.assertFixedEqual( + fixed.Const(3, fixed.SQ(3, 3)).clamp( + fixed.Const(-1), + fixed.Const(1)), + fixed.Const(1, fixed.SQ(3, 3)) + ) + + self.assertFixedEqual( + fixed.Const(3, fixed.SQ(3, 3)).clamp( + fixed.Const(-3), + fixed.Const(-2)), + fixed.Const(-2, fixed.SQ(3, 3)) + ) + + self.assertFixedEqual( + fixed.Const(3, fixed.SQ(3, 3)).clamp( + fixed.Const(-0.5), + fixed.Const(0.5)), + fixed.Const(0.5, fixed.SQ(3, 3)) + ) + + def test_saturate(self): + + # fixed.SQ -> fixed.SQ + + self.assertFixedEqual( + fixed.Const(-2, fixed.SQ(3, 3)).saturate(fixed.SQ(1, 1)), + fixed.Const(-1, fixed.SQ(1, 1)) + ) + + self.assertFixedEqual( + fixed.Const(-10.25, fixed.SQ(5, 3)).saturate(fixed.SQ(3, 1)), + fixed.Const(-4, fixed.SQ(3, 1)) + ) + + self.assertFixedEqual( + fixed.Const(14.25, fixed.SQ(8, 3)).saturate(fixed.SQ(4, 2)), + fixed.Const(7.75, fixed.SQ(4, 2)) + ) + + self.assertFixedEqual( + fixed.Const(0.995, fixed.SQ(1, 8)).saturate(fixed.SQ(1, 4)), + fixed.Const(0.9375, fixed.SQ(1, 4)) + ) + + with self.assertRaises(ValueError): + fixed.Const(0, fixed.SQ(8, 0)).saturate(fixed.SQ(9, 0)), + + # XXX: this 'odd' behaviour is an artifact of truncation rounding, + # and should be revisited when we have more rounding strategies. + + self.assertFixedEqual( + fixed.Const(-0.995, fixed.SQ(2, 8)).saturate(fixed.SQ(2, 4)), + fixed.Const(-1, fixed.SQ(2, 4)) + ) + + # fixed.UQ -> fixed.UQ + + self.assertFixedEqual( + fixed.Const(15, fixed.UQ(5, 2)).saturate(fixed.UQ(3, 1)), + fixed.Const(7.5, fixed.UQ(3, 1)) + ) + + # fixed.SQ -> fixed.UQ + + self.assertFixedEqual( + fixed.Const(14.25, fixed.SQ(8, 3)).saturate(fixed.UQ(2, 2)), + fixed.Const(3.75, fixed.UQ(2, 2)) + ) + + self.assertFixedEqual( + fixed.Const(-14.25, fixed.SQ(8, 3)).saturate(fixed.UQ(2, 2)), + fixed.Const(0, fixed.UQ(2, 2)) + ) + + # fixed.UQ -> fixed.SQ + + self.assertFixedEqual( + fixed.Const(255, fixed.UQ(8, 2)).saturate(fixed.SQ(8, 2)), + fixed.Const(127.75, fixed.SQ(8, 2)) + ) + + def test_lt(self): + + self.assertFixedBool( + fixed.Const(0.75, fixed.SQ(1, 2)) < fixed.Const(0.5, fixed.SQ(1, 2)), False) + self.assertFixedBool( + fixed.Const(0.5, fixed.SQ(1, 2)) < fixed.Const(0.75, fixed.SQ(1, 2)), True) + self.assertFixedBool( + fixed.Const(0.75, fixed.SQ(1, 2)) < fixed.Const(-0.5, fixed.SQ(1, 2)), False) + self.assertFixedBool( + fixed.Const(-0.5, fixed.SQ(1, 2)) < fixed.Const(0.75, fixed.SQ(1, 2)), True) + self.assertFixedBool( + fixed.Const(-0.25, fixed.SQ(1, 2)) < fixed.Const(0, fixed.SQ(1, 2)), True) + self.assertFixedBool( + fixed.Const(0.25, fixed.SQ(1, 2)) < fixed.Const(0, fixed.SQ(1, 2)), False) + self.assertFixedBool( + fixed.Const(-0.25, fixed.SQ(1, 2)) < fixed.Const(0), True) + self.assertFixedBool( + fixed.Const(0.25, fixed.SQ(1, 2)) < fixed.Const(0), False) + self.assertFixedBool( + fixed.Const(0, fixed.SQ(1, 2)) < fixed.Const(0), False) + self.assertFixedBool( + fixed.Const(0) < fixed.Const(0), False) + self.assertFixedBool( + fixed.Const(0) < 1, True) + self.assertFixedBool( + fixed.Const(0) < -1, False) + + def test_equality(self): + + self.assertFixedBool(fixed.Const(0) == 0, True) + self.assertFixedBool(fixed.Const(0) == fixed.Const(0), True) + self.assertFixedBool(fixed.Const(0.5) == fixed.Const(0.5), True) + self.assertFixedBool(fixed.Const(0.5) == fixed.Const(0.75), False) + self.assertFixedBool(fixed.Const(0.501) == fixed.Const(0.5), False) + + with self.assertRaises(TypeError): + self.assertFixedBool(0.5 == fixed.Const(0.5), False) + + def test_eq(self): + + self.assertFixedEqual( + fixed.Const(-1, fixed.SQ(2, 1)), + fixed.Const(-1, fixed.SQ(5, 1)), + force_expected_shape=True + ) + + self.assertFixedEqual( + fixed.SQ(1, 1).max(), + fixed.Const(0.5, fixed.SQ(5, 1)), + force_expected_shape=True + ) + + self.assertFixedEqual( + fixed.SQ(1, 1).max(), + fixed.Const(0.5, fixed.SQ(5, 1)), + force_expected_shape=True + ) + + self.assertFixedEqual( + fixed.Const(0.25, fixed.SQ(5, 5)), + fixed.Const(0.0, fixed.SQ(5, 1)), + force_expected_shape=True + ) + + # XXX: truncation rounding again + + self.assertFixedEqual( + fixed.Const(-0.25, fixed.SQ(5, 5)), + fixed.Const(-0.5, fixed.SQ(5, 1)), + force_expected_shape=True + ) + + # XXX: .eq() from fixed.SQ <-> fixed.UQ may over/underflow. + # fixed.SQ -> fixed.UQ: may overflow if fixed.SQ is negative + # fixed.UQ -> fixed.SQ: may overflow if i_bits (fixed.UQ) >= i_bits (fixed.SQ) + # same signedness: may overflow if i_bits > i_bits + # Should these really be prohibited completely? + + self.assertFixedEqual( + fixed.Const(-10, fixed.SQ(5, 2)), + fixed.Const(22, fixed.UQ(5, 2)), + force_expected_shape=True + ) + + self.assertFixedEqual( + fixed.Const(15, fixed.UQ(4, 2)), + fixed.Const(-1, fixed.SQ(4, 2)), + force_expected_shape=True + ) + + + def test_float_size_determination(self): + + self.assertFixedEqual( + fixed.Const(0.03125), + fixed.Const(0.03125, fixed.UQ(0, 5)) + ) + + self.assertFixedEqual( + fixed.Const(-0.03125), + fixed.Const(-0.03125, fixed.SQ(1, 5)) + ) + + self.assertFixedEqual( + fixed.Const(-0.5), + fixed.Const(-0.5, fixed.SQ(1, 1)) + ) + + self.assertFixedEqual( + fixed.Const(10), + fixed.Const(10, fixed.UQ(4, 0)) + ) + + self.assertFixedEqual( + fixed.Const(-10), + fixed.Const(-10, fixed.SQ(5, 0)) + )