Source code for pyrtl.rtllib.adders

"""
Basic integer addition is defined in PyRTL's core library, see:

- :meth:`.WireVector.__add__` for unsigned integer addition.

- :func:`.signed_add` for signed integer addition.

The functions below provide more complex alternatives.
"""

import itertools
import math
from collections.abc import Callable

import pyrtl


[docs] def kogge_stone( a: pyrtl.WireVector, b: pyrtl.WireVector, cin: pyrtl.wire.WireVectorLike = 0 ) -> pyrtl.WireVector: """Creates a Kogge-Stone adder given two inputs. The Kogge-Stone adder is a fast tree-based adder with `O(log(n))` propagation delay, useful for performance critical designs. However, it has `O(n log(n))` area usage, and large fan out. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> a = pyrtl.Input(name="a", bitwidth=4) >>> b = pyrtl.Input(name="b", bitwidth=4) >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.adders.kogge_stone(a, b) >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={"a": 2, "b": 3}) >>> sim.inspect("output") 5 :param a: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param b: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param cin: An optional 1-bit carry-in :class:`.WireVector`. Can be any type that can be coerced to :class:`.WireVector` by :func:`.as_wires`. :return: A :class:`.WireVector` representing the output of the adder. """ a, b = pyrtl.match_bitwidth(a, b) prop_orig = a ^ b prop_bits = list(prop_orig) gen_bits = list(a & b) prop_dist = 1 # creation of the carry calculation while prop_dist < len(a): for i in reversed(range(prop_dist, len(a))): prop_old = prop_bits[i] gen_bits[i] = gen_bits[i] | (prop_old & gen_bits[i - prop_dist]) if i >= prop_dist * 2: # to prevent creating unnecessary nets and wires prop_bits[i] = prop_old & prop_bits[i - prop_dist] prop_dist *= 2 # assembling the result of the addition. preparing the cin (and conveniently # shifting the gen bits) gen_bits.insert(0, pyrtl.as_wires(cin)) return pyrtl.concat_list(gen_bits) ^ prop_orig
[docs] def one_bit_add(a, b, cin=0): return pyrtl.concat(*_one_bit_add_no_concat(a, b, cin))
def _one_bit_add_no_concat(a, b, cin=0): cin = pyrtl.as_wires(cin) # to make sure that an int cin doesn't break things assert len(a) == len(b) == len(cin) == 1 sum = a ^ b ^ cin cout = a & b | a & cin | b & cin return cout, sum
[docs] def half_adder(a, b): assert len(a) == len(b) == 1 sum = a ^ b cout = a & b return cout, sum
[docs] def ripple_add(a, b, cin=0): if len(a) < len(b): # make sure that b is the shorter wire b, a = a, b cin = pyrtl.as_wires(cin) if len(a) == 1: return one_bit_add(a, b, cin) ripplecarry = one_bit_add(a[0], b[0], cin) if len(b) == 1: msbits = ripple_half_add(a[1:], ripplecarry[1]) else: msbits = ripple_add(a[1:], b[1:], ripplecarry[1]) return pyrtl.concat(msbits, ripplecarry[0])
[docs] def ripple_half_add(a, cin=0): cin = pyrtl.as_wires(cin) if len(a) == 1: return pyrtl.concat(*half_adder(a, cin)) ripplecarry = half_adder(a[0], cin) msbits = ripple_half_add(a[1:], ripplecarry[0]) return pyrtl.concat(msbits, ripplecarry[1])
[docs] def carrysave_adder( a: pyrtl.WireVector, b: pyrtl.WireVector, c: pyrtl.WireVector, final_adder: Callable = ripple_add, ) -> pyrtl.WireVector: """Adds three :class:`WireVectors<.WireVector>` up in an efficient manner. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> a = pyrtl.Input(name="a", bitwidth=4) >>> b = pyrtl.Input(name="b", bitwidth=4) >>> c = pyrtl.Input(name="c", bitwidth=4) >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.adders.carrysave_adder(a, b, c) >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={"a": 2, "b": 3, "c": 4}) >>> sim.inspect("output") 9 :param a: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param b: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param c: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param final_adder: The adder to use for the final addition. :return: A :class:`.WireVector` with bitwidth equal to the longest input, plus 2. """ a, b, c = pyrtl.match_bitwidth(a, b, c) partial_sum = a ^ b ^ c shift_carry = (a | b) & (a | c) & (b | c) return pyrtl.concat(final_adder(partial_sum[1:], shift_carry), partial_sum[0])
[docs] def cla_adder( a: pyrtl.WireVector, b: pyrtl.WireVector, cin: pyrtl.WireVector = 0, la_unit_len: int = 4, ) -> pyrtl.WireVector: """Carry Look-Ahead Adder. A Carry Look-Ahead Adder is an adder that is faster than :func:`ripple_add`, as it calculates the carry bits faster. It is not as fast as :func:`kogge_stone`, but uses less area. .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> a = pyrtl.Input(name="a", bitwidth=4) >>> b = pyrtl.Input(name="b", bitwidth=4) >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.adders.cla_adder(a, b) >>> sim = pyrtl.Simulation() >>> sim.step(provided_inputs={"a": 2, "b": 3}) >>> sim.inspect("output") 5 :param a: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param b: A :class:`.WireVector` to add up. Bitwidths don't need to match. :param cin: A 1-bit carry-in :class:`.WireVector`. :param la_unit_len: The length of input that every unit processes. :return: A :class:`.WireVector` representing the output of the adder. """ a, b = pyrtl.match_bitwidth(a, b) if len(a) <= la_unit_len: sum, cout = _cla_adder_unit(a, b, cin) return pyrtl.concat(cout, sum) sum, cout = _cla_adder_unit(a[0:la_unit_len], b[0:la_unit_len], cin) msbits = cla_adder(a[la_unit_len:], b[la_unit_len:], cout, la_unit_len) return pyrtl.concat(msbits, sum)
def _cla_adder_unit(a, b, cin): """ Carry generation and propogation signals will be calculated only using the inputs; their values don't rely on the sum. Every unit generates a cout signal which is used as cin for the next unit. """ gen = a & b prop = a ^ b assert len(prop) == len(gen) carry = [gen[0] | prop[0] & cin] sum_bit = prop[0] ^ cin cur_gen = gen[0] cur_prop = prop[0] for i in range(1, len(prop)): cur_gen = gen[i] | (prop[i] & cur_gen) cur_prop = cur_prop & prop[i] sum_bit = pyrtl.concat(prop[i] ^ carry[i - 1], sum_bit) carry.append(gen[i] | (prop[i] & carry[i - 1])) cout = cur_gen | (cur_prop & cin) return sum_bit, cout
[docs] def wallace_reducer( wire_array_2: list[list[pyrtl.WireVector]], result_bitwidth: int, final_adder: Callable = kogge_stone, ) -> pyrtl.WireVector: """The reduction and final adding part of a dada tree. Useful for adding many numbers together with :func:`fast_group_adder`. The use of single bitwidth wires allows for additional flexibility. :param wire_array_2: An array of arrays of single bitwidth :class:`WireVectors<.WireVector>`. :param result_bitwidth: Bitwidth of the resulting wire. Used to eliminate unnecessary wires. :param final_adder: The adder used for the final addition. :return: :class:`.WireVector` with :attr:`~.WireVector.bitwidth` ``result_bitwidth``. """ # verification that the wires are actually wirevectors of length 1 for wire_set in wire_array_2: for a_wire in wire_set: if not isinstance(a_wire, pyrtl.WireVector) or len(a_wire) != 1: msg = ( f"The item {a_wire} is not a valid element for the wire_array_2. " "It must be a WireVector of bitwidth 1" ) raise pyrtl.PyrtlError(msg) while not all(len(i) <= 2 for i in wire_array_2): deferred = [[] for weight in range(result_bitwidth + 1)] for i, w_array in enumerate( wire_array_2 ): # Start with low weights and start reducing while len(w_array) >= 3: cout, sum = _one_bit_add_no_concat(*(w_array.pop(0) for j in range(3))) deferred[i].append(sum) deferred[i + 1].append(cout) if len(w_array) == 2: cout, sum = half_adder(*w_array) deferred[i].append(sum) deferred[i + 1].append(cout) else: deferred[i].extend(w_array) wire_array_2 = deferred[:result_bitwidth] # At this stage in the multiplication we have only 2 wire vectors left. # now we need to add them up result = _sparse_adder(wire_array_2, final_adder) if len(result) > result_bitwidth: return result[:result_bitwidth] return result
[docs] def dada_reducer( wire_array_2: list[list[pyrtl.WireVector]], result_bitwidth: int, final_adder: Callable = kogge_stone, ) -> pyrtl.WireVector: """The reduction and final adding part of a dada tree. Useful for adding many numbers together with :func:`fast_group_adder`. The use of single bitwidth wires allows for additional flexibility. :param wire_array_2: An array of arrays of single bitwidth :class:`WireVectors<.WireVector>`. :param result_bitwidth: Bitwidth of the resulting wire. Used to eliminate unnecessary wires. :param final_adder: The adder used for the final addition. :return: :class:`.WireVector` with :attr:`~.WireVector.bitwidth` ``result_bitwidth``. """ # verification that the wires are actually wirevectors of length 1 for wire_set in wire_array_2: for a_wire in wire_set: if not isinstance(a_wire, pyrtl.WireVector) or len(a_wire) != 1: msg = ( f"The item {a_wire} is not a valid element for the wire_array_2. " "It must be a WireVector of bitwidth 1" ) raise pyrtl.PyrtlError(msg) max_width = max(len(i) for i in wire_array_2) reduction_schedule = [2] while reduction_schedule[-1] <= max_width: reduction_schedule.append(int(reduction_schedule[-1] * 3 / 2)) for reduction_target in reversed(reduction_schedule[:-1]): deferred = [[] for weight in range(result_bitwidth + 1)] for i, w_array in enumerate( wire_array_2 ): # Start with low weights and start reducing while len(w_array) + len(deferred[i]) > reduction_target: if len(w_array) + len(deferred[i]) - reduction_target >= 2: cout, sum = _one_bit_add_no_concat( *(w_array.pop(0) for j in range(3)) ) deferred[i].append(sum) deferred[i + 1].append(cout) else: cout, sum = half_adder(*(w_array.pop(0) for j in range(2))) deferred[i].append(sum) deferred[i + 1].append(cout) deferred[i].extend(w_array) if len(deferred[i]) > reduction_target: msg = "Expected that the code would be able to reduce more wires" raise pyrtl.PyrtlError(msg) wire_array_2 = deferred[:result_bitwidth] # At this stage in the multiplication we have only 2 wire vectors left. # now we need to add them up result = _sparse_adder(wire_array_2, final_adder) if len(result) > result_bitwidth: return result[:result_bitwidth] return result
def _sparse_adder(wire_array_2, adder): result = [] for single_w_index in range(len(wire_array_2)): if ( len(wire_array_2[single_w_index]) == 2 ): # Check if the two wire vectors overlap yet break result.append(wire_array_2[single_w_index][0]) wires_to_zip = wire_array_2[single_w_index:] add_wires = tuple(itertools.zip_longest(*wires_to_zip, fillvalue=pyrtl.Const(0))) adder_result = adder( pyrtl.concat_list(add_wires[0]), pyrtl.concat_list(add_wires[1]) ) return pyrtl.concat(adder_result, *reversed(result)) """ Some adders that utilize these tree reducers """
[docs] def fast_group_adder( wires_to_add: list[pyrtl.WireVector], reducer: Callable = wallace_reducer, final_adder: Callable = kogge_stone, ): """A generalization of :func:`carrysave_adder`, ``fast_group_adder`` is designed to add many numbers together in a both area and time efficient manner. Uses a tree reducer to achieve this performance. The length of the result is:: max(len(w) for w in wires_to_add) + ceil(len(wires_to_add)) .. doctest only:: >>> import pyrtl >>> pyrtl.reset_working_block() Example:: >>> wires_to_add = [pyrtl.Const(n) for n in range(10)] >>> output = pyrtl.Output(name="output") >>> output <<= pyrtl.rtllib.adders.fast_group_adder(wires_to_add) >>> sim = pyrtl.Simulation() >>> sim.step() >>> sim.inspect("output") 45 >>> sum(range(10)) 45 :param wires_to_add: A :class:`list` of :class:`WireVectors<.WireVector>` to add. :param reducer: The tree reducer to use. See :func:`wallace_reducer` and :func:`dada_reducer`. :param final_adder: The two value adder to use at the end. :return: A :class:`.WireVector` with the result of the addition. """ longest_wire_len = max(len(w) for w in wires_to_add) result_bitwidth = longest_wire_len + math.ceil(math.log2(len(wires_to_add))) bits = [[] for i in range(longest_wire_len)] for wire in wires_to_add: for bit_loc, bit in enumerate(wire): bits[bit_loc].append(bit) return reducer(bits, result_bitwidth, final_adder)