Source code for pyrtl.rtllib.muxes

"""
Basic multiplexers are defined in PyRTL's core library, see:

- :func:`.select` for a multiplexer that selects between two options.

- :func:`.mux` for a multiplexer that selects between an arbitrary number of options.

- :ref:`conditional_assignment` for a more readable alternative to nested
  :func:`selects<.select>` and :func:`muxes<.mux>`.

The functions below provide more complex alternatives.
"""

import numbers

import pyrtl
from pyrtl import WireVector


[docs] def prioritized_mux(selects: list[WireVector], vals: list[WireVector]) -> WireVector: """Returns the value in the first wire for which its ``select`` bit is ``1`` If none of the ``selects`` are ``1``, the last ``val`` is returned. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> selects = [pyrtl.Input(name=f"select{i}", bitwidth=1) ... for i in range(3)] >>> vals = [pyrtl.Const(n) for n in range(2, 5)] >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.muxes.prioritized_mux(selects, vals) >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={"select0": 1, "select1": 0, "select2": 0}) >>> sim.inspect("output") 2 >>> sim.step(provided_inputs={"select0": 0, "select1": 1, "select2": 0}) >>> sim.inspect("output") 3 >>> sim.step(provided_inputs={"select0": 0, "select1": 0, "select2": 0}) >>> sim.inspect("output") 4 :param selects: A list of :class:`WireVectors<.WireVector>` signaling whether a wire should be chosen. :param vals: Values to return when the corresponding ``select`` value is ``1``. :return: The selected value. """ if len(selects) != len(vals): msg = "Number of select and val signals must match" raise pyrtl.PyrtlError(msg) if len(vals) == 0: msg = "Must have a signal to mux" raise pyrtl.PyrtlError(msg) if len(vals) == 1: return vals[0] half = len(vals) // 2 return pyrtl.select( pyrtl.rtl_any(*selects[:half]), truecase=prioritized_mux(selects[:half], vals[:half]), falsecase=prioritized_mux(selects[half:], vals[half:]), )
def _is_equivalent(w1, w2): if isinstance(w1, pyrtl.Const) & isinstance(w2, pyrtl.Const): return (w1.val == w2.val) & (w1.bitwidth == w2.bitwidth) return w1 is w2 SparseDefault = "default" """ A special key for :func:`sparse_mux`'s ``vals`` :class:`dict` that specifies the mux's default value. """
[docs] def sparse_mux(sel: WireVector, vals: dict[int, WireVector]) -> WireVector: """Mux that avoids instantiating unnecessary ``mux_2s`` when possible. ``sparse_mux`` supports not having a full specification. Indices that are not specified are treated as don't-cares. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> select = pyrtl.Input(name="select", bitwidth=3) >>> vals = {2: pyrtl.Const(3), ... 4: pyrtl.Const(5), ... pyrtl.rtllib.muxes.SparseDefault: pyrtl.Const(7)} >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.muxes.sparse_mux(select, vals) >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={"select": 2}) >>> sim.inspect("output") 3 >>> sim.step(provided_inputs={"select": 4}) >>> sim.inspect("output") 5 >>> sim.step(provided_inputs={"select": 3}) >>> sim.inspect("output") 7 :param sel: Select wire, which chooses one of the mux input ``vals`` to output. :param vals: :class:`dict` of mux input values. If the special key :data:`SparseDefault` exists, it specifies the ``sparse_mux``'s default value. :return: The :class:`.WireVector` selected from ``vals`` by ``sel``. """ max_val = 2 ** len(sel) - 1 if SparseDefault in vals: default_val = vals[SparseDefault] del vals[SparseDefault] for i in range(max_val + 1): if i not in vals: vals[i] = default_val for key in vals: if not isinstance(key, numbers.Integral): msg = f"value {key} must be either an integer or 'default'" raise pyrtl.PyrtlError(msg) if key < 0 or key > max_val: msg = f"value {key} is out of range of the sel wire" raise pyrtl.PyrtlError(msg) return _sparse_mux(sel, vals)
def _sparse_mux(sel, vals): """Mux that avoids instantiating unnecessary mux_2s when possible. This mux supports not having a full specification. indices that are not specified are treated as Don't Cares :param WireVector sel: Select wire, determines what is selected on a given cycle :param {int: WireVector} vals: dictionary to store the values that are :return: WireVector that signifies the change """ items = list(vals.values()) if len(vals) <= 1: if len(vals) == 0: msg = "Needs at least one parameter for val" raise pyrtl.PyrtlError(msg) return items[0] if len(sel) == 1: try: false_result = vals[0] true_result = vals[1] except KeyError as exc: msg = ( "Failed to retrieve values for smartmux. The length of sel might be " "wrong" ) raise pyrtl.PyrtlError(msg) from exc else: half = 2 ** (len(sel) - 1) first_dict = {indx: wire for indx, wire in vals.items() if indx < half} second_dict = {indx - half: wire for indx, wire in vals.items() if indx >= half} if not first_dict: return sparse_mux(sel[:-1], second_dict) if not second_dict: return sparse_mux(sel[:-1], first_dict) false_result = sparse_mux(sel[:-1], first_dict) true_result = sparse_mux(sel[:-1], second_dict) if _is_equivalent(false_result, true_result): return true_result return pyrtl.select(sel[-1], falsecase=false_result, truecase=true_result) class MultiSelector: """The MultiSelector allows you to specify multiple wire value results for a single select wire. Useful for processors, finite state machines and other places where the result of many wire values are determined by a common wire signal (such as a 'state' wire). Example:: with muxes.MultiSelector(select, res0, res1, res2, ...) as ms: ms.option(val1, data0, data1, data2, ...) ms.option(val2, data0_2, data1_2, data2_2, ...) This means that when the ``select`` wire equals the ``val1`` wire the results will have the values in ``data0, data1, data2, ...`` (all ints are converted to wires) .. WARNING:: Use :ref:`conditional_assignment` instead. """ def __init__(self, signal_wire, *dest_wires): self._final = False self.dest_wires = dest_wires self.signal_wire = signal_wire self.instructions = [] self.dest_instrs_info = {dest_w: [] for dest_w in dest_wires} def __enter__(self): """For compatibility with `with` statements, which is the recommended method of using a MultiSelector. """ return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: self.finalize() else: print("The MultiSelector was not finalized due to uncaught exception") def _check_finalized(self): if self._final: msg = "Cannot change InstrConnector, already finalized" raise pyrtl.PyrtlError(msg) def option(self, select_val, *data_signals): self._check_finalized() instr = pyrtl.infer_val_and_bitwidth( select_val, self.signal_wire.bitwidth ).value if instr in self.instructions: msg = f"instruction {select_val} already exists" raise pyrtl.PyrtlError(msg) self.instructions.append(instr) self._add_signal(data_signals) def default(self, *data_signals): self._check_finalized() self.instructions.append(SparseDefault) self._add_signal(data_signals) def _add_signal(self, data_signals): self._check_finalized() if len(data_signals) != len(self.dest_wires): msg = ( "Incorrect number of data_signals for instruction received " f"{len(data_signals)}, expected {len(self.dest_wires)}" ) raise pyrtl.PyrtlError(msg) for dw, sig in zip(self.dest_wires, data_signals, strict=True): data_signal = pyrtl.as_wires(sig, dw.bitwidth) self.dest_instrs_info[dw].append(data_signal) def finalize(self): """Connects the wires.""" self._check_finalized() self._final = True for dest_w, values in self.dest_instrs_info.items(): mux_vals = dict(zip(self.instructions, values, strict=False)) dest_w <<= sparse_mux(self.signal_wire, mux_vals)
[docs] def demux(select: WireVector) -> tuple[WireVector, ...]: """Demultiplexes a wire of arbitrary bitwidth. This effectively converts an unsigned binary value into a one-hot encoded value, returning each bit of the one-hot encoded value as a separate :class:`.WireVector`. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> input = pyrtl.Input(bitwidth=3) >>> outputs = pyrtl.rtllib.muxes.demux(input) >>> len(outputs) 8 >>> len(outputs[0]) 1 >>> for i, wire in enumerate(outputs): ... wire.name = f"outputs[{i}]" >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={input.name: 5}) >>> sim.inspect("outputs[4]") 0 >>> sim.inspect("outputs[5]") 1 >>> sim.inspect("outputs[6]") 0 In the example above, ``len(outputs)`` is ``8`` because ``2 ** 3 == 8``, and ``outputs[5]`` is ``1`` because the output index ``5`` matches the input value. See :func:`.binary_to_one_hot`, which performs a similar operation. .. WARNING:: ``demux`` can create a very large number of ``WireVectors``. Use with caution. :param select: The value to demultiplex. :return: A tuple of 1-bit wires, where each wire indicates if the value of ``select`` equals the wire's index in the tuple. The tuple has length ``2 ** select.bitwidth``. """ if len(select) == 1: return _demux_2(select) wires = demux(select[:-1]) sel = select[-1] not_select = ~sel zero_wires = tuple(not_select & w for w in wires) one_wires = tuple(sel & w for w in wires) return zero_wires + one_wires
def _demux_2(select): assert len(select) == 1 return ~select, select