Source code for pyrtl.rtllib.barrel
"""
Basic shifting is defined in PyRTL's core library, see:
- :func:`.shift_left_logical`
- :func:`.shift_right_logical`
- :func:`.shift_right_arithmetic`
:func:`barrel_shifter` should only be used when more complex shifting behavior is
required.
"""
from enum import IntEnum
from pyrtl.wire import WireVector, WireVectorLike
[docs]
class Direction(IntEnum):
"""Assigns names to each shift direction, to improve code readability."""
RIGHT = 0
LEFT = 1
[docs]
def barrel_shifter(
bits_to_shift: WireVector,
bit_in: WireVectorLike,
direction: WireVectorLike,
shift_dist: WireVector,
wrap_around=0,
) -> WireVector:
"""Create a barrel shifter.
.. doctest only::
>>> import pyrtl
>>> pyrtl.reset_working_block()
Example::
>>> bits_to_shift = pyrtl.Input(name="input", bitwidth=8)
>>> shift_dist = pyrtl.Input(name="shift_dist", bitwidth=3)
>>> output = pyrtl.Output(name="output")
>>> output <<= pyrtl.rtllib.barrel.barrel_shifter(
... bits_to_shift,
... bit_in=1,
... direction=pyrtl.rtllib.barrel.Direction.RIGHT,
... shift_dist=shift_dist)
>>> sim = pyrtl.Simulation()
>>> sim.step(provided_inputs={"input": 0x55, "shift_dist": 4})
>>> hex(sim.inspect("output"))
'0xf5'
:param bits_to_shift: :class:`.WireVector` with the value to shift.
:param bit_in: A 1-bit :class:`.WireVector` representing the value to shift in.
:param direction: A one bit :class:`.WireVector` representing the shift direction
(``0`` = shift right, ``1`` = shift left). If ``direction`` is constant, use
:class:`Direction` to improve code readability (``direction=Direction.RIGHT``
instead of ``direction=0``).
:param shift_dist: :class:`.WireVector` representing the amount to shift.
:param wrap_around: ****currently not implemented****
:return: The shifted :class:`.WireVector`.
"""
from pyrtl import concat, select # just for readability
if wrap_around != 0:
raise NotImplementedError
# Implement with logN stages pyrtl.muxing between shifted and un-shifted values
final_width = len(bits_to_shift)
val = bits_to_shift
append_val = bit_in
for i in range(len(shift_dist)):
shift_amt = pow(2, i) # stages shift 1,2,4,8,...
if shift_amt < final_width:
newval = select(
direction,
concat(val[:-shift_amt], append_val), # shift left
concat(append_val, val[shift_amt:]),
) # shift right
val = select(
shift_dist[i],
truecase=newval, # if bit of shift is 1, do the shift
falsecase=val,
) # otherwise, don't
# the value to append grows exponentially, but is capped at full width
append_val = concat(append_val, append_val)[:final_width]
else:
# if we are shifting this much, all the data is gone
val = select(
shift_dist[i],
truecase=append_val, # if bit of shift is 1, do the shift
falsecase=val,
) # otherwise, don't
return val