# Most users should only interact with the objects "conditional_assignment" and
# "otherwise". The classes defined below are internal implementation details.
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
from pyrtl.wire import Const, Register, WireVector
# -----------------------------------------------------------------------
# __ __ __ ___ __ __
# / ` / \ |\ | | \ | | | / \ |\ | /\ | /__`
# \__, \__/ | \| |__/ | | | \__/ | \| /~~\ |___ .__/
#
[docs]
def currently_under_condition() -> bool:
"""
:return: ``True`` iff execution is currently in the context of a
:data:`conditional_assignment`.
"""
return _depth > 0
# `conditional_assignment` and `otherwise`, both visible in the pyrtl module, are
# defined as instances (hopefully the only and unchanging instances) of the following
# two types.
class _ConditionalAssignment:
def __init__(self):
self.defaults = {}
def __call__(self, defaults):
self.defaults = defaults
return self
""" Context providing functionality of "conditional_assignment". """
def __enter__(self):
global _depth
_check_no_nesting()
_depth = 1
def __exit__(self, *exc_info):
try:
_finalize(self.defaults)
finally:
# even if the above finalization throws an error we need to reset the state
# to prevent errors from bleeding over
_reset_conditional_state() # sets _depth back to 0
class _Otherwise:
def __enter__(self):
_push_condition(otherwise)
def __exit__(self, *exc_info):
_pop_condition()
def _reset_conditional_state():
"""Set or reset all the module state required for conditionals."""
global _conditions_list_stack
global _conflicts_map
global _predicate_map
global _depth
_depth = 0
_conditions_list_stack = [[]] # stack of lists of current conditions
# _predicate_map: map wirevector or mem -> [(final_pred, rhs), ...]
_predicate_map = {}
# _conflicts_map: map wirevector or mem ->
# [ set([(pred,bool), (pred,bool)]), set([(pred,bool)..
# * each element maps to a list of sets of tuples of (predicate id, bool)
# * each time a value is written (lhs) we add the predicate set to the list
# * each new write happens we have to check that the new predicate has at least one
# negated term with the value we are now trying to write. Otherwise it is an
# error.
_conflicts_map = {}
_reset_conditional_state()
conditional_assignment = _ConditionalAssignment()
otherwise = _Otherwise()
def _push_condition(predicate):
"""As we enter new conditions, this pushes them on the predicate stack."""
global _depth
_check_under_condition()
_depth += 1
if predicate is not otherwise and len(predicate) > 1:
msg = "all predicates for conditional assignments must be wirevectors of len 1"
raise PyrtlError(msg)
_conditions_list_stack[-1].append(predicate)
_conditions_list_stack.append([])
def _pop_condition():
"""As we exit conditions, this pops them off the stack."""
global _depth
_check_under_condition()
_conditions_list_stack.pop()
_depth -= 1
def _build(lhs, rhs):
"""Stores the wire assignment details until finalize is called."""
_check_under_condition()
final_predicate, pred_set = _current_select()
_check_and_add_pred_set(lhs, pred_set)
_predicate_map.setdefault(lhs, []).append((final_predicate, rhs))
def _build_read_port(mem, addr):
# TODO: reduce number of ports through collapsing reads
return mem._build_read_port(addr)
def _check_no_nesting():
if _depth != 0:
msg = "no nesting of conditional assignments allowed"
raise PyrtlError(msg)
def _check_under_condition():
if not currently_under_condition():
msg = 'conditional assignment "|=" only valid under a condition'
raise PyrtlError(msg)
def _check_and_add_pred_set(lhs, pred_set):
for test_set in _conflicts_map.setdefault(lhs, []):
if _pred_sets_are_in_conflict(pred_set, test_set):
msg = f"conflicting conditions for {lhs}"
raise PyrtlError(msg)
_conflicts_map[lhs].append(pred_set)
def _pred_sets_are_in_conflict(pred_set_a, pred_set_b):
"""Find conflict in sets, return conflict if found, else None."""
# pred_sets conflict if we cannot find one shared predicate that is "negated" in one
# and "non-negated" in the other
for pred_a, bool_a in pred_set_a:
for pred_b, bool_b in pred_set_b:
if pred_a is pred_b and bool_a != bool_b:
return False
return True
def _finalize(defaults):
"""
Build the required muxes and call back to WireVector to finalize the wirevector
build.
"""
from pyrtl.corecircuits import select
from pyrtl.memory import MemBlock
for lhs in _predicate_map:
# handle memory write ports
if isinstance(lhs, MemBlock):
p, (addr, data, enable) = _predicate_map[lhs][0]
combined_enable = select(p, truecase=enable, falsecase=Const(0))
combined_addr = addr
combined_data = data
for p, (addr, data, enable) in _predicate_map[lhs][1:]:
combined_enable = select(p, truecase=enable, falsecase=combined_enable)
combined_addr = select(p, truecase=addr, falsecase=combined_addr)
combined_data = select(p, truecase=data, falsecase=combined_data)
lhs._build(combined_addr, combined_data, combined_enable)
# handle wirevector and register assignments
else:
if isinstance(lhs, Register):
if lhs in defaults:
result = defaults[lhs]
else:
result = lhs # default for registers is "self"
elif isinstance(lhs, WireVector):
if lhs in defaults:
result = defaults[lhs]
else:
result = 0 # default for wire is "0"
else:
msg = "unknown assignment in finalize"
raise PyrtlInternalError(msg)
predlist = _predicate_map[lhs]
for p, rhs in predlist:
result = select(p, truecase=rhs, falsecase=result)
lhs._build(result)
def _current_select():
"""Function to calculate the current "predicate" in the current context.
Returns a tuple of information: (predicate, pred_set). The value pred_set is a set([
(predicate, bool), ... ]) as described in the _reset_conditional_state
"""
# helper to create the conjuction of predicates
def and_with_possible_none(a, b):
assert a is not None or b is not None
if a is None:
return b
if b is None:
return a
return a & b
def between_otherwise_and_current(predlist):
lastother = None
for i, p in enumerate(predlist[:-1]):
if p is otherwise:
lastother = i
if lastother is None:
return predlist[:-1]
return predlist[lastother + 1 : -1]
select = None
pred_set = set()
# for all conditions except the current children (which should be [])
for predlist in _conditions_list_stack[:-1]:
# negate all of the predicates between "otherwise" and the current one
for predicate in between_otherwise_and_current(predlist):
select = and_with_possible_none(select, ~predicate)
pred_set.add((predicate, True))
# include the predicate for the current one (not negated)
if predlist[-1] is not otherwise:
predicate = predlist[-1]
select = and_with_possible_none(select, predicate)
pred_set.add((predicate, False))
if select is None:
msg = "problem with conditional assignment"
raise PyrtlError(msg)
if len(select) != 1:
msg = "conditional predicate with length greater than 1"
raise PyrtlInternalError(msg)
return select, pred_set
# Some examples that were helpful in the design and testing of conditional
# 1 with a: # a
# 2 with b: # not(a) and b
# 3 with x: # not(a) and b and x
# 4 with otherwise: # not(a) and b and not(x)
# 5 with y: # not(a) and b and y; check(3,4)
# 6 with i: # not(a) and b and y and i; check(3,4)
# 7 with j: # not(a) and b and y and not(i) and j; check(3,4)
# 8 with otherwise: # not(a) and b and y and not(i) and not(j): check(3,4)
# 9 with k: # not(a) and b and y and k; check(3,4,6,7,8)
# 10 with m: # not(a) and b and y and not(k) and m; check(3,4,6,7,8)
# 11 with otherwise: #not(a) and not(b)
# 12 with c: #c; check(1,2,3,4,5,6,7,8,9,10,11)
# 0 with a: # a
# 1 with otherwise: # a;
# 2 with b: # not(a) and b; check(0,1)
# 3 with x: # not(a) and b and x; check(0,1)
# 4 with otherwise: # not(a) and b and not(x); check(0,1)
# 5 with y: # not(a) and b and y; check(0,1,3,4)
# 6 with i: # not(a) and b and y and i; check(0,1,3,4)
# 7 with j: # not(a) and b and y and not(i) and j; check(0,1,3,4)
# 8 with otherwise: # not(a) and b and y and not(i) and not(j): check(0,1,3,4)
# 9 with k: # not(a) and b and y and k; check(0,1,3,4,6,7,8)
# 10 with m: # not(a) and b and y and not(k) and m; check(0,1,3,4,6,7,8)
# with z: check(0,1,3,4)
# with otherwise: check(0,1,3,4)
# with g: check(0,1,3,4,5,6,7,8,9,10)
# 11 with otherwise: #not(a) and not(b); check(0,1)
# 12 with c: #c; check(0,1,2,3,4,5,6,7,8,9,10,11)