"""Some useful hardware generators (e.g. muxes, signed multipliers, etc.)"""
from __future__ import annotations
import itertools
from pyrtl.conditional import otherwise
from pyrtl.core import Block, LogicNet, working_block
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
from pyrtl.rtllib import barrel, muxes
from pyrtl.wire import Const, WireVector, WireVectorLike, WrappedWireVector
[docs]
def mux(
index: WireVectorLike, *mux_ins: WireVectorLike, default: WireVectorLike = None
) -> WireVector:
"""Multiplexer returning a wire from ``mux_ins`` according to ``index``.
``index`` ``0`` corresponds to the first ``mux_in`` argument.
.. note::
If ``index`` is a 1-bit predicate (something that is ``True`` or ``False``
rather than an integer), it is clearer to use :func:`select`, whose argument
order is consistent with the ternary operators in C-style languages
(``condition ? truecase : falsecase``).
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example of multiplexing between four values::
>>> index = pyrtl.Input(name="index", bitwidth=2)
>>> selected = pyrtl.WireVector(name="selected")
>>> selected <<= pyrtl.mux(index, 4, 5, 6, 7)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"index": 0})
>>> sim.inspect("selected")
4
>>> sim.step(provided_inputs={"index": 3})
>>> sim.inspect("selected")
7
.. doctest only::
>>> pyrtl.reset_working_block()
Example with ``default``::
>>> index = pyrtl.Input(name="index", bitwidth=2)
>>> selected = pyrtl.WireVector(name="selected")
>>> selected <<= pyrtl.mux(index, 4, 5, default=9)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"index": 1})
>>> sim.inspect("selected")
5
>>> sim.step(provided_inputs={"index": 2})
>>> sim.inspect("selected")
9
:param index: Multiplexer's selection input. Can be a :class:`WireVector`, or any
type that can be coerced to :class:`WireVector` by :func:`as_wires`.
:param mux_ins: :class:`WireVector` arguments to select from. Can be a
:class:`WireVector`, or any type that can be coerced to :class:`WireVector` by
:func:`as_wires`.
:param default: If you are selecting between fewer items than ``index`` can address,
``default`` will be used for all remaining items. For example, if you have a
3-bit index but are selecting between 6 ``mux_ins``, you need to specify a value
for those other 2 possible values of ``index`` (``0b110`` and ``0b111``).
:raises PyrtlError: If there are not enough ``mux_ins`` to select from. If
``default=None``, the number of ``mux_ins`` must be exactly ``2 **
index.bitwidth``.
:return: :class:`WireVector` with :attr:`~WireVector.bitwidth` matching the length
of the longest input (not including ``index``).
"""
# find the diff between the addressable range and number of inputs given
index = as_wires(index)
short_by = 2 ** len(index) - len(mux_ins)
if short_by > 0 and default is not None: # extend the list to appropriate size
mux_ins = list(mux_ins)
mux_ins.extend([default] * short_by)
if 2 ** len(index) != len(mux_ins):
msg = (
f"Mux select line is {len(index)} bits, but selecting from {len(mux_ins)} "
"inputs."
)
raise PyrtlError(msg)
if len(index) == 1:
return select(index, falsecase=mux_ins[0], truecase=mux_ins[1])
half = len(mux_ins) // 2
return select(
index[-1],
falsecase=mux(index[0:-1], *mux_ins[:half]),
truecase=mux(index[0:-1], *mux_ins[half:]),
)
[docs]
def select(
sel: WireVectorLike, truecase: WireVectorLike, falsecase: WireVectorLike
) -> WireVector:
"""Multiplexer returning ``truecase`` when ``sel == 1``, otherwise ``falsecase``.
``select`` is equivalent to :func:`mux` with a 1-bit ``index``, except that
``select``'s ``truecase`` is its first argument, rather than its second.
``select``'s argument order is consistent with the ternary operator in C-style
languages, which improves readability, so prefer ``select`` over :func:`mux` when
selecting between exactly two options.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example that computes ``min(a, 5)``::
>>> a = pyrtl.Input(name="a", bitwidth=4)
>>> min = pyrtl.WireVector(name="min")
>>> min <<= pyrtl.select(a < 5, 5, a)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"a": 4})
>>> sim.inspect("min")
5
>>> sim.step(provided_inputs={"a": 6})
>>> sim.inspect("min")
6
:param sel: Multiplexer's selection input. Can be a :class:`WireVector`, or any type
that can be coerced to :class:`WireVector` by :func:`as_wires`.
:param truecase: The WireVector selected if ``sel == 1``. Can be a
:class:`WireVector`, or any type that can be coerced to :class:`WireVector` by
:func:`as_wires`.
:param falsecase: The WireVector selected if ``sel == 0``. Can be a
:class:`WireVector`, or any type that can be coerced to :class:`WireVector` by
:func:`as_wires`.
:return: :class:`WireVector` with :attr:`~WireVector.bitwidth` matching the longer
of ``truecase`` and ``falsecase``.
"""
sel, f, t = (as_wires(w) for w in (sel, falsecase, truecase))
f, t = match_bitwidth(f, t)
outwire = WireVector(bitwidth=len(f))
net = LogicNet(op="x", op_param=None, args=(sel, f, t), dests=(outwire,))
working_block().add_net(net) # this includes sanity check on the mux
return outwire
[docs]
def concat(*args: WireVectorLike) -> WireVector:
"""Concatenates multiple :class:`WireVectors<WireVector>` into a single
:class:`WireVector`.
Concatenates any number of arguments. The right-most argument is the least
significant bits of the result.
.. note::
If you have a :class:`list` of arguments to ``concat`` together, you probably
want index 0 to be the least significant bit. If so, and you unpack the
:class:`list` into ``concat``'s ``args``, the result will be backwards. The
function :func:`concat_list` is provided specifically for that case.
.. note::
Consider using :func:`wire_struct` or :func:`wire_matrix` instead, which helps
with consistently disassembling, naming, and reassembling fields.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example that concatenates two bytes into a 16-bit ``output``::
>>> msb = pyrtl.Input(name="msb", bitwidth=8)
>>> lsb = pyrtl.Input(name="lsb", bitwidth=8)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.concat(msb, lsb)
>>> output.bitwidth
16
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"msb": 0xab, "lsb": 0xcd})
>>> hex(sim.inspect("output"))
'0xabcd'
:param args: Inputs to concatenate, with the most significant bits first. Each input
can be a :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: :class:`WireVector` with :attr:`~WireVector.bitwidth` equal to the sum of
all ``args``' :attr:`bitwidths<~WireVector.bitwidth>`.
"""
if len(args) <= 0:
msg = "error, concat requires at least 1 argument"
raise PyrtlError(msg)
if len(args) == 1:
return as_wires(args[0])
arg_wirevectors = tuple(as_wires(arg) for arg in args)
final_width = sum(len(arg) for arg in arg_wirevectors)
outwire = WireVector(bitwidth=final_width)
net = LogicNet(op="c", op_param=None, args=arg_wirevectors, dests=(outwire,))
working_block().add_net(net)
return outwire
[docs]
def concat_list(wire_list: list[WireVectorLike]) -> WireVector:
"""Concatenates a list of :class:`WireVectors<WireVector>` into a single
:class:`WireVector`.
This take a :class:`list` of :class:`WireVectors<WireVector>` and concats them all
into a single :class:`WireVector`, with the element at index 0 serving as the least
significant bits. This is useful when you have a variable number of
:class:`WireVectors<WireVector>` to concatenate, otherwise :func:`concat` is
prefered.
.. note::
Consider using :func:`wire_struct` or :func:`wire_matrix` instead, which helps
with consistently disassembling, naming, and reassembling fields.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example that concatenates two bytes into a 16-bit ``output``::
>>> msb = pyrtl.Input(name="msb", bitwidth=8)
>>> lsb = pyrtl.Input(name="lsb", bitwidth=8)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.concat_list([lsb, msb])
>>> output.bitwidth
16
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"msb": 0xab, "lsb": 0xcd})
>>> hex(sim.inspect("output"))
'0xabcd'
:param wire_list: List of inputs to concatenate. Each input can be a
:class:`WireVector`, or any type that can be coerced to :class:`WireVector` by
:func:`as_wires`.
:return: :class:`WireVector` with :attr:`~WireVector.bitwidth` equal to the sum of
all ``wire_list`` :attr:`bitwidths<~WireVector.bitwidth>`.
"""
return concat(*reversed(wire_list))
def _signed_input_to_wirevector(x):
"""Convert int input to a signed Const, otherwise call `as_wires`."""
if isinstance(x, int):
return Const(x, signed=True)
return as_wires(x)
[docs]
def signed_add(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a :class:`WireVector` for the result of signed addition.
The inputs are :meth:`~WireVector.sign_extended` to the result's bitwidth before
adding.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_add(neg_three, neg_five)
>>> output.bitwidth
6
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
-8
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A :class:`WireVector` representing the sum of ``a`` and ``b``, with
:attr:`~WireVector.bitwidth` ``max(a.bitwidth, b.bitwidth) + 1``.
"""
a = _signed_input_to_wirevector(a)
b = _signed_input_to_wirevector(b)
result_bitwidth = max(a.bitwidth, b.bitwidth) + 1
a = a.sign_extended(result_bitwidth)
b = b.sign_extended(result_bitwidth)
return (a + b).truncate(result_bitwidth)
[docs]
def signed_sub(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a :class:`WireVector` for the result of signed subtraction.
The inputs are :meth:`~WireVector.sign_extended` to the result's bitwidth before
subtracting.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_sub(neg_three, neg_five)
>>> output.bitwidth
6
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
2
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A :class:`WireVector` representing the difference between ``a`` and ``b``,
with :attr:`~WireVector.bitwidth` ``max(a.bitwidth, b.bitwidth) + 1``.
"""
a = _signed_input_to_wirevector(a)
b = _signed_input_to_wirevector(b)
result_bitwidth = max(a.bitwidth, b.bitwidth) + 1
a = a.sign_extended(result_bitwidth)
b = b.sign_extended(result_bitwidth)
return (a - b).truncate(result_bitwidth)
def mult_signed(a, b):
"""mult_signed is now deprecated, use ``signed_mult`` instead"""
return signed_mult(a, b)
[docs]
def signed_mult(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a :class:`WireVector` for the result of signed multiplication.
The inputs are :meth:`~WireVector.sign_extended` to the result's bitwidth before
multiplying.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_mult(neg_three, neg_five)
>>> output.bitwidth
8
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
15
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A :class:`WireVector` representing the product of ``a`` and ``b``, with
:attr:`~WireVector.bitwidth` ``a.bitwidth + b.bitwidth``.
"""
a = _signed_input_to_wirevector(a)
b = _signed_input_to_wirevector(b)
result_bitwidth = a.bitwidth + b.bitwidth
a = a.sign_extended(result_bitwidth)
b = b.sign_extended(result_bitwidth)
return (a * b).truncate(result_bitwidth)
[docs]
def signed_lt(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a 1-bit :class:`WireVector` for the result of a signed ``<`` comparison.
The inputs are :meth:`~WireVector.sign_extended` to matching bitwidths before
comparing.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_lt(neg_three, neg_five)
>>> output.bitwidth
1
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> sim.inspect("output")
0
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A 1-bit :class:`WireVector` indicating if ``a`` is less than ``b``.
"""
a, b = match_bitwidth(as_wires(a), as_wires(b), signed=True)
r = a - b
return r[-1] ^ (~a[-1]) ^ (~b[-1])
[docs]
def signed_le(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a 1-bit :class:`WireVector` for the result of a signed ``<=`` comparison.
The inputs are :meth:`~WireVector.sign_extended` to matching bitwidths before
comparing.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_le(neg_three, neg_five)
>>> output.bitwidth
1
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> sim.inspect("output")
0
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A 1-bit :class:`WireVector` indicating if ``a`` is less than or equal to
``b``.
"""
a, b = match_bitwidth(as_wires(a), as_wires(b), signed=True)
r = a - b
return (r[-1] ^ (~a[-1]) ^ (~b[-1])) | (a == b)
[docs]
def signed_gt(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a 1-bit :class:`WireVector` for the result of a signed ``>`` comparison.
The inputs are :meth:`~WireVector.sign_extended` to matching bitwidths before
comparing.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_gt(neg_three, neg_five)
>>> output.bitwidth
1
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> sim.inspect("output")
1
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A 1-bit :class:`WireVector` indicating if ``a`` is greater than ``b``.
"""
a, b = match_bitwidth(as_wires(a), as_wires(b), signed=True)
r = b - a
return r[-1] ^ (~a[-1]) ^ (~b[-1])
[docs]
def signed_ge(a: WireVectorLike, b: WireVectorLike) -> WireVector:
"""Return a 1-bit :class:`WireVector` for the result of a signed ``>=`` comparison.
The inputs are :meth:`~WireVector.sign_extended` to matching bitwidths before
comparing.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_three = pyrtl.Const(val=-3, signed=True, bitwidth=3)
>>> neg_five = pyrtl.Const(val=-5, signed=True, bitwidth=5)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.signed_ge(neg_three, neg_five)
>>> output.bitwidth
1
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> sim.inspect("output")
1
:param a: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param b: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: A 1-bit :class:`WireVector` indicating if ``a`` is greater than or equal to
``b``.
"""
a, b = match_bitwidth(as_wires(a), as_wires(b), signed=True)
r = b - a
return (r[-1] ^ (~a[-1]) ^ (~b[-1])) | (a == b)
[docs]
def shift_right_arithmetic(
bits_to_shift: WireVector, shift_amount: WireVector | int
) -> WireVector:
"""Arithmetic right shift operation.
Arithemetic shifting treats the ``bits_to_shift`` as a signed number, so copies of
``bits_to_shift``'s sign bit will be added on the left.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> neg_forty = pyrtl.Const(val=-40, signed=True, bitwidth=7)
>>> shift_amount = pyrtl.Input(name="shift_amount", bitwidth=3)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.shift_right_arithmetic(neg_forty, shift_amount)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"shift_amount": 3})
>>> pyrtl.val_to_signed_integer(sim.inspect("output"), bitwidth=output.bitwidth)
-5
>>> int(-40 / 2 ** 3)
-5
Right shifting by ``N`` bits is equivalent to dividing by ``2^N``.
:param bits_to_shift: Value to shift right arithmetically.
:param shift_amount: Number of bit positions to shift, as an unsigned integer.
:return: A new :class:`WireVector` with the same bitwidth as ``bits_to_shift``.
"""
if isinstance(shift_amount, int):
if shift_amount >= bits_to_shift.bitwidth:
return bits_to_shift[-1].sign_extended(len(bits_to_shift))
return bits_to_shift[shift_amount:].sign_extended(len(bits_to_shift))
bit_in = bits_to_shift[-1] # shift in sign_bit
dir = barrel.Direction.RIGHT
return barrel.barrel_shifter(bits_to_shift, bit_in, dir, shift_amount)
[docs]
def shift_left_logical(
bits_to_shift: WireVector, shift_amount: WireVector | int
) -> WireVector:
"""Logical left shift operation.
Logical shifting treats the ``bits_to_shift`` as an unsigned number. Zeroes will be
added on the right and the result will be truncated to ``bits_to_shift.bitwidth``.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> three = pyrtl.Const(val=3, bitwidth=6)
>>> shift_amount = pyrtl.Input(name="shift_amount", bitwidth=3)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.shift_left_logical(three, shift_amount)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"shift_amount": 3})
>>> sim.inspect("output")
24
>>> 3 * 2 ** 3
24
Left shifting by ``N`` bits is equivalent to multiplying by ``2^N``.
:param bits_to_shift: Value to shift left logically.
:param shift_amount: Number of bit positions to shift, as an unsigned integer.
:return: A new :class:`WireVector` with the same bitwidth as ``bits_to_shift``.
"""
if isinstance(shift_amount, int):
if shift_amount >= bits_to_shift.bitwidth:
return Const(val=0, bitwidth=bits_to_shift.bitwidth)
return concat(bits_to_shift[:-shift_amount], Const(0, shift_amount))
bit_in = 0 # shift in a 0
dir = barrel.Direction.LEFT
return barrel.barrel_shifter(bits_to_shift, bit_in, dir, shift_amount)
shift_left_arithmetic = shift_left_logical
"""Alias for :func:`shift_left_logical`"""
[docs]
def shift_right_logical(
bits_to_shift: WireVector, shift_amount: WireVector | int
) -> WireVector:
"""Logical right shift operation.
Logical shifting treats the ``bits_to_shift`` as an unsigned number, so zeroes will
be added on the left.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> forty = pyrtl.Const(val=40, bitwidth=6)
>>> shift_amount = pyrtl.Input(name="shift_amount", bitwidth=3)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.shift_right_logical(forty, shift_amount)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"shift_amount": 3})
>>> sim.inspect("output")
5
>>> int(40 / 2 ** 3)
5
Right shifting by ``N`` bits is equivalent to dividing by ``2^N``.
:param bits_to_shift: Value to shift right logically.
:param shift_amount: Number of bit positions to shift, as an unsigned integer.
:return: A new :class:`WireVector` with the same bitwidth as ``bits_to_shift``.
"""
if isinstance(shift_amount, int):
if shift_amount >= bits_to_shift.bitwidth:
return Const(val=0, bitwidth=bits_to_shift.bitwidth)
return bits_to_shift[shift_amount:].zero_extended(len(bits_to_shift))
bit_in = 0 # shift in a 0
dir = barrel.Direction.RIGHT
return barrel.barrel_shifter(bits_to_shift, bit_in, dir, shift_amount)
[docs]
def match_bitwidth(*args: WireVector, signed: bool = False) -> tuple[WireVector]:
"""Matches multiple :class:`WireVector` :attr:`bitwidths<~WireVector.bitwidth>` via
zero- or sign-extension.
:class:`WireVectors<WireVector>` with shorter
:attr:`bitwidths<~WireVector.bitwidth>` will be to match the longest
:attr:`~WireVector.bitwidth` in ``args``. :class:`WireVectors<WireVector>` will be
:meth:`~WireVector.sign_extended` or :meth:`~WireVector.zero_extended`, depending on
``signed``.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example with sign-extension::
>>> a = pyrtl.Const(-1, name="a_short", signed=True, bitwidth=2)
>>> b = pyrtl.Const(-3, name="b", signed=True, bitwidth=4)
>>> a, b = match_bitwidth(a, b, signed=True)
>>> a.name = "a_long"
>>> a.bitwidth, b.bitwidth
(4, 4)
>>> sim = pyrtl.Simulation()
>>> sim.step()
>>> bin(sim.inspect("b"))
'0b1101'
>>> bin(sim.inspect("a_short"))
'0b11'
>>> bin(sim.inspect("a_long"))
'0b1111'
:param args: :class:`WireVectors<WireVector>` of which to match
:attr:`~WireVector.bitwidth`
:param signed: If ``True``, extend shorter :class:`WireVectors<WireVector>` with
:meth:`~WireVector.sign_extended`. Otherwise, extend with
:meth:`~WireVector.zero_extended`.
:return: :class:`tuple` of :class:`WireVectors<WireVector>`, in the same order they
appeared in ``args``, all with :attr:`~WireVector.bitwidth` equal to the
longest :attr:`~WireVector.bitwidth` in ``args``.
"""
max_len = max(len(wv) for wv in args)
if signed:
return (wv.sign_extended(max_len) for wv in args)
return (wv.zero_extended(max_len) for wv in args)
[docs]
def as_wires(
val: WireVectorLike,
bitwidth: int | None = None,
truncating: bool = True,
block: Block = None,
) -> WireVector:
"""Convert ``val`` to a :class:`WireVector`.
``val`` may be a :class:`WireVector`, :class:`int` (including
:class:`~enum.IntEnum`), :class:`str`, or :class:`bool`.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
``as_wires`` is mainly used to coerce values into :class:`WireVectors<WireVector>`
(for example, operations such as ``x + 1`` where ``1`` needs to be converted to a
:class:`Const` :class:`WireVector`). See :ref:`wirevector_coercion`. An example::
>>> def make_my_hardware(a, b):
... a = as_wires(a)
... b = as_wires(b)
... return (a + b) & 0xf
>>> input = pyrtl.Input(name="input", bitwidth=8)
>>> output = pyrtl.Output(name="output", bitwidth=4)
>>> output <<= make_my_hardware(input, 7)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 20})
>>> sim.inspect("output")
11
>>> (20 + 7) % 16
11
In the example above, ``as_wires`` will convert the ``7`` to ``Const(7)`` but keep
``input`` unchanged.
:param val: A :class:`WireVector`, or a constant value that can be converted into a
:class:`Const`.
:param bitwidth: The :attr:`~WireVector.bitwidth` of the resulting
:class:`WireVector`.
:param truncating: Determines whether bits will be dropped to achieve the desired
:attr:`~WireVector.bitwidth` if ``val`` is too long (if ``True``, the
most-significant bits will be dropped).
:param block: ``Block`` to use for the returned :class:`WireVector`. Defaults to the
:ref:`working_block`.
"""
from pyrtl.memory import _MemIndexed
block = working_block(block)
if isinstance(val, (int, str)):
# note that this case captures bool as well (as bools are instances of ints)
return Const(val, bitwidth=bitwidth, block=block)
if isinstance(val, _MemIndexed):
# convert to a memory read when the value is actually used
if val.wire is None:
val.wire = as_wires(
val.mem._readaccess(val.index), bitwidth, truncating, block
)
return val.wire
if isinstance(val, WrappedWireVector):
return val.wire
if not isinstance(val, WireVector):
msg = (
"error, expecting a wirevector, int, or Verilog-style const string got "
f"{val} instead"
)
raise PyrtlError(msg)
if bitwidth == "0":
msg = "error, bitwidth must be >= 1"
raise PyrtlError(msg)
if val.bitwidth is None:
msg = "error, attempting to use wirevector with no defined bitwidth"
raise PyrtlError(msg)
if bitwidth and bitwidth > val.bitwidth:
return val.zero_extended(bitwidth)
if bitwidth and truncating and bitwidth < val.bitwidth:
return val[:bitwidth] # truncate the upper bits
return val
[docs]
def bitfield_update(
w: WireVectorLike,
range_start: int,
range_end: int,
newvalue: int,
truncating: bool = False,
) -> WireVector:
"""Update a :class:`WireVector` by replacing some of its bits with ``newvalue``.
Given a :class:`WireVector` ``w``, this function returns a new :class:`WireVector`
that is identical to ``w`` except in the range of bits specified by ``[range_start,
range_end)``. In that range, the value ``newvalue`` is swapped in. For example::
bitfield_update(w, range_start=20, range_end=23, newvalue=0b111)
will return a :class:`WireVector` of the same length as ``w``, and with the same
values as ``w``, but with bits 20, 21, and 22 all set to ``1``.
Note that ``range_start`` and ``range_end`` will be inputs to a slice and so
standard Python slicing rules apply (e.g. negative values for end-relative indexing
and support for ``None``)::
# Sets bits 20, 21, 22 to 1.
w = bitfield_update(w, 20, 23, 0b111)
# Sets bit 20 to 0, bits 21 and 22 to 1.
w = bitfield_update(w, 20, 23, 0b110)
# Assuming w is 32 bits, sets bits 31..20 = 0x7.
w = bitfield_update(w, 20, None, 0x7)
# Set the MSB (bit) to 1.
w = bitfield_update(w, -1, None, 0x1)
# Set the bits before the MSB (bit) to 9.
w = bitfield_update(w, None, -1, 0x9)
# Set the LSB (bit) to 1.
w = bitfield_update(w, None, 1, 0x1)
# Set the bits after the LSB (bit) to 9.
w = bitfield_update(w, 1, None, 0x9)
.. note::
Consider using :func:`wire_struct` or :func:`wire_matrix` instead, which helps
with consistently disassembling, naming, and reassembling fields.
:param w: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`, to use as the starting point for
the update
:param range_start: The start of the range of bits to be updated.
:param range_end: The end of the range of bits to be updated.
:param newvalue: The value to be written in to the ``range_start:range_end`` range.
:param truncating: If ``True``, clip ``newvalue`` to the proper bitwidth if
``newvalue`` is too large.
:raise PyrtlError: If ``newvalue`` is too large to fit in the selected range of bits
and ``truncating`` is ``False``.
:return: ``w`` with some of the bits overwritten by ``newvalue``.
"""
w = as_wires(w)
idxs = list(
range(len(w))
) # we make a list of integers and slice those up to use as indexes
idxs_middle = idxs[range_start:range_end]
if len(idxs_middle) == 0:
msg = "Cannot update bitfield of size 0 (i.e. there are no bits to update)"
raise PyrtlError(msg)
idxs_lower = idxs[: idxs_middle[0]]
idxs_upper = idxs[idxs_middle[-1] + 1 :]
newvalue = as_wires(newvalue, bitwidth=len(idxs_middle), truncating=truncating)
if len(idxs_middle) != len(newvalue):
msg = (
f"Cannot update bitfield of length {len(idxs_middle)} with value of length "
f"{len(newvalue)} unless truncating=True is specified"
)
raise PyrtlError(msg)
result_list = []
if idxs_lower:
result_list.append(w[idxs_lower[0] : idxs_lower[-1] + 1])
result_list.append(newvalue)
if idxs_upper:
result_list.append(w[idxs_upper[0] : idxs_upper[-1] + 1])
result = concat_list(result_list)
if len(result) != len(w):
msg = f"len(result)={len(result)}, len(original)={len(w)}"
raise PyrtlInternalError(msg)
return result
[docs]
def bitfield_update_set(
w: WireVectorLike,
update_set: dict[tuple[int, int], WireVectorLike],
truncating: bool = False,
) -> WireVector:
"""Update a :class:`WireVector` by replacing the bits specified in ``update_set``.
Given a WireVector ``w``, return a new :class:`WireVector` that is identical to `w`
except in the ranges of bits specified by ``update_set``. When multiple
non-overlapping fields need to be updated in a single cycle, this provides a clearer
way to describe that behavior than iterative calls to :func:`bitfield_update`::
w = bitfield_update_set(w, update_set={
(20, 23): 0x6, # sets bit 20 to 0, bits 21 and 22 to 1
(26, None): 0x7, # assuming w is 32 bits, sets bits 31..26 to 0x7
(None, 1): 0x0, # set the LSB (bit) to 0
})
.. note::
Consider using :func:`wire_struct` or :func:`wire_matrix` instead, which helps
with consistently disassembling, naming, and reassembling fields.
:param w: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`, to use as the starting point for
the update
:param update_set: A map from tuples of ``(range_start, range_end)`` integers to a
new value for the range of bits.
:param truncating: If ``True``, clip new values to the proper bitwidth if a new
value is too large.
:raise PyrtlError: If ``update_set`` contains overlapping fields.
:return: ``w`` with some of its bits updated.
"""
w = as_wires(w)
# keep a list of bits that are updated to find overlaps
setlist = [False] * len(w)
# call bitfield for each one
for (range_start, range_end), new_value in update_set.items():
# check for overlaps
setbits = setlist[range_start:range_end]
if any(setbits):
msg = "Bitfields for update are overlapping"
raise PyrtlError(msg)
setlist[range_start:range_end] = [True] * len(setbits)
# do the actual update
w = bitfield_update(w, range_start, range_end, new_value, truncating)
return w
[docs]
def enum_mux(
cntrl: WireVector,
table: dict[int, WireVector],
default: WireVector = None,
strict: bool = True,
) -> WireVector:
"""Build a mux for the control signals specified by an :class:`enum.IntEnum`.
.. note::
Consider using :ref:`conditional_assignment` instead.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> from enum import IntEnum
>>> class Command(IntEnum):
... ADD = 0
... SUB = 1
>>> command = pyrtl.Input(name="command", bitwidth=1)
>>> a = pyrtl.Input(name="a", bitwidth=4)
>>> b = pyrtl.Input(name="b", bitwidth=4)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.enum_mux(cntrl=command, table={
... Command.ADD: a + b,
... Command.SUB: a - b,
... })
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"command": Command.ADD, "a": 1, "b": 2})
>>> sim.inspect("output")
3
>>> sim.step(provided_inputs={"command": Command.SUB, "a": 5, "b": 3})
>>> sim.inspect("output")
2
:param cntrl: Control for the mux.
:param table: Maps :class:`enum.IntEnum` values to :class:`WireVector`.
:param default: A :class:`WireVector` to use when the key is not present. In
addition it is possible to use the key :data:`otherwise` to specify a default
value, but it is an error if both are supplied.
:param strict: When ``True``, check that the dictionary has an entry for every
possible value in the :class:`enum.IntEnum`. Note that if a ``default`` is set,
then this check is not performed as the ``default`` will provide valid values
for any underspecified keys.
:return: Result of the mux.
"""
# check dictionary keys are of the right type
keytypeset = {type(x) for x in table if x is not otherwise}
if len(keytypeset) != 1:
msg = f"table mixes multiple types {keytypeset} as keys"
raise PyrtlError(msg)
keytype = next(iter(keytypeset))
# check that dictionary is complete for the enum
try:
enumkeys = list(keytype.__members__.values())
except AttributeError as exc:
msg = f"type {keytype} not an Enum and does not support the same interface"
raise PyrtlError(msg) from exc
missingkeys = [e for e in enumkeys if e not in table]
# check for "otherwise" in table and move it to a default
if otherwise in table:
if default is not None:
msg = 'both "otherwise" and default provided to enum_mux'
raise PyrtlError(msg)
default = table[otherwise]
if strict and default is None and missingkeys:
msg = f"table provided is incomplete, missing: {missingkeys}"
raise PyrtlError(msg)
# generate the actual mux
vals = {k.value: d for k, d in table.items() if k is not otherwise}
if default is not None:
vals["default"] = default
return muxes.sparse_mux(cntrl, vals)
[docs]
def and_all_bits(vector: WireVector) -> WireVector:
"""Returns the result of bitwise ANDing all the bits in ``vector``.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> input = pyrtl.Input(name="input", bitwidth=4)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.and_all_bits(input)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0b0101})
>>> sim.inspect("output")
0
>>> sim.step(provided_inputs={"input": 0b1111})
>>> sim.inspect("output")
1
:param vector: Takes a single arbitrary length :class:`WireVector`.
:return: Returns a 1-bit result, the bitwise ``&`` of all of the bits in ``vector``.
"""
return tree_reduce(lambda a, b: a & b, vector)
[docs]
def or_all_bits(vector: WireVector) -> WireVector:
"""Returns the result of bitwise ORing all the bits in ``vector``.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> input = pyrtl.Input(name="input", bitwidth=4)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.or_all_bits(input)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0b0000})
>>> sim.inspect("output")
0
>>> sim.step(provided_inputs={"input": 0b0100})
>>> sim.inspect("output")
1
:param vector: Takes a single arbitrary length :class:`WireVector`.
:return: Returns a 1-bit result, the bitwise ``|`` of all of the bits in ``vector``.
"""
return tree_reduce(lambda a, b: a | b, vector)
[docs]
def xor_all_bits(vector: WireVector) -> WireVector:
"""Returns the result of bitwise XORing all the bits in ``vector``.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> input = pyrtl.Input(name="input", bitwidth=4)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.xor_all_bits(input)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0b0100})
>>> sim.inspect("output")
1
>>> sim.step(provided_inputs={"input": 0b0101})
>>> sim.inspect("output")
0
:param vector: Takes a single arbitrary length :class:`WireVector`.
:return: Returns a 1-bit result, the bitwise ``^`` of all of the bits in ``vector``.
"""
return tree_reduce(lambda a, b: a ^ b, vector)
parity = xor_all_bits # shadowing the xor_all_bits function
"""Alias for :func:`xor_all_bits`."""
def tree_reduce(op, vector: WireVector) -> WireVector:
if len(vector) < 1:
msg = "Cannot reduce empty vectors"
raise PyrtlError(msg)
if len(vector) == 1:
return vector[0]
left = tree_reduce(op, vector[: len(vector) // 2])
right = tree_reduce(op, vector[len(vector) // 2 :])
return op(left, right)
def _apply_op_over_all_bits(op, vector):
if len(vector) < 1:
msg = "Cannot reduce empty vectors"
raise PyrtlError(msg)
if len(vector) == 1:
return vector[0]
rest = _apply_op_over_all_bits(op, vector[1:])
return op(vector[0], rest)
[docs]
def rtl_any(*vectorlist: WireVectorLike) -> WireVector:
"""Hardware equivalent of Python's :func:`any`.
Given any number of :class:`WireVectors<WireVector>`, return a 1-bit
:class:`WireVector` which will hold a ``1`` if any of the inputs are ``1``. In other
words, this generates a large OR gate. If no inputs are provided, it will return a
:class:`Const` ``0`` (since there are no ``1s`` present) similar to Python's
:func:`any` called with an empty list.
.. note::
``rtl_any`` is most useful when working with a variable number of
:class:`WireVectors<WireVector>`. For a fixed number of
:class:`WireVectors<WireVector>`, it is clearer to use ``|``::
any_ones = a | b | c
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> inputs = [pyrtl.Input(name="input0", bitwidth=1),
... pyrtl.Input(name="input1", bitwidth=1)]
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.rtl_any(*inputs)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input0": 0, "input1": 0})
>>> sim.inspect("output")
0
>>> sim.step(provided_inputs={"input0": 0, "input1": 1})
>>> sim.inspect("output")
1
:param vectorlist: All arguments are length 1 :class:`WireVector`, or any type that
can be coerced to :class:`WireVector` by :func:`as_wires`, with length 1.
:raise PyrtlError: If any argument's :attr:`~WireVector.bitwidth` is not 1.
:return: Length 1 :class:`WireVector` indicating if any bits in ``vectorlist`` are
``1``.
"""
if len(vectorlist) == 0:
return as_wires(False)
converted_vectorlist = [as_wires(v) for v in vectorlist]
if any(len(v) != 1 for v in converted_vectorlist):
msg = "only length 1 WireVectors can be inputs to rtl_any"
raise PyrtlError(msg)
return or_all_bits(concat_list(converted_vectorlist))
[docs]
def rtl_all(*vectorlist: WireVectorLike) -> WireVector:
"""Hardware equivalent of Python's :func:`all`.
Given any number of :class:`WireVectors<WireVector>`, return a 1-bit
:class:`WireVector` which will hold a ``1`` only if all of the inputs are ``1``. In
other words, this generates a large AND gate. If no inputs are provided, it will
return a :class:`Const` ``1`` (since there are no ``0s`` present) similar to
Python's :func:`all` called with an empty list.
.. note::
``rtl_all`` is most useful when working with a variable number of
:class:`WireVectors<WireVector>`. For a fixed number of
:class:`WireVectors<WireVector>`, it is clearer to use ``&``::
all_ones = a & b & c
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> inputs = [pyrtl.Input(name="input0", bitwidth=1),
... pyrtl.Input(name="input1", bitwidth=1)]
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.rtl_all(*inputs)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input0": 0, "input1": 1})
>>> sim.inspect("output")
0
>>> sim.step(provided_inputs={"input0": 1, "input1": 1})
>>> sim.inspect("output")
1
:param vectorlist: All arguments are length 1 :class:`WireVector`, or any type that
can be coerced to :class:`WireVector` by :func:`as_wires`, with length 1.
:raise PyrtlError: If any argument's :attr:`~WireVector.bitwidth` is not 1.
:return: Length 1 :class:`WireVector` indicating if all bits in ``vectorlist`` are
``1``.
"""
if len(vectorlist) == 0:
return as_wires(True)
converted_vectorlist = [as_wires(v) for v in vectorlist]
if any(len(v) != 1 for v in converted_vectorlist):
msg = "only length 1 WireVectors can be inputs to rtl_all"
raise PyrtlError(msg)
return and_all_bits(concat_list(converted_vectorlist))
def _basic_mult(A, B):
"""A stripped-down copy of the Wallace multiplier in rtllib"""
if len(B) == 1:
A, B = B, A # so that we can reuse the code below :)
if len(A) == 1:
return concat_list(
[A & b for b in B] + [Const(0)]
) # keep WireVector len consistent
result_bitwidth = len(A) + len(B)
bits = [[] for weight in range(result_bitwidth)]
for i, a in enumerate(A):
for j, b in enumerate(B):
bits[i + j].append(a & b)
while not all(len(i) <= 2 for i in bits):
deferred = [[] for weight in range(result_bitwidth + 1)]
for i, w_array in enumerate(bits): # Start with low weights and start reducing
while len(w_array) >= 3: # build a new full adder
a, b, cin = (w_array.pop(0) for j in range(3))
deferred[i].append(a ^ b ^ cin)
deferred[i + 1].append(a & b | a & cin | b & cin)
if len(w_array) == 2:
a, b = w_array
deferred[i].append(a ^ b)
deferred[i + 1].append(a & b)
else:
deferred[i].extend(w_array)
bits = deferred[:result_bitwidth]
add_wires = tuple(itertools.zip_longest(*bits, fillvalue=Const(0)))
adder_result = concat_list(add_wires[0]) + concat_list(add_wires[1])
return adder_result[:result_bitwidth]
def _one_bit_add(a, b, carry_in):
assert len(a) == len(b) == 1
sumbit = a ^ b ^ carry_in
carry_out = a & b | a & carry_in | b & carry_in
return sumbit, carry_out
def _add_helper(a, b, carry_in):
a, b = match_bitwidth(a, b)
if len(a) == 1:
sumbits, carry_out = _one_bit_add(a, b, carry_in)
else:
lsbit, ripplecarry = _one_bit_add(a[0], b[0], carry_in)
msbits, carry_out = _add_helper(a[1:], b[1:], ripplecarry)
sumbits = concat(msbits, lsbit)
return sumbits, carry_out
def _basic_add(a, b):
sumbits, carry_out = _add_helper(a, b, 0)
return concat(carry_out, sumbits)
def _basic_sub(a, b):
sumbits, carry_out = _add_helper(a, ~b, 1)
return concat(carry_out, sumbits)
def _basic_eq(a, b):
return ~or_all_bits(a ^ b)
def _basic_lt(a, b):
assert len(a) == len(b)
a_msb = a[-1]
b_msb = b[-1]
if len(a) == 1:
return b_msb & ~a_msb
small = _basic_lt(a[:-1], b[:-1])
return (b_msb & ~a_msb) | (small & ~(a_msb ^ b_msb))
def _basic_gt(a, b):
return _basic_lt(b, a)
def _basic_select(s, a, b):
assert len(a) == len(b)
assert len(s) == 1
sa = concat(*[~s] * len(a))
sb = concat(*[s] * len(b))
return (a & sa) | (b & sb)