Source code for pyrtl.rtllib.matrix

from __future__ import annotations

import builtins
from functools import reduce

from pyrtl import (
    Const,
    PyrtlError,
    WireVector,
    as_wires,
    concat,
    formatted_str_to_val,
    select,
)
from pyrtl.rtllib.multipliers import fused_multiply_adder
from pyrtl.wire import WireVectorLike


[docs] class Matrix: """Class for making a Matrix using PyRTL. Provides the ability to perform different matrix operations. """ # Internally, this class uses a Python list of lists of WireVectors. # So, a Matrix is represented as follows for a 2 x 2: # [[WireVector, WireVector], # [WireVector, WireVector]]
[docs] def __init__( self, rows: int, columns: int, bits: int, signed: bool = False, # noqa: ARG002 value: WireVector | list[list[WireVectorLike]] = None, max_bits: int = 64, ): """Constructs a Matrix object. :param rows: The number of rows in the matrix. Must be greater than 0. :param columns: The number of columns in the matrix. Must be greater than 0. :param bits: The number of bits per :class:`.WireVector` matrix element. Must be greater than 0. :param signed: Currently not supported (will be added in the future). :param value: The value you want to initialize the ``Matrix`` to. If a :class:`.WireVector`, must be of size ``rows * columns * bits``. If a :class:`list`, must have ``rows`` rows and ``columns`` columns, and every element must be representable with a :attr:`~.WireVector.bitwidth` of ``bits``. If ``None``, the matrix initializes to 0. :param max_bits: The maximum number of bits each :class:`.WireVector` element can grow to. Operations like multiplication and addition can produce matrices with more ``bits``, but results will be limited to ``max_bits``. """ if not isinstance(rows, int): msg = ( f'Rows must be of type int, instead "{rows}" was passed of type ' f"{type(rows)}" ) raise PyrtlError(msg) if rows <= 0: msg = ( f"Rows cannot be less than or equal to zero. Rows value passed: {rows}" ) raise PyrtlError(msg) if not isinstance(columns, int): msg = ( f'Columns must be of type int, instead "{columns}" was passed of type ' f"{type(columns)}" ) raise PyrtlError(msg) if columns <= 0: msg = ( "Columns cannot be less than or equal to zero. Columns value passed: " f"{columns}" ) raise PyrtlError(msg) if not isinstance(bits, int): msg = ( f'Bits must be of type int, instead "{bits}" was passed of type ' f"{type(bits)}" ) raise PyrtlError(msg) if bits <= 0: msg = f'Bits cannot be negative or zero, instead "{bits}" was passed' raise PyrtlError(msg) if max_bits is not None and bits > max_bits: bits = max_bits self._matrix = [[0 for _ in range(columns)] for _ in range(rows)] if value is None: for i in range(rows): for j in range(columns): self._matrix[i][j] = Const(0) elif isinstance(value, WireVector): if value.bitwidth != bits * rows * columns: msg = ( "Initialized bitwidth value does not match given value.bitwidth: " f"{value.bitwidth}, expected: {bits * rows * columns}" ) raise PyrtlError(msg) for i in range(rows): for j in range(columns): start_index = (j * bits) + (i * columns * bits) self._matrix[rows - i - 1][columns - j - 1] = as_wires( value[start_index : start_index + bits], bitwidth=bits ) elif isinstance(value, list): if len(value) != rows or any(len(row) != columns for row in value): msg = ( "Rows and columns mismatch\n" f"Rows: {len(value)}, expected: {rows}\n" f"Columns: {len(value[0])}, expected: {columns}" ) raise PyrtlError(msg) for i in range(rows): for j in range(columns): self._matrix[i][j] = as_wires(value[i][j], bitwidth=bits) else: msg = ( "Initialized value must be of type WireVector or list. Instead was " f"passed value of type {type(value)}" ) raise PyrtlError(msg) self.rows = rows self.columns = columns self._bits = bits self.bits = bits self.signed = False self.max_bits = max_bits
@property def bits(self) -> int: """The number of bits for each matrix element. Reducing the number of ``bits`` will :meth:`~.WireVector.truncate` the most significant bits of each matrix element. """ return self._bits @bits.setter def bits(self, bits): if not isinstance(bits, int): msg = ( f'Bits must be of type int, instead "{bits}" was passed of type ' f"{type(bits)}" ) raise PyrtlError(msg) if bits <= 0: msg = f'Bits cannot be negative or zero, instead "{bits}" was passed' raise PyrtlError(msg) self._bits = bits for i in range(self.rows): for j in range(self.columns): self._matrix[i][j] = self._matrix[i][j][:bits]
[docs] def __len__(self) -> int: """Returns the total bitwidth for all elements in the ``Matrix``. :return: The ``Matrix``'s total :attr:`~.WireVector.bitwidth`: ``rows * columns * bits``. """ return self.bits * self.rows * self.columns
[docs] def to_wirevector(self) -> WireVector: """Returns all elements in the ``Matrix`` in one :class:`.WireVector`. This :func:`concatenates<.concat>` all the ``Matrix``'s elements together, in row-major order. For example, a 2 x 1 matrix ``[[wire_a, wire_b]]`` would become ``pyrtl.concat(wire_a, wire_b)``. :return: A concatenated :class:`.WireVector` containing all of the ``Matrix``'s elements. """ result = [] for i in range(len(self._matrix)): for j in range(len(self._matrix[0])): result.append(as_wires(self[i, j], bitwidth=self.bits)) return as_wires(concat(*result), bitwidth=len(self))
[docs] def transpose(self) -> Matrix: """ :return: A ``Matrix`` representing the transpose of ``self``. """ result = Matrix(self.columns, self.rows, self.bits, max_bits=self.max_bits) for i in range(result.rows): for j in range(result.columns): result[i, j] = self[j, i] return result
[docs] def __reversed__(self) -> Matrix: """Invoked with the :func:`reversed` builtin. :return: A ``Matrix`` with all row and column indices reversed. """ result = Matrix(self.rows, self.columns, self.bits, max_bits=self.max_bits) for i in range(self.rows): for j in range(self.columns): result[i, j] = self[self.rows - 1 - i, self.columns - 1 - j] return result
[docs] def __getitem__(self, key: int | slice | tuple[int, int]) -> WireVector | Matrix: """Access elements in the ``Matrix``. Invoked with square brackets, like ``matrix[...]``. Examples:: int_matrix = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] matrix = Matrix(rows=3, columns=3, bits=4, value=int_matrix) # Retrieve the second row. matrix[1] == [3, 4, 5] # Retrieve the last row. matrix[-1] == [6, 7, 8] # Retrieve the element in row 2, column 0. matrix[2, 0] == 6 matrix[(2, 0)] = 6 # Retrieve the first two rows. matrix[slice(0, 2), slice(0, 3)] == [[0, 1, 2], [3, 4, 5]] matrix[0:2, 0:3] == [[0, 1, 2], [3, 4, 5]] matrix[:2] == [[0, 1, 2], [3, 4, 5]] # Retrieve the last two rows. matrix[-2:] == [[3, 4, 5], [6, 7, 8]] :param key: The key value to get. :return: :class:`.WireVector` or ``Matrix`` containing the value of key. """ if isinstance(key, tuple): rows, columns = key # First set up proper slice if not isinstance(rows, slice): if not isinstance(rows, int): msg = ( f'Rows must be of type int or slice, instead "{rows}" was ' f"passed of type {type(rows)}" ) raise PyrtlError(msg) if rows < 0: rows = self.rows - abs(rows) if rows < 0: msg = ( f"Invalid bounds for rows. Max rows: {self.rows}, got: " f"{rows}" ) raise PyrtlError(msg) rows = slice(rows, rows + 1, 1) if not isinstance(columns, slice): if not isinstance(columns, int): msg = ( f'Columns must be of type int or slice, instead "{columns}" ' f"was passed of type {type(columns)}" ) raise PyrtlError(msg) if columns < 0: columns = self.columns - abs(columns) if columns < 0: msg = ( f"Invalid bounds for columns. Max columns: {self.columns}, " f"got: {columns}" ) raise PyrtlError(msg) columns = slice(columns, columns + 1, 1) if rows.start is None: rows = slice(0, rows.stop, rows.step) elif rows.start < 0: rows = slice(self.rows - abs(rows.start), rows.stop, rows.step) if rows.stop is None: rows = slice(rows.start, self.rows, rows.step) elif rows.stop < 0: rows = slice(rows.start, self.rows - abs(rows.stop), rows.step) rows = slice(rows.start, rows.stop, 1) if columns.start is None: columns = slice(0, columns.stop, columns.step) elif columns.start < 0: columns = slice( self.columns - abs(columns.start), columns.stop, columns.step ) if columns.stop is None: columns = slice(columns.start, self.columns, columns.step) elif columns.stop < 0: columns = slice( columns.start, self.columns - abs(columns.stop), columns.step ) columns = slice(columns.start, columns.stop, 1) # Check slice bounds if ( rows.start > self.rows or rows.stop > self.rows or rows.start < 0 or rows.stop < 0 ): msg = ( f"Invalid bounds for rows. Max rows: {self.rows}, got: " f"{rows.start}:{rows.stop}" ) raise PyrtlError(msg) if ( columns.start > self.columns or columns.stop > self.columns or columns.start < 0 or columns.stop < 0 ): msg = ( f"Invalid bounds for columns. Max columns: {self.columns}, got: " f"{columns.start}:{columns.stop}" ) raise PyrtlError(msg) # If it's a single value we want to return a wirevector if rows.stop - rows.start == 1 and columns.stop - columns.start == 1: return as_wires(self._matrix[rows][0][columns][0], bitwidth=self.bits) # Otherwise set up matrix and return that result = [ [0 for _ in range(columns.stop - columns.start)] for _ in range(rows.stop - rows.start) ] for i in range(len(result)): for j in range(len(result[0])): result[i][j] = self._matrix[i + rows.start][j + columns.start] return Matrix( len(result), len(result[0]), self._bits, signed=self.signed, value=result, max_bits=self.max_bits, ) # Second case when we just want to get full row if isinstance(key, int): if key < 0: start = self.rows - abs(key) if start < 0: msg = ( f"Index {key} is out of bounds for matrix with {self.rows} rows" ) raise PyrtlError(msg) key = slice(start, start + 1, None) else: key = slice(key, key + 1, None) return self[key, :] # Third case when we want multiple rows if isinstance(key, slice): return self[key, :] # Otherwise improper value was passed msg = ( f'Rows must be of type int or slice, instead "{key}" was passed of type ' f"{type(key)}" ) raise PyrtlError(msg)
def __setitem__( self, key: int | slice | tuple[int, int], value: WireVectorLike | Matrix, ): """Mutate the ``Matrix``. Invoked with square brackets, like ``matrix[a, b] = value``. ``value`` will be truncated so it fits in :attr:`bits`. This modifies the ``Matrix``'s :class:`lists<list>` in its internal :class:`list` of :class:`list` of :class:`WireVectors<.WireVector>`, which makes the ``Matrix`` use a different set of :class:`WireVectors<.WireVector>` as its elements. It does not modify any :class:`WireVectors<.WireVector>`. :param key: The key value to set. :param value: The value in which to set the key. """ if isinstance(key, tuple): rows, columns = key # First ensure that slices are correct if not isinstance(rows, slice): if not isinstance(rows, int): msg = ( f'Rows must be of type int or slice, instead "{rows}" was ' f"passed of type {type(rows)}" ) raise PyrtlError(msg) rows = slice(rows, rows + 1, 1) if not isinstance(columns, slice): if not isinstance(columns, int): msg = ( f'Columns must be of type int or slice, instead "{columns}" ' f"was passed of type {type(columns)}" ) raise PyrtlError(msg) columns = slice(columns, columns + 1, 1) if rows.start is None: rows = slice(0, rows.stop, rows.step) elif rows.start < 0: rows = slice(self.rows - abs(rows.start), rows.stop, rows.step) if rows.stop is None: rows = slice(rows.start, self.rows, rows.step) elif rows.stop < 0: rows = slice(rows.start, self.rows - abs(rows.stop), rows.step) if columns.start is None: columns = slice(0, columns.stop, columns.step) elif columns.start < 0: columns = slice( self.columns - abs(columns.start), columns.stop, columns.step ) if columns.stop is None: columns = slice(columns.start, self.columns, columns.step) elif columns.stop < 0: columns = slice( columns.start, self.columns - abs(columns.stop), columns.step ) # Check Slice Bounds if ( rows.start > self.rows or rows.stop > self.rows or rows.start < 0 or rows.stop < 0 ): msg = ( f"Invalid bounds for rows. Max rows: {self.rows}, got: " f"{rows.start}:{rows.stop}" ) raise PyrtlError(msg) if ( columns.start > self.columns or columns.stop > self.columns or columns.start < 0 or columns.stop < 0 ): msg = ( f"Invalid bounds for columns. Max columns: {self.columns}, got: " f"{columns.start}:{columns.stop}" ) raise PyrtlError(msg) # First case when setting value to Matrix if isinstance(value, Matrix): if value.rows != (rows.stop - rows.start): msg = ( "Value rows mismatch. Expected Matrix of rows " f'"{rows.stop - rows.start}", instead received Matrix of rows ' f'"{value.rows}"' ) raise PyrtlError(msg) if value.columns != (columns.stop - columns.start): msg = ( "Value columns mismatch. Expected Matrix of columns " f'"{columns.stop - columns.start}", instead received Matrix of ' f'columns "{value.columns}"' ) raise PyrtlError(msg) for i in range(rows.stop - rows.start): for j in range(columns.stop - columns.start): self._matrix[rows.start + i][columns.start + j] = as_wires( value[i, j], bitwidth=self.bits ) # Second case when setting value to wirevector elif isinstance(value, (int, WireVector)): if ((rows.stop - rows.start) != 1) or ( (columns.stop - columns.start) != 1 ): msg = "Value mismatch: expected Matrix, instead received WireVector" raise PyrtlError(msg) self._matrix[rows.start][columns.start] = as_wires( value, bitwidth=self.bits ) # Otherwise Error else: msg = f"Invalid value of type {type(value)}" raise PyrtlError(msg) else: # Second case if we just want to set a full row if isinstance(key, int): if key < 0: start = self.rows - abs(key) if start < 0: msg = ( f"Index {key} is out of bounds for matrix with {self.rows} " "rows" ) raise PyrtlError(msg) key = slice(start, start + 1, None) else: key = slice(key, key + 1, None) self[key, :] = value # Third case if we want to set full rows elif isinstance(key, slice): self[key, :] = value else: msg = ( f'Rows must be of type int or slice, instead "{key}" was passed of ' f"type {type(key)}" ) raise PyrtlError(msg)
[docs] def copy(self) -> Matrix: """Constructs a copy of the ``Matrix``. The returned copy will have new set of :class:`WireVectors<.WireVector>` for its elements, but each new :class:`WireVector` will be wired to the corresponding :class:`WireVector` in the original ``Matrix``. :return: A new instance of ``Matrix`` that indirectly refers to the same underlying :class:`WireVectors<.WireVector>` as ``self``. """ return Matrix( self.rows, self.columns, self.bits, value=self.to_wirevector(), max_bits=self.max_bits, )
def __iadd__(self, other: Matrix) -> Matrix: """Perform the in-place addition operation. Invoked with ``a += b``. Performs elementwise addition. :return: a Matrix object with the elementwise addition being performed. """ new_value = self + other self._matrix = new_value._matrix self.bits = new_value._bits return self.copy()
[docs] def __add__(self, other: Matrix) -> Matrix: """Perform the addition operation. Invoked with ``a + b``. Performs elementwise addition. :return: a Matrix object containing the elementwise sum. """ if not isinstance(other, Matrix): msg = f"error: expecting a Matrix, got {type(other)} instead" raise PyrtlError(msg) if self.columns != other.columns: msg = ( f"error: columns mismatch. Matrix a: {self.columns} columns, Matrix b: " f"{other.columns} columns" ) raise PyrtlError(msg) if self.rows != other.rows: msg = ( f"error: row mismatch. Matrix a: {self.rows} rows, Matrix b: " f"{other.rows} rows" ) raise PyrtlError(msg) new_bits = self.bits if other.bits > new_bits: new_bits = other.bits result = Matrix(self.rows, self.columns, new_bits + 1, max_bits=self.max_bits) for i in range(result.rows): for j in range(result.columns): result[i, j] = self[i, j] + other[i, j] return result
def __isub__(self, other: Matrix) -> Matrix: """Perform the inplace subtraction opperation. Invoked with ``a -= b``. Performs elementwise subtraction. :param other: The ``Matrix`` to subtract. :return: A ``Matrix`` object with the result of elementwise subtraction. """ new_value = self - other self._matrix = new_value._matrix self._bits = new_value._bits return self.copy()
[docs] def __sub__(self, other: Matrix) -> Matrix: """Perform the subtraction operation. Invoked with ``a - b``. Performs elementwise subtraction. .. note:: If ``signed=False``, the result will be floored at 0. :param other: The ``Matrix`` to subtract. :return: a ``Matrix`` object with the result of elementwise subtraction. """ if not isinstance(other, Matrix): msg = f"error: expecting a Matrix, got {type(other)} instead" raise PyrtlError(msg) if self.columns != other.columns: msg = ( f"error: columns mismatch. Matrix a: {self.columns} columns, " f"Matrix b: {other.columns} columns" ) raise PyrtlError(msg) if self.rows != other.rows: msg = ( f"error: row mismatch. Matrix a: {self.rows} rows, Matrix b: " f"{other.rows} rows" ) raise PyrtlError(msg) new_bits = self.bits if other.bits > new_bits: new_bits = other.bits result = Matrix(self.rows, self.columns, new_bits, max_bits=self.max_bits) for i in range(result.rows): for j in range(result.columns): if self.signed: result[i, j] = self[i, j] - other[i, j] else: result[i, j] = select( self[i, j] > other[i, j], self[i, j] - other[i, j], Const(0) ) return result
def __imul__(self, other: Matrix | WireVector) -> Matrix: """Perform the in-place multiplication operation. Invoked with ``a *= b``. Performs elementwise or scalar multiplication. :param other: The ``Matrix`` or scalar to multiply. :return: A ``Matrix`` object with the product. """ new_value = self * other self._matrix = new_value._matrix self._bits = new_value._bits return self.copy()
[docs] def __mul__(self, other: Matrix | WireVector) -> Matrix: """Perform the elementwise or scalar multiplication operation. Invoked with ``a * b``. :param other: The ``Matrix`` or scalar to multiply. :return: A ``Matrix`` object with the product. """ if isinstance(other, Matrix): if self.columns != other.columns: msg = ( f"error: columns mismatch. Matrix a: {self.columns} columns, " f"Matrix b: {other.columns} columns" ) raise PyrtlError(msg) if self.rows != other.rows: msg = ( f"error, row mismatch Matrix a: {self.rows} rows, Matrix b: " f"{other.rows} rows" ) raise PyrtlError(msg) bits = self.bits + other.bits elif isinstance(other, WireVector): bits = self.bits + len(other) else: msg = f"Expecting a Matrix or WireVector got {type(other)} instead" raise PyrtlError(msg) result = Matrix(self.rows, self.columns, bits, max_bits=self.max_bits) for i in range(self.rows): for j in range(self.columns): if isinstance(other, Matrix): result[i, j] = self[i, j] * other[i, j] else: result[i, j] = self[i, j] * other return result
def __imatmul__(self, other: Matrix) -> Matrix: """Performs the inplace matrix multiplication operation. Invoked with ``a @= b``. :param other: The second ``Matrix``. :return: A ``Matrix`` that contains the product. """ new_value = self.__matmul__(other) self.columns = new_value.columns self.rows = new_value.rows self._matrix = new_value._matrix self._bits = new_value._bits return self.copy()
[docs] def __matmul__(self, other: Matrix) -> Matrix: """Performs the inplace matrix multiplication operation. Invoked with ``a @ b``. :param other: The second ``Matrix``. :return: A ``Matrix`` that contains the product. """ if not isinstance(other, Matrix): msg = f"error: expecting a Matrix, got {type(other)} instead" raise PyrtlError(msg) if self.columns != other.rows: msg = ( f"error: rows and columns mismatch. Matrix a: {self.columns} columns, " f"Matrix b: {other.rows} rows" ) raise PyrtlError(msg) result = Matrix( self.rows, other.columns, self.columns * other.rows * (self.bits + other.bits), max_bits=self.max_bits, ) for i in range(self.rows): for j in range(other.columns): for k in range(self.columns): result[i, j] = fused_multiply_adder( self[i, k], other[k, j], result[i, j], signed=self.signed ) return result
def __ipow__(self, power: int) -> Matrix: """Performs the matrix power operation. Invoked with ``a **= b``. This performs a chain of matrix multiplications, where ``self`` is matrix multiplied by ``self``, ``power`` times. :param power: The power to raise the matrix to. :return: A ``Matrix`` containing the result. """ new_value = self**power self._matrix = new_value._matrix self._bits = new_value._bits return self.copy()
[docs] def __pow__(self, power: int) -> Matrix: """Performs the matrix power operation. Invoked with ``a ** b``. This performs a chain of matrix multiplications, where ``self`` is matrix multiplied by ``self``, ``power`` times. :param power: The power to raise the matrix to. :return: A ``Matrix`` containing the result. """ if not isinstance(power, int): msg = ( "Unexpected power given. Type int expected, but received type " f"{type(power)}" ) raise PyrtlError(msg) if self.rows != self.columns: msg = "Matrix must be square" raise PyrtlError(msg) result = self.copy() # First case: return identity matrix if power == 0: for i in range(self.rows): for j in range(self.columns): if i != j: result[i, j] = Const(0) else: result[i, j] = Const(1) return result # Second case: do matrix multiplications if power >= 1: inputs = [result] * power def pow_2(first, second): return first.__matmul__(second) return reduce(pow_2, inputs) msg = "Power must be greater than or equal to 0" raise PyrtlError(msg)
[docs] def put( self, ind: int | list[int] | tuple[int], v: int | list[int] | tuple[int] | Matrix, mode: str = "raise", ): """Replace specified elements of the ``Matrix`` with values ``v``. Note that the index ``ind`` is on the flattened matrix. :param ind: Target indices. :param v: Values to place in ``Matrix`` at ``ind``. If ``v`` is shorter than ``ind``, ``v`` will be repeated. :param mode: How out-of-bounds indices behave. ``raise`` raises an error, ``wrap`` wraps around, and ``clip`` clips to the range. """ count = self.rows * self.columns if isinstance(ind, int): ind = (ind,) elif not isinstance(ind, (tuple, list)): msg = f"Expected int or list-like indices, got {type(ind)}" raise PyrtlError(msg) if isinstance(v, int): v = (v,) if isinstance(v, (tuple, list)) and len(v) == 0: return if isinstance(v, Matrix) and v.rows != 1: msg = f"Expected a row-vector matrix, instead got matrix with {v.rows} rows" raise PyrtlError(msg) if mode not in ["raise", "wrap", "clip"]: msg = ( f"Unexpected mode {mode}; allowable modes are 'raise', 'wrap', and " "'clip'" ) raise PyrtlError(msg) def get_ix(ix): if ix < 0: ix = count - abs(ix) if ix < 0 or ix >= count: if mode == "raise": msg = f"index {ix} is out of bounds with size {count}" raise PyrtlError(msg) if mode == "wrap": ix = ix % count elif mode == "clip": ix = 0 if ix < 0 else count - 1 return ix def get_value(ix): if isinstance(v, (tuple, list)): if ix >= len(v): return v[-1] # if v is shorter than ind, repeat last as necessary return v[ix] if isinstance(v, Matrix): if ix >= count: return v[0, -1] return v[0, ix] return None for v_ix, mat_ix in enumerate(ind): mat_ix = get_ix(mat_ix) row = mat_ix // self.columns col = mat_ix % self.columns self[row, col] = get_value(v_ix)
[docs] def reshape(self, *newshape: int | tuple, order: str = "C"): """Create a ``Matrix`` of the given shape from ``self``. One shape dimension in ``newshape`` can be ``-1``; in this case, the value for that dimension is inferred from the other given dimension (if any) and the number of elements in ``self``. Examples:: int_matrix = [[0, 1, 2, 3], [4, 5, 6, 7]] matrix = Matrix.Matrix(2, 4, 4, value=int_matrix) matrix.reshape(-1) == [[0, 1, 2, 3, 4, 5, 6, 7]] matrix.reshape(8) == [[0, 1, 2, 3, 4, 5, 6, 7]] matrix.reshape(1, 8) == [[0, 1, 2, 3, 4, 5, 6, 7]] matrix.reshape((1, 8)) == [[0, 1, 2, 3, 4, 5, 6, 7]] matrix.reshape((1, -1)) == [[0, 1, 2, 3, 4, 5, 6, 7]] matrix.reshape(4, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]] matrix.reshape(-1, 2) == [[0, 1], [2, 3], [4, 5], [6, 7]] matrix.reshape(4, -1) == [[0, 1], [2, 3], [4, 5], [6, 7]] :param newshape: Shape of the matrix to return. If ``newshape`` is a single :class:`int`, the new shape will be a 1-D row-vector of that length. If ``newshape`` is a :class:`tuple`, the :class:`tuple` specifies the new number of rows and columns. ``newshape`` can also be varargs. :param order: ``C`` means to read from self using row-major order (C-style), and ``F`` means to read from self using column-major order (Fortran-style). :return: A copy of the matrix with same data, with a new number of rows and columns. """ count = self.rows * self.columns if isinstance(newshape, int): if newshape == -1: newshape = (1, count) else: newshape = (1, newshape) elif isinstance(newshape, tuple): if isinstance(newshape[0], tuple): newshape = newshape[0] if len(newshape) == 1: newshape = (1, newshape[0]) if len(newshape) > 2: msg = "length of newshape tuple must be <= 2" raise PyrtlError(msg) rows, cols = newshape if not isinstance(rows, int) or not isinstance(cols, int): msg = ( "newshape dimensions must be integers, instead got " f"{type(newshape)}" ) raise PyrtlError(msg) if rows == -1 and cols == -1: msg = "Both dimensions in newshape cannot be -1" raise PyrtlError(msg) if rows == -1: rows = count // cols newshape = (rows, cols) elif cols == -1: cols = count // rows newshape = (rows, cols) else: msg = ( f"newshape can be an integer or tuple of integers, not {type(newshape)}" ) raise PyrtlError(msg) rows, cols = newshape if rows * cols != count: msg = f"Cannot reshape matrix of size {count} into shape {newshape}" raise PyrtlError(msg) if order not in "CF": msg = ( f"Invalid order {order}. Acceptable orders are 'C' (for row-major " "C-style order) and 'F' (for column-major Fortran-style order)." ) raise PyrtlError(msg) value = [[0] * cols for _ in range(rows)] ix = 0 if order == "C": # Read and write in row-wise order for newr in range(rows): for newc in range(cols): r = ix // self.columns c = ix % self.columns value[newr][newc] = self[r, c] ix += 1 else: # Read and write in column-wise order for newc in range(cols): for newr in range(rows): r = ix % self.rows c = ix // self.rows value[newr][newc] = self[r, c] ix += 1 return Matrix(rows, cols, self.bits, self.signed, value, self.max_bits)
[docs] def flatten(self, order: str = "C"): """Flatten the ``Matrix`` into a single row. :param order: ``C`` means row-major order (C-style), and ``F`` means column-major order (Fortran-style) :return: A copy of the ``Matrix`` flattened into a row vector. """ return self.reshape(self.rows * self.columns, order=order)
def multiply(first: Matrix, second: Matrix | WireVector) -> Matrix: """Perform the elementwise or scalar multiplication operation. .. WARNING:: Use :meth:`Matrix.__mul__` instead. :param first: first matrix :param second: second matrix :return: a Matrix object with the element wise or scalar multiplication being performed """ if not isinstance(first, Matrix): msg = f"error: expecting a Matrix, got {type(second)} instead" raise PyrtlError(msg) return first * second
[docs] def sum( matrix: Matrix | WireVector, axis: int | None = None, bits: int | None = None ) -> Matrix | WireVector: """Returns the sum of values in a ``Matrix`` across ``axis``. This performs a reduction, summing over the specified ``axis``. :param matrix: The matrix to perform sum operation on. If it is a :class:`.WireVector`, it will return itself. :param axis: The axis to perform the operation on. ``None`` refers to sum of all elements. ``0`` is sum of column. ``1`` is sum of rows. Defaults to ``None``. :param bits: The bits per element of the sum. Defaults to ``matrix.bits``. :return: A :class:`.WireVector` or ``Matrix`` representing the sum. """ def sum_2(first, second): return first + second if isinstance(matrix, WireVector): return matrix if not isinstance(matrix, Matrix): msg = ( f"error: expecting a Matrix or WireVector for matrix, got {type(matrix)} " "instead" ) raise PyrtlError(msg) if not isinstance(bits, int) and bits is not None: msg = f"error: expecting an int/None for bits, got {type(bits)} instead" raise PyrtlError(msg) if not isinstance(axis, int) and axis is not None: msg = f"error: expecting an int or None for axis, got {type(axis)} instead" raise PyrtlError(msg) if bits is None: bits = matrix.bits if bits <= 0: msg = f"error: bits cannot be negative or zero, got {bits} instead" raise PyrtlError(msg) if axis is None: inputs = [] for i in range(matrix.rows): for j in range(matrix.columns): inputs.append(matrix[i, j]) return reduce(sum_2, inputs) if axis == 0: result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits) for i in range(matrix.columns): inputs = [] for j in range(matrix.rows): inputs.append(matrix[j, i]) result[0, i] = reduce(sum_2, inputs) return result if axis == 1: result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits) for i in range(matrix.rows): inputs = [] for j in range(matrix.columns): inputs.append(matrix[i, j]) result[0, i] = reduce(sum_2, inputs) return result msg = f"Axis invalid: expected (None, 0, or 1), got {axis}" raise PyrtlError(msg)
[docs] def min( matrix: Matrix | WireVector, axis: int | None = None, bits: int | None = None ) -> Matrix | WireVector: """Returns the minimum value in a ``Matrix``. This performs a reduction, taking the minimum over the specified ``axis``. :param matrix: The matrix to take the mimimum of. If it is a :class:`.WireVector`, it will return itself. :param axis: The axis to perform the minimum on. ``None`` refers to min of all elements. ``0`` is min of columns. ``1`` is min of rows. Defaults to ``None``. :param bits: The bits per element of the min. Defaults to ``matrix.bits``. :return: A :class:`.WireVector` or ``Matrix`` representing the min value. """ def min_2(first, second): return select(first < second, first, second) if isinstance(matrix, WireVector): return matrix if not isinstance(matrix, Matrix): msg = ( f"error: expecting a Matrix or WireVector for matrix, got {type(matrix)} " "instead" ) raise PyrtlError(msg) if not isinstance(bits, int) and bits is not None: msg = f"error: expecting an int/None for bits, got {type(bits)} instead" raise PyrtlError(msg) if not isinstance(axis, int) and axis is not None: msg = f"error: expecting an int or None for axis, got {type(axis)} instead" raise PyrtlError(msg) if bits is None: bits = matrix.bits if bits <= 0: msg = f"error: bits cannot be negative or zero, got {bits} instead" raise PyrtlError(msg) if axis is None: inputs = [] for i in range(matrix.rows): for j in range(matrix.columns): inputs.append(matrix[i, j]) return reduce(min_2, inputs) if axis == 0: result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits) for i in range(matrix.columns): inputs = [] for j in range(matrix.rows): inputs.append(matrix[j, i]) result[0, i] = reduce(min_2, inputs) return result if axis == 1: result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits) for i in range(matrix.rows): inputs = [] for j in range(matrix.columns): inputs.append(matrix[i, j]) result[0, i] = reduce(min_2, inputs) return result msg = f"Axis invalid: expected (None, 0, or 1), got {axis}" raise PyrtlError(msg)
[docs] def max( matrix: Matrix | WireVector, axis: int | None = None, bits: int | None = None ) -> Matrix | WireVector: """Returns the maximum value in a ``Matrix``. This performs a reduction, taking the maximum over the specified ``axis``. :param matrix: The matrix to take the mimimum of. If it is a :class:`.WireVector`, it will return itself. :param axis: The axis to perform the maximum on. ``None`` refers to max of all elements. ``0`` is max of columns. ``1`` is max of rows. Defaults to ``None``. :param bits: The bits per element of the max. Defaults to ``matrix.bits``. :return: A :class:`.WireVector` or ``Matrix`` representing the max value. """ def max_2(first, second): return select(first > second, first, second) if isinstance(matrix, WireVector): return matrix if not isinstance(matrix, Matrix): msg = ( f"error: expecting a Matrix or WireVector for matrix, got {type(matrix)} " "instead" ) raise PyrtlError(msg) if not isinstance(bits, int) and bits is not None: msg = f"error: expecting an int/None for bits, got {type(bits)} instead" raise PyrtlError(msg) if not isinstance(axis, int) and axis is not None: msg = f"error: expecting an int or None for axis, got {type(axis)} instead" raise PyrtlError(msg) if bits is None: bits = matrix.bits if bits <= 0: msg = f"error: bits cannot be negative or zero, got {bits} instead" raise PyrtlError(msg) if axis is None: inputs = [] for i in range(matrix.rows): for j in range(matrix.columns): inputs.append(matrix[i, j]) return reduce(max_2, inputs) if axis == 0: result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits) for i in range(matrix.columns): inputs = [] for j in range(matrix.rows): inputs.append(matrix[j, i]) result[0, i] = reduce(max_2, inputs) return result if axis == 1: result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits) for i in range(matrix.rows): inputs = [] for j in range(matrix.columns): inputs.append(matrix[i, j]) result[0, i] = reduce(max_2, inputs) return result msg = f"Axis invalid: expected (None, 0, or 1), got {axis}" raise PyrtlError(msg)
[docs] def argmax( matrix: Matrix | WireVector, axis: int | None = None, bits: int | None = None ) -> Matrix | WireVector: """Returns the index of the max value of the ``Matrix``. .. note:: If there are two indices with the same max value, this function picks the first instance. :param matrix: The ``Matrix`` to perform argmax operation on. If it is a :class:`.WireVector`, it will return itself. :param axis: The axis to perform the operation on. ``None`` refers to argmax of all items. ``0`` is argmax of the columns. ``1`` is argmax of rows. Defaults to ``None``. :param bits: The bits per element of the argmax. Defaults to ``matrix.bits``. :return: A :class:`.WireVector` or ``Matrix`` representing the argmax value. """ if isinstance(matrix, WireVector): return Const(0) if not isinstance(matrix, Matrix): msg = ( f"error: expecting a Matrix or WireVector for matrix, got {type(matrix)} " "instead" ) raise PyrtlError(msg) if not isinstance(bits, int) and bits is not None: msg = f"error: expecting an int/None for bits, got {type(bits)} instead" raise PyrtlError(msg) if not isinstance(axis, int) and axis is not None: msg = f"error: expecting an int or None for axis, got {type(axis)} instead" raise PyrtlError(msg) if bits is None: bits = matrix.bits if bits <= 0: msg = f"error: bits cannot be negative or zero, got {bits} instead" raise PyrtlError(msg) max_number = max(matrix, axis=axis, bits=bits) if axis is None: index = Const(0) arg = matrix.rows * matrix.columns - 1 for i in reversed(range(matrix.rows)): for j in reversed(range(matrix.columns)): index = select(max_number == matrix[i, j], Const(arg), index) arg -= 1 return index if axis == 0: result = Matrix(1, matrix.columns, signed=matrix.signed, bits=bits) for i in range(matrix.columns): local_max = max_number[0, i] index = Const(0) arg = matrix.rows - 1 for j in reversed(range(matrix.rows)): index = select(local_max == matrix[j, i], Const(arg), index) arg -= 1 result[0, i] = index return result if axis == 1: result = Matrix(1, matrix.rows, signed=matrix.signed, bits=bits) for i in range(matrix.rows): local_max = max_number[0, i] index = Const(0) arg = matrix.columns - 1 for j in reversed(range(matrix.columns)): index = select(local_max == matrix[i, j], Const(arg), index) arg -= 1 result[0, i] = index return result return None
[docs] def dot(first: Matrix, second: Matrix) -> Matrix: """Performs the dot product on two matrices. Specifically, the dot product on two matrices is: 1. If either ``first`` or ``second`` are :class:`WireVectors<.WireVector>`, or have both rows and columns equal to 1, ``dot`` is equivalent to :meth:`Matrix.__mul__` 2. If ``first`` and ``second`` are both arrays (have rows or columns equal to 1), ``dot`` is the inner product of the vectors. 3. Otherwise ``dot`` is :meth:`Matrix.__matmul__` between ``first`` and ``second``. .. note:: Row vectors and column vectors are both treated as arrays. :param first: The first matrix. :param second: The second matrix. :return: A ``Matrix`` that contains the dot product of ``first`` and ``second``. """ if not isinstance(first, (WireVector, Matrix)): msg = f"error: expecting a Matrix, got {type(first)} instead" raise PyrtlError(msg) if not isinstance(second, (WireVector, Matrix)): msg = f"error: expecting a Matrix/WireVector, got {type(second)} instead" raise PyrtlError(msg) # First case when it is multiply if isinstance(first, WireVector): if isinstance(second, WireVector): return first * second return second[:, :] * first if isinstance(second, WireVector): return first[:, :] * second if (first.rows == 1 and first.columns == 1) or ( second.rows == 1 and second.columns == 1 ): return first[:, :] * second[:, :] # Second case when it is Inner Product if first.rows == 1: if second.rows == 1: return sum(first * second) if second.columns == 1: return sum(first * second.transpose()) elif first.columns == 1: if second.rows == 1: return sum(first * second.transpose()) if second.columns == 1: return sum(first * second) # Third case when it is Matrix Multiply return first.__matmul__(second)
[docs] def hstack(*matrices: Matrix) -> Matrix: """Stack ``matrices`` in sequence horizontally (column-wise). All the ``matrices`` must have the same number of rows and the same ``signed`` value. For example:: m1 = Matrix(rows=2, columns=3, bits=5, value=[[1, 2, 3], [4, 5, 6]]) m2 = Matrix(rows=2, columns=1, bits=10, value=[[17], [23]]]) m3 = hstack(m1, m2) ``m3`` will look like:: [[1, 2, 3, 17], [4, 5, 6, 23]] And ``m3.bits`` will be ``10``. :param matrices: Matrices to concatenate together horizontally. :return: A new ``Matrix``, with the same number of rows as the original, and columns equal to the sum of the columns of ``matrices``. The new ``Matrix``'s bitwidth is the max of the bitwidths of all ``matrices``. """ if len(matrices) == 0: msg = "Must supply at least one matrix to hstack()" raise PyrtlError(msg) if any(not isinstance(matrix, Matrix) for matrix in matrices): msg = "All arguments to hstack must be matrices." raise PyrtlError(msg) if len(matrices) == 1: return matrices[0].copy() new_rows = matrices[0].rows if any(m.rows != new_rows for m in matrices): msg = "All matrices being hstacked together must have the same number of rows" raise PyrtlError(msg) new_signed = matrices[0].signed if any(m.signed != new_signed for m in matrices): msg = "All matrices being hstacked together must have the same signedness" raise PyrtlError(msg) new_cols = builtins.sum(m.columns for m in matrices) new_bits = builtins.max(m.bits for m in matrices) new_max_bits = builtins.max(m.max_bits for m in matrices) new = Matrix(new_rows, new_cols, new_bits, max_bits=new_max_bits) new_c = 0 for matrix in matrices: for c in range(matrix.columns): for r in range(matrix.rows): new[r, new_c] = matrix[r, c] new_c += 1 return new
[docs] def vstack(*matrices: Matrix) -> Matrix: """Stack matrices in sequence vertically (row-wise). All the ``matrices`` must have the same number of columns and the same ``signed`` value. For example:: m1 = Matrix(rows=2, columns=3, bits=5, value=[[1, 2, 3], [4, 5, 6]]) m2 = Matrix(rows=1, columns=3, bits=10, value=[[7, 8, 9]]) m3 = vstack(m1, m2) ``m3`` will look like:: [[1, 2, 3], [4, 5, 6], [7, 8, 9]] And ``m3.bits`` will be ``10``. :param matrices: Matrices to concatenate together vertically :return: A new ``Matrix``, with the same number of columns as the original, and rows equal to the sum of the rows of ``matricies``. The new ``Matrix``'s bitwidth is the max of the bitwidths of all ``matrices``. """ if len(matrices) == 0: msg = "Must supply at least one matrix to vstack()" raise PyrtlError(msg) if any(not isinstance(matrix, Matrix) for matrix in matrices): msg = "All arguments to vstack must be matrices." raise PyrtlError(msg) if len(matrices) == 1: return matrices[0].copy() new_cols = matrices[0].columns if any(m.columns != new_cols for m in matrices): msg = ( "All matrices being vstacked together must have the same number of columns" ) raise PyrtlError(msg) new_signed = matrices[0].signed if any(m.signed != new_signed for m in matrices): msg = "All matrices being vstacked together must have the same signedness" raise PyrtlError(msg) new_rows = builtins.sum(m.rows for m in matrices) new_bits = builtins.max(m.bits for m in matrices) new_max_bits = builtins.max(m.max_bits for m in matrices) new = Matrix(new_rows, new_cols, new_bits, max_bits=new_max_bits) new_r = 0 for matrix in matrices: for r in range(matrix.rows): for c in range(matrix.columns): new[new_r, c] = matrix[r, c] new_r += 1 return new
[docs] def concatenate(matrices: Matrix, axis: int = 0) -> Matrix: """Join a sequence of ``matrices`` along an existing ``axis``. This function is just a wrapper around :func:`hstack` and :func:`vstack`. :param matrices: Matrices to concatenate together. :param axis: Axis along which to concatenate. ``0`` is horizontally, ``1`` is vertically. Defaults to ``0``. :return: A new ``Matrix`` composed of the given matrices concatenated together. """ if axis == 0: return hstack(*matrices) if axis == 1: return vstack(*matrices) msg = "Only allowable axes are 0 or 1" raise PyrtlError(msg)
[docs] def matrix_wv_to_list( matrix_wv: WireVector, rows: int, columns: int, bits: int ) -> list[list[int]]: """Convert a :class:`.WireVector` representing a :class:`Matrix` into a Python list of lists. During :class:`.Simulation`, this is useful when printing the value of an :meth:`inspected<.Simulation.inspect>` wire that represents a :class:`Matrix`. Example:: m = Matrix.Matrix(rows=2, columns=3, bits=4, values=[[1, 2, 3], [4, 5, 6]]) output = Output(name="output") output <<= m.to_wirevector() sim = Simulation() sim.step() raw_matrix = Matrix.matrix_wv_to_list( sim.inspect("output"), m.rows, m.columns, m.bits) print(raw_matrix) # Produces: # [[1, 2, 3], [4, 5, 6]] :param matrix_wv: Result of calling :meth:`Matrix.to_wirevector`. :param rows: Number of rows in the matrix. :param columns: Number of columns in the matrix. :param bits: Number of bits for each element in the matrix. :return: A Python list of lists. """ value = f"{matrix_wv:b}".zfill(rows * columns * bits) result = [[0 for _ in range(columns)] for _ in range(rows)] bit_pointer = 0 for i in range(rows): for j in range(columns): int_value = int(value[bit_pointer : bit_pointer + bits], 2) result[i][j] = int_value bit_pointer += bits return result
[docs] def list_to_int(matrix: list[list[int]], n_bits: int) -> int: """Convert a Python matrix (a :class:`list` of :class:`lists<list>`) into an :class:`int`. Integers that are signed will automatically be converted to their two's complement form. This function is helpful for turning a pure Python list of lists into a very large integer suitable for creating a :class:`.Const` that can be used as :meth:`Matrix.__init__`'s ``value`` argument, or for passing into a :meth:`.Simulation.step`'s ``provided_inputs`` for an :class:`.Input` wire. For example, calling ``list_to_int([3, 5], [7, 9], n_bits=4)`` produces ``13689``, which in binary looks like:: 0011 0101 0111 1001 Note how the elements of the list of lists were added, 4 bits at a time, in row order, such that the element at row 0, column 0 is in the most significant 4 bits, and the element at row 1, column 1 is in the least significant 4 bits. Here's an example of using it in :class:`.Simulation`:: a_vals = [[0, 1], [2, 3]] b_vals = [[2, 4, 6], [8, 10, 12]] a_in = pyrtl.Input(name="a_in", bitwidth=2 * 2 * 4) b_in = pyrtl.Input(name="b_in", bitwidth=2 * 3 * 4) a = Matrix.Matrix(rows=2, columns=2, bits=4, value=a_in) b = Matrix.Matrix(rows=2, columns=3, bits=4, value=b_in) ... sim = pyrtl.Simulation() sim.step({ 'a_in': Matrix.list_to_int(a_vals, n_bits=a.bits) 'b_in': Matrix.list_to_int(b_vals, n_bits=b.bits) }) :param matrix: A :class:`list` of :class:`lists<list>` of :class:`ints<int>` representing the data in a :class:`Matrix`. :param n_bits: The number of bits used to represent each element. If an element doesn't fit in ``n_bits``, its most significant bits will be truncated. :return: An :class:`int` with bitwidth ``N * n_bits``, containing the elements of ``matrix``, where ``N`` is the number of elements in ``matrix``. """ if n_bits <= 0: msg = f"Number of bits per element must be positive, instead got {n_bits}" raise PyrtlError(msg) result = 0 for i in range(len(matrix)): for j in range(len(matrix[0])): val = formatted_str_to_val(str(matrix[i][j]), "s" + str(n_bits)) result = (result << n_bits) | val return result