"""Helper functions that make constructing hardware easier."""
from __future__ import annotations
import collections
import numbers
import random
from typing import TYPE_CHECKING, NamedTuple
from pyrtl.core import Block, _get_debug_mode, _NameIndexer, working_block
from pyrtl.corecircuits import (
as_wires,
concat,
concat_list,
rtl_all,
rtl_any,
select,
shift_left_logical,
)
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
from pyrtl.wire import (
Const,
Input,
Output,
Register,
WireVector,
WireVectorLike,
WrappedWireVector,
)
if TYPE_CHECKING:
from pyrtl.simulation import Simulation
# -----------------------------------------------------------------
# ___ __ ___ __ __
# |__| |__ | |__) |__ |__) /__`
# | | |___ |___ | |___ | \ .__/
#
probeIndexer = _NameIndexer("Probe-")
[docs]
def probe(w: WireVector, name: str | None = None) -> WireVector:
"""Print useful information about a :class:`WireVector` in debug mode.
``probe`` can be inserted into a existing design easily because it returns the
original wire unmodified. For example::
y <<= x[0:3] + 4
could be rewritten as::
y <<= probe(x)[0:3] + 4
to give visibility into both the origin of ``x`` (including the line that
:class:`WireVector` was originally created) and the run-time values of ``x`` (which
will be named and thus show up by default in a trace). Likewise::
y <<= probe(x[0:3]) + 4
y <<= probe(x[0:3] + 4)
probe(y) <<= x[0:3] + 4
are all valid uses of ``probe``.
.. note::
``probe`` actually adds an :class:`Output` wire to the :ref:`working_block` of
``w``, which can confuse various post-processing transforms such as
:func:`output_to_verilog`.
:param w: :class:`WireVector` from which to get info
:param name: optional name for probe (defaults to an autogenerated name)
:return: original :class:`WireVector` ``w``
"""
if not isinstance(w, WireVector):
msg = "Only WireVectors can be probed"
raise PyrtlError(msg)
if name is None:
name = f"({probeIndexer.make_valid_string()}: {w.name})"
if _get_debug_mode():
print("Probe: " + name + " " + get_stack(w))
p = Output(name=name)
p <<= w # late assigns len from w automatically
return w
assertIndexer = _NameIndexer("assertion")
[docs]
def rtl_assert(w: WireVector, exp: Exception, block: Block = None) -> Output:
"""Add a hardware assertion to be checked during simulation.
If at any time during execution the wire ``w`` is not ``True`` (i.e. when it is
asserted low) then :class:`Simulation` will raise ``exp``.
:param w: A 1-bit :class:`WireVector` to assert.
:param exp: :class:`Exception` to throw when the assertion fails during
:class:`Simulation`.
:param block: :class:`Block` to which the assertion should be added (default to
:ref:`working_block`).
:raises exp: When ``w`` is not ``True``.
:return: The :class:`Output` wire for the assertion, which can be ignored in most
cases.
"""
block = working_block(block)
if not isinstance(w, WireVector):
msg = "Only WireVectors can be asserted with rtl_assert"
raise PyrtlError(msg)
if len(w) != 1:
msg = "rtl_assert checks only a WireVector of bitwidth 1"
raise PyrtlError(msg)
if not isinstance(exp, Exception):
msg = "the second argument to rtl_assert must be an instance of Exception"
raise PyrtlError(msg)
if isinstance(exp, KeyError):
msg = "the second argument to rtl_assert cannot be a KeyError"
raise PyrtlError(msg)
if w not in block.wirevector_set:
msg = "assertion wire not part of the block to which it is being added"
raise PyrtlError(msg)
if w not in block.wirevector_set:
msg = "assertion not a known wirevector in the target block"
raise PyrtlError(msg)
if w in block.rtl_assert_dict:
msg = "assertion conflicts with existing registered assertion"
raise PyrtlInternalError(msg)
assert_wire = Output(
bitwidth=1, name=assertIndexer.make_valid_string(), block=block
)
assert_wire <<= w
block.rtl_assert_dict[assert_wire] = exp
return assert_wire
def check_rtl_assertions(sim: Simulation):
"""Checks for failing assertions in ``sim``.
:class:`Simulation` calls this automatically, so users generally shouldn't need to
call this. See :func:`rtl_assert`.
:param sim: Simulation in which to check assertions.
"""
for w, exp in sim.block.rtl_assert_dict.items():
try:
value = sim.inspect(w)
if not value:
raise exp
except KeyError:
pass
[docs]
def log2(integer_val: int) -> int:
"""Return the base-2 logarithm of an integer.
Useful when checking that powers of 2 are provided as function inputs.
Examples::
>>> log2(2)
1
>>> log2(256)
8
>>> log2(100)
Traceback (most recent call last):
...
pyrtl.pyrtlexceptions.PyrtlError: this function can only take even powers of 2
:param integer_val: The integer to take the log base 2 of.
:raise PyrtlError: If the input is negative, or not an even power of 2.
:return: The log base 2 of ``integer_val``.
"""
i = integer_val
if not isinstance(i, int):
msg = "this function can only take integers"
raise PyrtlError(msg)
if i <= 0:
msg = "this function can only take positive numbers 1 or greater"
raise PyrtlError(msg)
if i & (i - 1) != 0:
msg = "this function can only take even powers of 2"
raise PyrtlError(msg)
return i.bit_length() - 1
[docs]
def truncate(
wirevector_or_integer: WireVector | int, bitwidth: int
) -> WireVector | int:
"""Returns a :class:`WireVector` or integer truncated to the specified ``bitwidth``.
Truncation removes the most significant bits of ``wirevector_or_integer``, leaving a
result that is only :attr:`~WireVector.bitwidth` bits wide. For :class:`ints<int>`
this is performed with a simple bitmask of size :attr:`~WireVector.bitwidth`, and
returns an :class:`int`. For :class:`WireVectors<WireVector>` the function calls
:meth:`WireVector.truncate` and returns a :class:`WireVector` with the specified
:attr:`~WireVector.bitwidth`.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Examples::
>>> truncate(0b101_001, bitwidth=3)
1
>>> bin(truncate(0b111_101, bitwidth=3))
'0b101'
>>> # -1 is 0b1111111... with the number of 1-bits equal to the bitwidth. Python
>>> # ints are arbitrary-precision, so this can produce any number of 1-bits.
>>> bin(truncate(-1, bitwidth=3))
'0b111'
>>> input = pyrtl.Input(name="input", bitwidth=8)
>>> output = truncate(input, bitwidth=4)
>>> output.name = "output"
>>> output.bitwidth
4
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0xab})
>>> hex(sim.inspect("output"))
'0xb'
:param wirevector_or_integer: A :class:`WireVector` or :class:`int` to truncate.
:param bitwidth: The length to which ``wirevector_or_integer`` should be truncated.
:return: A truncated :class:`WireVector` or :class:`int`, depending on the input
type.
"""
if bitwidth < 1:
msg = "bitwidth must be a positive integer"
raise PyrtlError(msg)
x = wirevector_or_integer
try:
return x.truncate(bitwidth)
except AttributeError:
return x & ((1 << bitwidth) - 1)
[docs]
class MatchedFields(NamedTuple):
"""Result returned by :func:`match_bitpattern`."""
matched: WireVector
"""1-bit :class:`WireVector` indicating if ``w`` matches ``bitpattern``."""
fields: NamedTuple
""":class:`NamedTuple` containing the matched fields, if any."""
def __enter__(self):
from pyrtl.conditional import _push_condition
_push_condition(self.matched)
return self.fields
def __exit__(self, *execinfo):
from pyrtl.conditional import _pop_condition
_pop_condition()
[docs]
def match_bitpattern(
w: WireVector, bitpattern: str, field_map: dict[str, str] | None = None
) -> MatchedFields:
"""Returns a single-bit :class:`WireVector` that is ``1`` if and only if ``w``
matches the ``bitpattern``, and a tuple containing the matched fields, if any.
Compatible with the ``with`` statement.
This function will compare a multi-bit :class:`WireVector` to a specified pattern of
bits, where some of the pattern can be "wildcard" bits. If any of the ``1`` or ``0``
values specified in the bitpattern fail to match the :class:`WireVector` during
execution, a ``0`` will be produced, otherwise the value carried on the wire will be
``1``. The wildcard characters can be any other alphanumeric character, with
characters other than ``?`` having special functionality (see below). The string
must have length equal to the :class:`WireVector` specified, although whitespace and
underscore characters will be ignored and can be used for pattern readability.
For all other characters besides ``1``, ``0``, or ``?``, a tuple of
:class:`WireVectors<WireVector>` will be returned as the second return value. Each
character will be treated as the name of a field, and non-consecutive fields with
the same name will be concatenated together, left-to-right, into a single field in
the resultant tuple. For example, ``01aa1?bbb11a`` will match a string such as
``010010100111``, and the resultant matched fields are::
(a, b) = (0b001, 0b100)
where the ``a`` field is the concatenation of bits 9, 8, and 0, and the ``b`` field
is the concenation of bits 5, 4, and 3. Thus, arbitrary characters beside ``?`` act
as wildcard characters for the purposes of matching, with the additional benefit of
returning the :class:`WireVectors<WireVector>` corresponding to those fields.
A prime example of this is for decoding instructions. Here we decode some RISC-V::
with pyrtl.conditional_assignment:
with match_bitpattern(
inst, "iiiiiiiiiiiirrrrr010ddddd0000011"
) as (imm, rs1, rd):
regfile[rd] |= mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)]
pc.next |= pc + 1
with match_bitpattern(
inst, "iiiiiiirrrrrsssss010iiiii0100011"
) as (imm, rs2, rs1):
mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)] |= regfile[rs2]
pc.next |= pc + 1
with match_bitpattern(
inst, "0000000rrrrrsssss111ddddd0110011"
) as (rs2, rs1, rd):
regfile[rd] |= regfile[rs1] & regfile[rs2]
pc.next |= pc + 1
with match_bitpattern(
inst, "0000000rrrrrsssss000ddddd0110011")
as (rs2, rs1, rd):
regfile[rd] |= (regfile[rs1] + regfile[rs2]).truncate(32)
pc.next |= pc + 1
# ...etc...
Some smaller examples::
# Basically the same as w == '0b0101'.
m, _ = match_bitpattern(w, '0101')
# m will be true when w is '0101' or '0111'.
m, _ = match_bitpattern(w, '01?1')
# m will be true when last two bits of w are '01'.
m, _ = match_bitpattern(w, '??01')
# spaces/underscores are ignored, same as line above.
m, _ = match_bitpattern(w, '??_0 1')
# All bits with the same letter make up same field.
m, (a, b) = match_bitpattern(w, '01aa1?bbb11a')
# Fields will be named `fs.foo` and `fs.bar`.
m, fs = match_bitpattern(w, '01aa1?bbb11a', {'a': 'foo', 'b': 'bar'})
:param w: The :class:`WireVector` to be compared to the ``bitpattern``
:param bitpattern: A string holding the pattern (of bits and wildcards) to match
:param field_map: (optional) A map from single-character field name in the
bitpattern to the desired name of field in the returned :class:`NamedTuple`. If
given, all non-``1``/``0``/``?`` characters in the ``bitpattern`` must be
present in the map.
:return: A :class:`NamedTuple` consisting of a 1-bit :class:`WireVector` carrying
the result of the comparison, and a :class:`NamedTuple` containing the
matched fields, if any.
"""
w = as_wires(w)
if not isinstance(bitpattern, str):
msg = "bitpattern must be a string"
raise PyrtlError(msg)
bitpattern = bitpattern.replace("_", "").replace(" ", "")
if len(w) != len(bitpattern):
msg = "bitpattern string different length than wirevector provided"
raise PyrtlError(msg)
# Reverse ``bitpattern`` so index 0 is the least significant bit. This makes
# ``w[i]`` and ``reversed_bitpattern[i]`` refer to the same bit ``i``.
reversed_bitpattern = bitpattern[::-1]
zero_bits = [w[index] for index, x in enumerate(reversed_bitpattern) if x == "0"]
one_bits = [w[index] for index, x in enumerate(reversed_bitpattern) if x == "1"]
match = rtl_all(*one_bits) & ~rtl_any(*zero_bits)
def field_name(name: str) -> str:
"""Retrieve a field's name from ``field_map``."""
if field_map is not None:
if name not in field_map:
msg = (
f"field_map argument has been given, but {name} field is not "
"present"
)
raise PyrtlError(msg)
return field_map[name]
return name
# ``fields`` maps from ``field_name`` to a list of WireVectors that match
# ``field_name``.
fields = collections.defaultdict(list)
for i, c in enumerate(reversed_bitpattern):
if c not in "01?":
fields[c].append(w[i])
# Sort ``fields`` by each field's position in ``bitpattern`` and convert ``fields``
# to a list of tuples.
fields = sorted(fields.items(), key=lambda m: bitpattern.index(m[0]))
Fields = collections.namedtuple("Fields", [field_name(name) for name, _ in fields])
fields = Fields(
**{
field_name(name): concat_list(wirevector_list)
for name, wirevector_list in fields
}
)
return MatchedFields(match, fields)
def bitpattern_to_val(bitpattern: str, *ordered_fields, **named_fields) -> int:
"""Return an unsigned integer representation of field format filled with the
provided values.
This function will compare a specified pattern of bits, where some of the pattern
can be "wildcard" bits. The wildcard bits must all be named with a single letter
and, unlike the related function :func:`match_bitpattern`, no "?" can be used. The
function will take the provided ``bitpattern`` and create an integer that
substitutes the provided fields in for the given wildcards at the bit level. This
sort of bit substitution is useful when creating values for testing when the
resulting values will be "chopped" up by the hardware later (e.g. instruction decode
or other bitfield heavy functions).
If a special keyword argument, ``field_map``, is provided, then the named fields
provided can be longer, human-readable field names, which will correspond to the
field in the bitpattern according to the ``field_map``. See the third example below.
Examples::
>>> # RISC-V ADD instruction.
>>> bin(bitpattern_to_val('0000000sssssrrrrr000ddddd0110011', s=2, r=1, d=3))
'0b1000001000000110110011'
>>> # RISC-V SW instruction.
>>> bin(bitpattern_to_val('iiiiiiisssssrrrrr010iiiii0100011', i=1, s=4, r=3))
'0b10000011010000010100011'
>>> # RISC-V SW instruction.
>>> bin(bitpattern_to_val(
... 'iiiiiiisssssrrrrr010iiiii0100011',
... imm=1, rs2=4, rs1=3,
... field_map={'i': 'imm', 's': 'rs2', 'r': 'rs1'}))
'0b10000011010000010100011'
:param bitpattern: A string holding the pattern (of bits and wildcards) to match
:param ordered_fields: A list of parameters to be matched to the provided bit
pattern in the order provided. If ``ordered_fields`` are provided then no
``named_fields`` can be used.
:param named_fields: A list of parameters to be matched to the provided bit pattern
by the names provided. If ``named_fields`` are provided then no
``ordered_fields`` can be used. A special keyword argument, ``field_map``, can
be provided, which will allow you to specify a correspondence between the
1-letter field names in the bitpattern string and longer, human readable field
names. See the example above.
:return: An unsigned integer carrying the result of the field substitution.
"""
if not bitpattern:
msg = "bitpattern must be nonempty"
raise PyrtlError(msg)
if len(ordered_fields) > 0 and len(named_fields) > 0:
msg = "named and ordered fields cannot be mixed"
raise PyrtlError(msg)
def letters_in_field_order():
seen = []
for c in bitpattern:
if c != "0" and c != "1" and c not in seen:
seen.append(c)
return seen
field_map = None
if "field_map" in named_fields:
field_map = named_fields["field_map"]
named_fields.pop("field_map")
bitlist = []
lifo = letters_in_field_order()
if ordered_fields:
if len(lifo) != len(ordered_fields):
msg = "number of fields and number of unique patterns do not match"
raise PyrtlError(msg)
intfields = [int(f) for f in ordered_fields]
else:
if len(lifo) != len(named_fields):
msg = "number of fields and number of unique patterns do not match"
raise PyrtlError(msg)
try:
def fn(n):
return field_map[n] if field_map else n
intfields = [int(named_fields[fn(n)]) for n in lifo]
except KeyError as exc:
msg = f"bitpattern field {exc.args[0]} was not provided in named_field list"
raise PyrtlError(msg) from exc
fmap = dict(zip(lifo, intfields, strict=True))
for c in bitpattern[::-1]:
if c == "0" or c == "1":
bitlist.append(c)
elif c == "?":
msg = "all fields in the bitpattern must have names"
raise PyrtlError(msg)
else:
bitlist.append(str(fmap[c] & 0x1)) # append lsb of the field
fmap[c] = fmap[c] >> 1 # and bit shift by one position
for f, intfield in fmap.items():
if intfield not in [0, -1]:
msg = f"too many bits given to value to fit in field {f}"
raise PyrtlError(msg)
if len(bitpattern) != len(bitlist):
msg = "resulting values have different bitwidths"
raise PyrtlInternalError(msg)
final_str = "".join(bitlist[::-1])
return int(final_str, 2)
[docs]
def chop(w: WireVector, *segment_widths: int) -> list[WireVector]:
"""Returns a list of :class:`WireVectors<WireVector>`, each a slice of the original
``w``.
This function chops a :class:`WireVector` into a set of smaller
:class:`WireVectors<WireVector>` of different lengths. It is most useful when
multiple "fields" are contained with a single :class:`WireVector`, for example when
breaking apart an instruction.
For example, if you wish to break apart a 32-bit MIPS I-type (Immediate) instruction
you know it has an 6-bit opcode, 2 5-bit operands, and 16-bit offset. You could take
each of those slices in absolute terms: ``offset=instr[0:16]``, ``rt=instr[16:21]``
and so on, but then you have to do the arithmetic yourself. With this function you
can do all the fields at once which can be seen in the examples below.
As a check, ``chop`` will throw an error if the sum of the lengths of the fields
given is not the same as the length of the :class:`WireVector` to ``chop``. Note
also that ``chop`` assumes that the "rightmost" arguments are the least signficant
bits (just like :func:`concat`) which is normal for hardware functions but makes the
list order a little counter intuitive.
.. note::
Consider using :func:`wire_struct` or :func:`wire_matrix` instead, which helps
with consistently disassembling, naming, and reassembling fields.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> input = pyrtl.Input(name="input", bitwidth=12)
>>> high, middle, low = pyrtl.chop(input, 4, 4, 4)
>>> high.name = "high"
>>> middle.name = "middle"
>>> low.name = "low"
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0xabc})
>>> hex(sim.inspect("high"))
'0xa'
>>> hex(sim.inspect("middle"))
'0xb'
>>> hex(sim.inspect("low"))
'0xc'
:param w: The :class:`WireVector` to be chopped up into segments
:param segment_widths: Additional arguments are integers which are bitwidths
:return: A list of :class:`WireVectors<WireVector>` each with a proper segment
width.
"""
w = as_wires(w)
for seg in segment_widths:
if not isinstance(seg, int):
msg = "segment widths must be integers"
raise PyrtlError(msg)
if sum(segment_widths) != len(w):
msg = "sum of segment widths must equal length of wirevetor"
raise PyrtlError(msg)
n_segments = len(segment_widths)
starts = [sum(segment_widths[i + 1 :]) for i in range(n_segments)]
ends = [sum(segment_widths[i:]) for i in range(n_segments)]
return [w[s:e] for s, e in zip(starts, ends, strict=True)]
def input_list(
names: str | list[str], bitwidth: int | list[int] | None = None
) -> list[Input]:
"""Allocate and return a list of :class:`Inputs<Input>`.
See :func:`wirevector_list`. Equivalent to::
wirevector_list(names, bitwidth, wvtype=pyrtl.Input)
.. WARNING::
Avoid using this function. Lists of ``Inputs`` can be created with list
comprehensions, which are easier to understand because they compose familiar
concepts, rather than introducing a new concept.
:param names: Names for the ``Inputs``. Can be a list or single
comma/space-separated string
:param bitwidth: The desired bitwidth for the resulting ``Inputs``.
"""
return wirevector_list(names, bitwidth, wvtype=Input)
def output_list(
names: str | list[str], bitwidth: int | list[int] | None = None
) -> list[Output]:
"""Allocate and return a list of :class:`Outputs<Output>`.
See :func:`wirevector_list`. Equivalent to::
wirevector_list(names, bitwidth, wvtype=pyrtl.Output)
.. WARNING::
Avoid using this function. Lists of ``Outputs`` can be created with list
comprehensions, which are easier to understand because they compose familiar
concepts, rather than introducing a new concept.
:param names: Names for the ``Outputs``. Can be a list or single
comma/space-separated string
:param bitwidth: The desired bitwidth for the resulting ``Outputs``.
"""
return wirevector_list(names, bitwidth, wvtype=Output)
def register_list(
names: str | list[str], bitwidth: int | list[int] | None = None
) -> list[Register]:
"""Allocate and return a list of :class:`Registers<Register>`.
See :func:`wirevector_list`. Equivalent to::
wirevector_list(names, bitwidth, wvtype=pyrtl.Register)
.. WARNING::
Avoid using this function. Lists of ``Registers`` can be created with list
comprehensions, which are easier to understand because they compose familiar
concepts, rather than introducing a new concept.
:param names: Names for the ``Registers``. Can be a list or single
comma/space-separated string
:param bitwidth: The desired bitwidth for the resulting ``Registers``.
"""
return wirevector_list(names, bitwidth, wvtype=Register)
def wirevector_list(
names: str | list[str],
bitwidth: int | list[int] | None = None,
wvtype: type[WireVector] = WireVector,
) -> list[WireVector]:
"""Allocate and return a list of :class:`WireVectors<WireVector>`.
The strings in ``names`` can also contain an additional bitwidth specification,
separated by a ``/``. This cannot be used in combination with a
:attr:`~WireVector.bitwidth` value other than ``1``.
Examples::
wirevector_list(['name1', 'name2', 'name3'])
wirevector_list('name1, name2, name3')
wirevector_list('input1 input2 input3', bitwidth=8, wvtype=pyrtl.Input)
wirevector_list('output1, output2 output3', bitwidth=3, wvtype=pyrtl.Output)
wirevector_list('two_bits/2, four_bits/4, eight_bits/8')
wirevector_list(['name1', 'name2', 'name3'], bitwidth=[2, 4, 8])
.. WARNING::
Avoid using this function. Create lists of :class:`WireVectors<WireVector>` with
list comprehensions, which are easier to understand because they compose
familiar concepts, rather than introducing a new concept::
[WireVector(name) for name in ['name1', 'name2', 'name3']]
[Input(name) for name in 'input1 input2 input3'.split(' ')]
:param names: Names for the :class:`WireVectors<WireVector>`. Can be a list or a
single comma/space-separated string
:param bitwidth: The desired bitwidth for the resulting
:class:`WireVectors<WireVector>`.
:param wvtype: The :class:`WireVector` type to create.
"""
if isinstance(names, str):
names = names.replace(",", " ").split()
if any("/" in name for name in names) and bitwidth is not None:
msg = 'only one of optional "/" or bitwidth parameter allowed'
raise PyrtlError(msg)
if bitwidth is None:
bitwidth = 1
if isinstance(bitwidth, numbers.Integral):
bitwidth = [bitwidth] * len(names)
if len(bitwidth) != len(names):
raise ValueError(
"number of names "
+ str(len(names))
+ " should match number of bitwidths "
+ str(len(bitwidth))
)
wirelist = []
for fullname, bw in zip(names, bitwidth, strict=True):
try:
name, bw = fullname.split("/")
except ValueError:
name = fullname
wirelist.append(wvtype(bitwidth=int(bw), name=name))
return wirelist
[docs]
def val_to_signed_integer(value: int, bitwidth: int) -> int:
"""Return ``value`` interpreted as a two's complement signed integer.
Reinterpret an unsigned integer (not a :class:`WireVector`!) as a signed integer.
This is useful for printing and interpreting two's complement values::
>>> val_to_signed_integer(0xff, bitwidth=8)
-1
``val_to_signed_integer`` can also be used as an ``repr_func`` for
:meth:`SimulationTrace.render_trace`, to display signed integers in traces::
bitwidth = 3
counter = Register(name='counter', bitwidth=bitwidth)
counter.next <<= counter + 1
sim = Simulation()
sim.step_multiple(nsteps=2 ** bitwidth)
# Generates a trace like:
# │0 │1 │2 │3 │4 │5 │6 │7
#
# counter ──┤1 │2 │3 │-4│-3│-2│-1
sim.tracer.render_trace(repr_func=val_to_signed_integer)
:func:`infer_val_and_bitwidth` performs the opposite conversion::
>>> integer = val_to_signed_integer(0xff, bitwidth=8)
>>> hex(infer_val_and_bitwidth(integer, bitwidth=8).value)
'0xff'
:param value: A Python integer holding the value to convert.
:param bitwidth: The length of the integer in bits to assume for conversion.
:return: ``value`` as a signed integer
"""
if isinstance(value, WireVector) or isinstance(bitwidth, WireVector):
msg = "inputs must not be wirevectors"
raise PyrtlError(msg)
if bitwidth < 1:
msg = "bitwidth must be a positive integer"
raise PyrtlError(msg)
neg_mask = 1 << (bitwidth - 1)
neg_part = value & neg_mask
pos_mask = neg_mask - 1
pos_part = value & pos_mask
return pos_part - neg_part
[docs]
class ValueBitwidthTuple(NamedTuple):
"""Return type for :func:`infer_val_and_bitwidth`."""
value: int
"""Inferred value."""
bitwidth: int
"""Inferred bitwidth."""
[docs]
def infer_val_and_bitwidth(
rawinput: int | bool | str, bitwidth: int | None = None, signed: bool = False
) -> ValueBitwidthTuple:
"""Return a ``(value, bitwidth)`` :class:`tuple` inferred from the specified input.
Given a boolean, integer, or Verilog-style string constant, this function returns a
:class:`ValueBitwidthTuple` ``(value, bitwidth)`` which are inferred from the
specified ``rawinput``. If ``signed`` is ``True``, bits will be included to ensure a
proper two's complement representation is possible, otherwise it assumes a standard
unsigned representation. Error checks are performed that determine if the bitwidths
specified are sufficient and appropriate for the values specified. Examples::
>>> infer_val_and_bitwidth(2, bitwidth=5)
ValueBitwidthTuple(value=2, bitwidth=5)
>>> # Infer bitwidth from value.
>>> infer_val_and_bitwidth(3)
ValueBitwidthTuple(value=3, bitwidth=2)
>>> infer_val_and_bitwidth(3).bitwidth
2
>>> # Signed values need an additional sign bit.
>>> infer_val_and_bitwidth(3, signed=True)
ValueBitwidthTuple(value=3, bitwidth=3)
>>> val, bitwidth = infer_val_and_bitwidth(-1, bitwidth=3)
>>> (bin(val), bitwidth)
('0b111', 3)
>>> infer_val_and_bitwidth("5'd12")
ValueBitwidthTuple(value=12, bitwidth=5)
:func:`val_to_signed_integer` performs the opposite conversion::
>>> val, bitwidth = infer_val_and_bitwidth(-1, bitwidth=3)
>>> val_to_signed_integer(val, bitwidth)
-1
:param rawinput: a bool, int, or Verilog-style string constant
:param bitwidth: an integer bitwidth or (by default) ``None``
:param signed: a bool (by default ``False``) to include bits for proper two's
complement
:return: tuple of integers ``(value, bitwidth)``
"""
if isinstance(rawinput, bool):
return _convert_bool(rawinput, bitwidth, signed)
if isinstance(rawinput, numbers.Integral):
return _convert_int(rawinput, bitwidth, signed)
if isinstance(rawinput, str):
return _convert_verilog_str(rawinput, bitwidth, signed)
msg = (
f'error, the value provided is of an improper type, "{type(rawinput)}" proper '
"types are bool, int, and string"
)
raise PyrtlError(msg)
def _convert_bool(
bool_val: bool, bitwidth: int | None = None, signed: bool = False
) -> ValueBitwidthTuple:
if signed:
msg = "error, booleans cannot be signed (convert to int first)"
raise PyrtlError(msg)
num = int(bool_val)
if bitwidth is None:
bitwidth = 1
if bitwidth != 1:
msg = "error, boolean has bitwidth not equal to 1"
raise PyrtlError(msg)
return ValueBitwidthTuple(num, bitwidth)
def _convert_int(
val: numbers.Integral, bitwidth: int | None = None, signed: bool = False
) -> ValueBitwidthTuple:
# Convert val from numbers.Integral to int. This avoids issues with
# limited-precision types like numpy.int32.
val = int(val)
if val >= 0:
num = val
# infer bitwidth if it is not specified explicitly
min_bitwidth = (
len(bin(num)) - 2
) # the -2 for the "0b" at the start of the string
if signed and val != 0:
min_bitwidth += 1 # extra bit needed for the zero
if bitwidth is None:
bitwidth = min_bitwidth
elif bitwidth < min_bitwidth:
msg = (
f"bitwidth specified ({bitwidth}) is insufficient to represent "
f"constant {val}"
)
raise PyrtlError(msg)
else: # val is negative
if not signed and bitwidth is None:
msg = (
f"negative constant {val} requires either signed=True or specified "
"bitwidth"
)
raise PyrtlError(msg)
if bitwidth is None:
bitwidth = 1 if val == -1 else len(bin(~val)) - 1
if (val >> bitwidth - 1) != -1:
msg = f"insufficient bits ({bitwidth}) for negative number {val}"
raise PyrtlError(msg)
num = val & ((1 << bitwidth) - 1) # result is a two's complement value
return ValueBitwidthTuple(num, bitwidth)
def _convert_verilog_str(
val: str, bitwidth: int | None = None, signed: bool = False
) -> ValueBitwidthTuple:
if signed:
msg = 'error, "signed" option with Verilog-style string constants not supported'
raise PyrtlError(msg)
bases = {"b": 2, "o": 8, "d": 10, "h": 16, "x": 16}
neg = False
if val.startswith("-"):
neg = True
val = val[1:]
split_string = val.lower().split("'")
if len(split_string) != 2:
msg = "error, string not in Verilog style format"
raise PyrtlError(msg)
try:
verilog_bitwidth = int(split_string[0])
bitwidth = (
bitwidth or verilog_bitwidth
) # if bitwidth is None, use verilog_bitwidth
if verilog_bitwidth > bitwidth:
msg = (
f"bitwidth parameter passed ({bitwidth}) cannot fit Verilog-style "
f"constant with bitwidth {verilog_bitwidth} (if bitwidth=None is used, "
"PyRTL will determine the bitwidth from the Verilog-style constant "
"specification)"
)
raise PyrtlError(msg)
sval = split_string[1]
if sval[0] == "s":
msg = "error, signed integers are not supported in Verilog-style constants"
raise PyrtlError(msg)
base = 10
if sval[0] in bases:
base = bases[sval[0]]
sval = sval[1:]
sval = sval.replace("_", "")
num = int(sval, base)
except (IndexError, ValueError) as exc:
msg = "error, string not in Verilog style format"
raise PyrtlError(msg) from exc
if neg and num:
if num >> bitwidth - 1:
msg = "error, insufficient bits for negative number"
raise PyrtlError(msg)
num = (1 << bitwidth) - num
if num >> bitwidth != 0:
msg = (
f"specified bitwidth {bitwidth} for Verilog constant insufficient to store "
f"value {num}"
)
raise PyrtlError(msg)
return ValueBitwidthTuple(num, bitwidth)
def get_stacks(*wires):
call_stack = getattr(wires[0], "init_call_stack", None)
if not call_stack:
return (
" No call info found for wires: use set_debug_mode() to provide more "
"information\n"
)
return "\n".join(str(wire) + ":\n" + get_stack(wire) for wire in wires)
def get_stack(wire):
if not isinstance(wire, WireVector):
msg = "Only WireVectors can be traced"
raise PyrtlError(msg)
call_stack = getattr(wire, "init_call_stack", None)
if call_stack:
frames = " ".join(frame for frame in call_stack[:-1])
return "Wire Traceback, most recent call last \n" + frames + "\n"
return (
" No call info found for wire: use set_debug_mode() to provide more "
"information"
)
def _check_for_loop(block=None):
block = working_block(block)
logic_left = block.logic.copy()
wires_left = block.wirevector_subset(exclude=(Input, Const, Output, Register))
prev_logic_left = len(logic_left) + 1
while prev_logic_left > len(logic_left):
prev_logic_left = len(logic_left)
nets_to_remove = (
set()
) # bc it's not safe to mutate a set inside its own iterator
for net in logic_left:
if not any(n_wire in wires_left for n_wire in net.args):
nets_to_remove.add(net)
wires_left.difference_update(net.dests)
logic_left -= nets_to_remove
if len(logic_left) == 0:
return None
return wires_left, logic_left
def find_loop(block=None):
block = working_block(block)
block.sanity_check() # make sure that the block is sane first
result = _check_for_loop(block)
if not result:
return None
wires_left, logic_left = result
class _FilteringState:
def __init__(self, dst_w):
self.dst_w = dst_w
self.arg_num = -1
def dead_end():
# clean up after a wire is found to not be part of the loop
wires_left.discard(cur_item.dst_w)
current_wires.discard(cur_item.dst_w)
del checking_stack[-1]
# now making a map to quickly look up nets
dest_nets = {dest_w: net_ for net_ in logic_left for dest_w in net_.dests}
initial_w = random.sample(list(wires_left), 1)[0]
current_wires = set()
checking_stack = [_FilteringState(initial_w)]
# we don't use a recursive method as Python has a limited stack (default: 999
# frames)
while checking_stack:
cur_item = checking_stack[-1]
if cur_item.arg_num == -1:
# first time testing this item
if cur_item.dst_w not in wires_left:
dead_end()
continue
current_wires.add(cur_item.dst_w)
cur_item.net = dest_nets[cur_item.dst_w]
if cur_item.net.op == "r":
dead_end()
continue
cur_item.arg_num += 1 # go to the next item
if cur_item.arg_num == len(cur_item.net.args):
dead_end()
continue
next_wire = cur_item.net.args[cur_item.arg_num]
if next_wire not in current_wires:
current_wires.add(next_wire)
checking_stack.append(_FilteringState(next_wire))
else: # We have found the loop!!!!!
loop_info = []
for f_state in reversed(checking_stack):
loop_info.append(f_state)
if f_state.dst_w is next_wire:
break
else:
msg = "Shouldn't get here! Couldn't figure out the loop"
raise PyrtlError(msg)
return loop_info
msg = "Error in detecting loop"
raise PyrtlError(msg)
def find_and_print_loop(block=None):
loop_data = find_loop(block)
print_loop(loop_data)
return loop_data
def print_loop(loop_data):
if not loop_data:
print("No Loop Found")
else:
print("Loop found:")
print("\n".join(f"{fs.net}" for fs in loop_data))
print()
def _currently_in_jupyter_notebook():
"""Return true if running under Jupyter notebook, otherwise return False.
We want to check for more than just the presence of __IPYTHON__ because that is
present in both Jupyter notebooks and IPython terminals.
"""
try:
# get_ipython() is in the global namespace when ipython is started
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
if shell == "TerminalInteractiveShell":
return False # Terminal running IPython
return False # Other type
except NameError:
return False # Probably standard Python interpreter
def _print_netlist_latex(netlist):
"""Print each net in netlist in a Latex array"""
from IPython.display import Latex, display
out = "\n\\begin{array}{ \\| c \\| c \\| l \\| }\n"
out += "\n\\hline\n"
out += "\\hline\n".join(str(n) for n in netlist)
out += "\\hline\n\\end{array}\n"
display(Latex(out))
class _NetCount:
"""Helper class to track when to stop an iteration that depends on number of nets
Mainly useful for iterations that are for optimization
"""
def __init__(self, block=None):
self.block = working_block(block)
self.prev_nets = len(self.block.logic) * 1000
def shrank(self, block=None, percent_diff=0, abs_diff=1):
"""Returns whether a block has fewer nets than before
This function checks whether the change in the number of nets is greater than
the percentage and absolute difference thresholds.
:param Block block: block to check (if changed)
:param Number percent_diff: percentage difference threshold
:param int abs_diff: absolute difference threshold
:return: boolean
"""
if block is None:
block = self.block
cur_nets = len(block.logic)
net_goal = self.prev_nets * (1 - percent_diff) - abs_diff
less_nets = cur_nets <= net_goal
self.prev_nets = cur_nets
return less_nets
shrinking = shrank
# _ComponentMeta holds the component's name, bitwidth, and type. If the _ComponentMeta's
# type is None, then the default component_type should be used instead.
_ComponentMeta = collections.namedtuple("_ComponentMeta", ["name", "bitwidth", "type"])
def _make_component(
component_meta: _ComponentMeta,
block: Block,
name: str,
component_type,
component_value,
):
"""Determine the component's type, instantiate it, and set its value."""
# Determine the component's actual type.
#
# If the _ComponentMeta specifies a type, then the component is a wire_struct or a
# wire_matrix. The _ComponentMeta's type must be used as the component's primary
# type, and the default_component_type becomes the component's concatenated_type.
#
# If the _ComponentMeta does not specify a type, the component uses the default
# component_type.
if component_meta.type is None:
actual_component_type = component_type
else:
actual_component_type = component_meta.type
component_name = ""
if len(name) > 0:
if isinstance(component_meta.name, str):
# wire_struct components are named with strings and printed with
# dots, like `struct.component`.
component_name = name + "." + component_meta.name
else:
# wire_matrix components are numbered with integers and printed
# with brackets, like `matrix[0]`.
component_name = name + "[" + str(component_meta.name) + "]"
# The logic below always creates a new wire_struct, wire_matrix, or WireVector for
# each component. If the component_value already has the appropriate type and name,
# we could use the component_value directly and we don't need a new
# struct/matrix/Vector. Correctly detecting these opportunities is complicated, so
# we keep things simple for now.
#
# Components are always initialized with one concatenated component_value, which
# provides values for all its wires. This implies that component wire_structs and
# wire_matricies always call _split().
if hasattr(actual_component_type, "_is_wire_struct"):
# Make a wire_struct component. component_value may be None.
component_kwargs = {actual_component_type._class_name: component_value}
component = actual_component_type(
name=component_name,
block=block,
concatenated_type=component_type,
**component_kwargs,
)
elif hasattr(actual_component_type, "_is_wire_matrix"):
# Make a wire_matrix component. component_value may be None.
component = actual_component_type(
name=component_name,
block=block,
concatenated_type=component_type,
values=[component_value],
)
elif isinstance(component_value, int) and actual_component_type is WireVector:
# Special case: simplify the component type to Const.
component = Const(
bitwidth=component_meta.bitwidth,
name=component_name,
block=block,
val=component_value,
)
else:
# Make a WireVector component.
component = actual_component_type(
bitwidth=component_meta.bitwidth, name=component_name, block=block
)
if component_value is not None:
component <<= component_value
return component
def _bitslice(value: int, start: int, end: int) -> int:
"""Slice an integer value bitwise, from start to end."""
mask = (1 << (end - start)) - 1
return (value >> start) & mask
def _slice(
block: Block,
schema: list[_ComponentMeta],
bitwidth: int,
component_type: type,
name: str,
concatenated,
components,
concatenated_value,
):
"""Slice ``concatenated`` into components.
``concatenated_value`` is the driver for ``concatenated``. Some optimizations are
possible by inspecting ``concatenated_value``, for example we immediately slice
Consts rather than generating slicing logic.
"""
if concatenated_value is not None and not isinstance(concatenated, Const):
concatenated <<= concatenated_value
end_index = bitwidth
for component_meta in schema:
if isinstance(concatenated_value, int):
# Special case: immediately slice Const values.
component_value = _bitslice(
concatenated_value, end_index - component_meta.bitwidth, end_index
)
else:
component_value = concatenated[
end_index - component_meta.bitwidth : end_index
]
end_index -= component_meta.bitwidth
component = _make_component(
component_meta=component_meta,
block=block,
name=name,
component_type=component_type,
component_value=component_value,
)
components[component_meta.name] = component
def _concatenate(
block: Block,
schema: list[_ComponentMeta],
component_type: type,
name: str,
concatenated,
components,
component_map,
):
"""Concatenate components from ``component_map`` to ``concatenated``."""
all_components = []
for component_meta in schema:
component_value = component_map[component_meta.name]
component = _make_component(
component_meta=component_meta,
block=block,
name=name,
component_type=component_type,
component_value=component_value,
)
components[component_meta.name] = component
all_components.append(component)
concatenated <<= concat(*all_components)
[docs]
def wire_struct(wire_struct_spec):
"""Decorator that assigns names to :class:`WireVector` slices.
``@wire_struct`` assigns names to *non-overlapping* :class:`WireVector` slices.
Suppose we have an 8-bit wide :class:`WireVector` called ``byte``. We can refer to
all 8 bits with the name ``byte``, but ``@wire_struct`` lets us refer to slices by
name, for example we could name the high 4 bits ``byte.high`` and the low 4 bits
``byte.low``. Without ``@wire_struct``, we would refer to these slices as
``byte[4:8]`` and ``byte[0:4]``, which are prone to off-by-one errors and harder to
read.
.. note::
See `example1.2-wire-struct.py
<https://github.com/UCSBarchlab/PyRTL/blob/development/examples/example1.2-wire-struct.py>`_
for additional examples.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
The example ``Byte`` ``@wire_struct`` can be defined as::
>>> @wire_struct
... class Byte:
... high: 4
... low: 4
Construction
------------
Once a ``@wire_struct`` class is defined, it can be instantiated by providing
drivers for all of its wires. This can be done in two ways:
1. Provide a driver for *each* component wire, for example::
>>> byte = Byte(high=0xA, low=0xB)
Note how the component names (``high``, ``low``) are used as keyword args for the
constructor. Drivers must be provided for *all* components.
2. Provide a driver for the entire ``@wire_struct``, for example::
>>> byte = Byte(Byte=0xAB)
Note how the class name (``Byte``) is used as a keyword arg for the constructor.
Accessing Slices
----------------
After instantiating a ``@wire_struct``, the instance functions as a
:class:`WireVector` containing all the wires. For example, ``byte`` functions as a
:class:`WireVector` with bitwidth 8::
>>> byte = Byte(Byte=0xAB)
>>> byte.bitwidth
8
The named slice can be accessed through the ``.`` operator (``__getattr__``), for
example ``byte.high`` and ``byte.low``, which both function as :class:`WireVector`
with bitwidth 4::
>>> byte = Byte(Byte=0xAB)
>>> byte.high.bitwidth
4
>>> byte.low.bitwidth
4
Both the instance and the slices are first-class :class:`WireVector`, so they can be
manipulated with all the usual PyRTL operators.
``len()`` returns the number of components in the ``@wire_struct``::
>>> byte = Byte(Byte=0xAB)
>>> len(byte)
2
Naming
------
A ``@wire_struct`` can be assigned a name in the usual way::
>>> byte = Byte(name="my_byte", high=0xA, low=0xB)
When a ``@wire_struct`` is assigned a name (``my_byte``), its components will be
assigned dotted names (``my_byte.high``, ``my_byte.low``)::
>>> byte.high.name
'my_byte.high'
>>> byte.low.name
'my_byte.low'
.. WARNING::
All ``@wire_struct`` names are only set during construction. You can later
rename a ``@wire_struct`` or its components, but those changes are local, and
will not propagate to other ``@wire_struct`` components.
Composition
-----------
.. doctest comment::
These `Composition` examples are currently not ``doctested`` because the
``doctest`` environment converts all class attributes to strings, due to its use
of ``exec``. In the example below, the attribute will be "Byte" (``str``),
instead of ``Byte`` (``class``), which breaks ``wire_struct``.
``@wire_struct`` can be composed with itself, and with :func:`wire_matrix`. For
example, we can define a ``Pixel`` that contains three ``Bytes``::
@pyrtl.wire_struct
class Pixel:
red: Byte
green: Byte
blue: Byte
Drivers must be specified for all components, but they can be specified at any
level. All these examples construct an equivalent ``@wire_struct``::
pixel = Pixel(Pixel=0xABCDEF)
pixel = Pixel(red=0xAB, green=0xCD, blue=0xEF)
pixel = Pixel(red=Byte(high=0xA, low=0xB), green=0xCD, blue=0xEF)
pixel = Pixel(red=Byte(high=0xA, low=0xB),
green=Byte(high=0xC, low=0xD),
blue=0xEF)
Hierarchical ``@wire_struct`` components are accessed by composing ``.`` operators::
pixel
pixel.red
pixel.red.high
pixel.red.low
pixel.green
pixel.green.high
pixel.green.low
pixel.blue
pixel.blue.high
pixel.blue.low
``@wire_struct`` can be composed with :func:`wire_matrix`::
Word = pyrtl.wire_matrix(component_schema=8, size=4)
@pyrtl.wire_struct
class CacheLine:
address: Word
data: Word
valid: 1
cache_line = CacheLine(address=0x01234567, data=0x89ABCDEF, valid=1)
Leaf-level components can be accessed by combining the ``.`` and ``[]`` operators,
for example ``cache_line.address[3]``.
Types
-----
You can change the type of a ``@wire_struct``'s components to a :class:`WireVector`
subclass like :class:`Input` or :class:`Output` with the ``component_type``
constructor argument::
# Generates Outputs named ``output_byte.low`` and ``output_byte.high``.
>>> byte = Byte(name="output_byte", component_type=pyrtl.Output,
... Byte=0xCD)
>>> isinstance(byte.high, pyrtl.Output)
True
>>> byte.high.name
'output_byte.high'
You can also change the type of the ``@wire_struct`` itself with the
``concatenated_type`` constructor argument::
# Generates an Input named ``input_byte``.
>>> input_byte = Byte(name="input_byte", concatenated_type=pyrtl.Input)
.. NOTE::
No values are specified for ``input_byte`` because its value is not known until
simulation time.
"""
# Convert the decorated class' annotations (dict of attr_name: attr_value)
# to a list of _ComponentMetas.
#
# dict iteration order is guaranteed to be insertion order in Python 3.7+.
schema = []
for component_name, component_bitwidth in wire_struct_spec.__annotations__.items():
# This is a hack for doctests, which convert all __annotations__ to str due to
# its use of `exec`.
if isinstance(component_bitwidth, str):
component_bitwidth = int(component_bitwidth)
if isinstance(component_bitwidth, int):
# An ordinary component ("foo: 4") that should use the default
# component_type.
schema.append(
_ComponentMeta(
name=component_name, bitwidth=component_bitwidth, type=None
)
)
else:
# A nested component ("bar: Byte") that must use the nested
# component's type.
schema.append(
_ComponentMeta(
name=component_name,
bitwidth=component_bitwidth._bitwidth,
type=component_bitwidth,
)
)
total_bitwidth = sum([component.bitwidth for component in schema])
# Name of the decorated class.
class_name = wire_struct_spec.__name__
class _WireStruct(WrappedWireVector):
"""``wire_struct`` implementation: Concatenate or slice :class:`WireVector`.
``wire_struct`` works by either concatenating component :class:`WireVector`
to create the ``wire_struct``'s full value, *or* slicing a
``wire_struct``s value to create component :class:`WireVectors<WireVector>`. A
``wire_struct`` can only concatenate or slice, not both. The decision
to concatenate or slice is made in __init__.
"""
_bitwidth = total_bitwidth
_class_name = class_name
_is_wire_struct = True
def __init__(
self,
name="",
block=None,
concatenated_type=WireVector,
component_type=WireVector,
**kwargs,
):
"""Concatenate or slice :class:`WireVector` components.
The remaining keyword args specify values for all wires. If the concatenated
value is provided, its value must be provided with the keyword arg matching
the decorated class name. For example, if the decorated class is::
@wire_struct
class Byte:
high: 4 # high is the 4 most significant bits.
low: 4 # low is the 4 least significant bits.
then the concatenated value must be provided like this::
byte = Byte(Byte=0xAB)
And if the component values are provided instead, their values are set by
keyword args matching the component names::
byte = Byte(low=0xA, high=0xB)
:param str name: The name of the concatenated wire. Must be unique. If none
is provided, one will be autogenerated. If a name is provided,
components will be assigned names of the form "{name}.{component_name}".
:param Block block: The block containing the concatenated and component
wires. Defaults to the :ref:`working_block`.
:param type concatenated_type: Type for the concatenated
:class:`WireVector`.
:param type component_type: Type for each component.
"""
# The concatenated WireVector contains all the _WireStruct's wires.
# WrappedWireVector (base class) will forward all attribute and method
# accesses on this _WireStruct to the concatenated WireVector.
if (
class_name in kwargs
and isinstance(kwargs[class_name], int)
and concatenated_type is WireVector
):
# Special case: simplify the concatenated type to Const.
concatenated = Const(
bitwidth=self._bitwidth,
name=name,
block=block,
val=kwargs[class_name],
)
else:
concatenated = concatenated_type(
bitwidth=self._bitwidth, name=name, block=block
)
super().__init__(wire=concatenated)
# self._components maps from component name to each component's WireVector.
components = {}
self.__dict__["_components"] = components
# Handle Input and Register special cases.
if concatenated_type is Input or concatenated_type is Register:
kwargs = {class_name: None}
elif component_type is Input or component_type is Register:
kwargs = {component_meta.name: None for component_meta in schema}
if class_name in kwargs:
# Check for unused kwargs.
for component_name in kwargs:
if component_name != class_name:
msg = (
"Do not pass additional kwargs to @wire_struct when "
f'slicing. ("{class_name}" was passed so don\'t pass '
f'"{component_name}")'
)
raise PyrtlError(msg)
# Concatenated value was provided. Slice it into components.
_slice(
block=block,
schema=schema,
bitwidth=self._bitwidth,
component_type=component_type,
name=name,
concatenated=concatenated,
components=components,
concatenated_value=kwargs[class_name],
)
else:
# Component values were provided; concatenate them.
# Check that values were provided for all components.
expected_component_names = [
component_meta.name for component_meta in schema
]
for expected_component_name in expected_component_names:
if expected_component_name not in kwargs:
msg = (
"You must provide kwargs for all @wire_struct components "
"when concatenating (missing kwarg "
f'"{expected_component_name}")'
)
raise PyrtlError(msg)
# Check for unused kwargs.
for component_name in kwargs:
if component_name not in expected_component_names:
msg = (
"Do not pass additional kwargs to @wire_struct when "
f'concatenating (don\'t pass "{component_name}")'
)
raise PyrtlError(msg)
_concatenate(
block=block,
schema=schema,
component_type=component_type,
name=name,
concatenated=concatenated,
components=components,
component_map=kwargs,
)
def __getattr__(self, component_name: str):
"""Retrieve a component by name.
Components are concatenated to form the concatenated :class:`WireVector`, or
sliced from the concatenated :class:`WireVector`.
:param component_name: The name of the component wire.
"""
components = self.__dict__["_components"]
if component_name in components:
return components[component_name]
return super().__getattr__(component_name)
def __len__(self):
components = self.__dict__["_components"]
return len(components)
return _WireStruct
[docs]
def wire_matrix(component_schema, size: int):
"""Returns a class that assigns numbered indices to :class:`WireVector` slices.
``wire_matrix`` assigns numbered indices to *non-overlapping* :class:`WireVector`
slices. ``wire_matrix`` is very similar to :func:`wire_struct`, so read
:func:`wire_struct`'s documentation first.
.. note::
See `example1.2-wire-struct.py
<https://github.com/UCSBarchlab/PyRTL/blob/development/examples/example1.2-wire-struct.py>`_
for additional examples.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
An example 32-bit ``Word`` ``wire_matrix``, which represents a group of four bytes,
can be defined as::
>>> Word = wire_matrix(component_schema=8, size=4)
.. NOTE::
``wire_matrix`` returns a class, like :func:`~collections.namedtuple`.
Construction
------------
Once a ``wire_matrix`` class is defined, it can be instantiated by providing drivers
for all of its wires. This can be done in two ways::
# Provide a driver for each component, most significant bits first.
>>> word = Word(values=[0x89, 0xAB, 0xCD, 0xEF])
# Provide a driver for all components.
>>> word = Word(values=[0x89ABCDEF])
.. NOTE::
When specifying drivers for each component, the most significant bits are
specified first.
After instantiating a ``wire_matrix``, regardless of how it was constructed, the
instance functions as a :class:`WireVector` containing all the wires, so ``word``
functions as a :class:`WireVector` with bitwidth 32. The named slice can be accessed
with square brackets (``__getitem__``), for example ``word[0]`` and ``word[3]``,
which both function as :class:`WireVector` with bitwidth 8. ``word[0]`` refers to
the most significant byte, and ``word[3]`` refers to the least significant byte.
Both the instance and the slices are first-class :class:`WireVector`, so they can be
manipulated with all the usual PyRTL operators.
Naming
------
A ``wire_matrix`` can be assigned a name in the usual way::
# The whole Word is named 'w', so the components will have names
# w[0], w[1], ...
>>> word = Word(name="w", values=[0x89, 0xAB, 0xCD, 0xEF])
>>> word[0].name
'w[0]'
Composition
-----------
.. doctest comment::
These `Composition` examples are currently not ``doctested`` because the
``doctest`` environment converts all class attributes to strings, due to its use
of ``exec``. In the example below, the attribute will be "Byte" (``str``),
instead of ``Byte`` (``class``), which breaks ``wire_struct``.
``wire_matrix`` can be composed with itself and :func:`wire_struct`. For example, we
can define some multi-dimensional byte arrays::
Array1D = wire_matrix(component_schema=8, size=2)
Array2D = wire_matrix(component_schema=Array1D, size=2)
Drivers must be specified for all components, but they can be specified at any
level. All these examples construct an equivalent ``wire_matrix``::
array_2d = Array2D(values=[0x89AB, 0xCDEF])
array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]),
0xCDEF])
array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]),
Array1D(values=[0xCD, 0xEF])])
Accessing Slices
----------------
Hierarchical components are accessed by composing ``[]`` operators, for example::
print(array_2d[0][0].bitwidth) # Prints 8.
print(array_2d[0][1].bitwidth) # Prints 8.
When ``wire_matrix`` is composed with :func:`wire_struct`, components can be
accessed by combining the ``[]`` and ``.`` operators::
@wire_struct
class Byte:
high: 4
low: 4
Array1D = wire_matrix(component_schema=Byte, size=2)
array_1d = Array1D(values=[0xAB, 0xCD])
print(array_1d[0].high.bitwidth) # Prints 4.
``len()`` returns the number of components in the ``wire_matrix``::
print(len(array_1d)) # Prints '2'.
Types
-----
You can change the type of a ``wire_matrix``'s components with the
``component_type`` constructor argument::
# Generates Outputs named ``output_word[0]``, ``output_word[1]``, ...
>>> word = Word(name="output_word",
... component_type=pyrtl.Output,
... values=[0x89ABCDEF])
>>> isinstance(word[1], pyrtl.Output)
True
>>> word[1].name
'output_word[1]'
You can change the type of the ``wire_matrix`` itself with the ``concatenated_type``
cnstructor argument::
# Generates an Input named ``input_word``.
>>> word = Word(name="input_word", concatenated_type=pyrtl.Input)
.. NOTE::
No values are specified for ``input_word`` because its value is not known until
simulation time.
"""
# Determine each component's bitwidth.
if hasattr(component_schema, "_is_wire_struct") or hasattr(
component_schema, "_is_wire_matrix"
):
component_bitwidth = component_schema._bitwidth
else:
component_bitwidth = component_schema
component_schema = None
class _WireMatrix(WrappedWireVector):
_component_bitwidth = component_bitwidth
_component_schema = component_schema
_size = size
_bitwidth = component_bitwidth * size
_is_wire_matrix = True
def __init__(
self,
name: str = "",
block: Block = None,
concatenated_type=WireVector,
component_type=WireVector,
values: list | None = None,
):
# The concatenated WireVector contains all the _WireMatrix's wires.
# WrappedWireVector (base class) will forward all attribute and method
# accesses on this _WireMatrix to the concatenated WireVector.
if values is None:
values = []
if (
len(values) == 1
and isinstance(values[0], int)
and concatenated_type is WireVector
):
# Special case: simplify the concatenated type to Const.
concatenated = Const(
bitwidth=self._bitwidth, name=name, block=block, val=values[0]
)
else:
concatenated = concatenated_type(
bitwidth=self._bitwidth, name=name, block=block
)
super().__init__(wire=concatenated)
schema = []
for component_name in range(self._size):
schema.append(
_ComponentMeta(
name=component_name,
bitwidth=self._component_bitwidth,
type=component_schema,
)
)
# By default, slice the concatenated value into components iff exactly one
# value is provided.
slicing = len(values) == 1
# Handle Input and Register special cases.
if concatenated_type is Input or concatenated_type is Register:
# Slice the concatenated value. Override the default 'slicing' because
# 'values' is empty when slicing a concatenated Input or Register.
#
# Note that we can't just check len(values) == 1 after we set values to
# [None] because that doesn't work when there is only one element in the
# wire_matrix. We must distinguish between:
#
# 1. Slicing values to produce values[0] (this case).
# 2. Concatenating values[0] to produce values (next case).
#
# But len(values) == 1 in both cases. The slice in (1) and concatenate
# in (2) are both no-ops, but we have to get the direction right. In the
# first case, values[0] is driven by values, and in the second case,
# values is driven by values[0].
slicing = True
values = [None]
elif component_type is Input or component_type is Register:
values = [None for _ in range(self._size)]
self._components = [None for i in range(len(schema))]
if slicing:
# Concatenated value was provided. Slice it into components.
_slice(
block=block,
schema=schema,
bitwidth=self._bitwidth,
component_type=component_type,
name=name,
concatenated=concatenated,
components=self._components,
concatenated_value=values[0],
)
else:
if len(values) != len(schema):
msg = (
"wire_matrix constructor expects 1 value to slice, or "
f"{len(schema)} values to concatenate (received {len(values)} "
"values)"
)
raise PyrtlError(msg)
# Component values were provided; concatenate them.
_concatenate(
block=block,
schema=schema,
component_type=component_type,
name=name,
concatenated=concatenated,
components=self._components,
component_map=values,
)
def __getitem__(self, key):
return self._components[key]
def __len__(self):
return len(self._components)
return _WireMatrix
[docs]
def one_hot_to_binary(w: WireVectorLike) -> WireVector:
"""Takes a one-hot input and returns the bit position of the high bit in binary.
If the input contains multiple ``1``'s, the smallest bit position containing a ``1``
will be returned. If the input contains no ``1``'s, ``0`` will be returned.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Examples::
>>> input = pyrtl.Input(name="input", bitwidth=8)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.one_hot_to_binary(input)
>>> output.bitwidth
3
>>> 2 ** 3 == input.bitwidth
True
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0b0010})
>>> sim.inspect("output")
1
>>> sim.step(provided_inputs={"input": 64})
>>> sim.inspect("output")
6
>>> 2 ** 6
64
>>> # Bit positions 2 and 3 contain 1's, but 2 is smaller.
>>> sim.step(provided_inputs={"input": 0b1100})
>>> sim.inspect("output")
2
>>> sim.step(provided_inputs={"input": 0})
>>> sim.inspect("output")
0
:param w: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:return: a :class:`WireVector` containing the smallest bit position of the high bit,
in binary.
"""
w = as_wires(w)
pos = 0 # Bit position of the first 1
already_found = as_wires(False) # True if first 1 already found, False otherwise
for i in range(len(w)):
pos = select(w[i] & ~already_found, i, pos)
already_found = already_found | w[i]
return pos
[docs]
def binary_to_one_hot(
bit_position: WireVectorLike, max_bitwidth: int | None = None
) -> WireVector:
"""Given a ``bit_position``, return a :class:`WireVector` with only that bit set to
``1``.
If the ``max_bitwidth`` provided is not large enough for the given ``bit_position``,
a ``0``-valued :class:`WireVector` of size ``max_bitwidth`` will be returned.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Examples::
>>> input = pyrtl.Input(name="input", bitwidth=4)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.binary_to_one_hot(input)
>>> output.bitwidth
16
>>> 16 == 2 ** input.bitwidth
True
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0})
>>> bin(sim.inspect("output"))
'0b1'
>>> sim.step(provided_inputs={"input": 3})
>>> bin(sim.inspect("output"))
'0b1000'
:param bit_position: A :class:`WireVector`, or any type that can be coerced to
:class:`WireVector` by :func:`as_wires`.
:param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot
:class:`WireVector`.
:return: A :class:`WireVector` with the bit at ``bit_position`` set to ``1`` and all
other bits set to ``0``. Bit position 0 is the least significant bit.
"""
bit_position = as_wires(bit_position)
if max_bitwidth is not None:
bitwidth = max_bitwidth
else:
bitwidth = 2 ** len(bit_position)
# Need to dynamically set the appropriate bit position since bit_position may not be
# a Const
return shift_left_logical(Const(1, bitwidth=bitwidth), bit_position)