from __future__ import annotations
import _ctypes
import ctypes
import platform
import shutil
import subprocess
import sys
import tempfile
import warnings
from collections.abc import Mapping
from os import path
from pyrtl.core import Block, working_block
from pyrtl.helperfuncs import infer_val_and_bitwidth
from pyrtl.memory import MemBlock, RomBlock
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
from pyrtl.simulation import SimulationTrace, _trace_sort_key
from pyrtl.wire import Const, Input, Output, Register, WireVector
__all__ = ["CompiledSimulation"]
class DllMemInspector(Mapping):
"""Dictionary-like access to a hashmap in a CompiledSimulation."""
def __init__(self, sim, mem):
self._aw = mem.addrwidth
self._limbs = sim._limbs(mem)
self._vn = vn = sim.varname[mem]
self._mem = ctypes.c_void_p.in_dll(sim._dll, vn)
self._sim = sim # keep reference to avoid freeing dll
def __getitem__(self, ind):
arr = self._sim._mem_lookup(self._mem, ind)
val = 0
for n in reversed(range(self._limbs)):
val <<= 64
val |= arr[n]
return val
def __iter__(self):
return iter(range(len(self)))
def __len__(self):
return 1 << self._aw
def __eq__(self, other):
if (
isinstance(other, DllMemInspector)
and self._sim is other._sim
and self._vn == other._vn
):
return True
return all(self[x] == other.get(x, 0) for x in self)
def __hash__(self):
return hash(self._sim) ^ hash(self._vn)
[docs]
class CompiledSimulation:
"""Simulate a block by generating, compiling, and running C code.
``CompiledSimulation`` provides significant execution speed improvements over
:class:`FastSimulation`, at the cost of even longer start-up time. Generally this
will do better than :class:`FastSimulation` for simulations requiring over 1000
steps.
``CompiledSimulation`` is not built to be a debugging tool, though it may help with
debugging. Note that only :class:`Input` and :class:`Output` wires can be traced
with ``CompiledSimulation``.
.. note::
For very large circuits, :class:`FastSimulation` can sometimes be a better
choice than ``CompiledSimulation`` because ``CompiledSimulation`` will generate
an extremely large ``.c`` file, which can take prohibitively long to compile and
optimize. :class:`FastSimulation` will generate an extremely large ``.py``
file, but Python will interpret that generated code as needed, instead of trying
to process all the generated code at once.
.. WARNING::
This code is still experimental, but has been used on designs of significant
scale to good effect.
To use ``CompiledSimulation``, you'll need:
- A 64-bit processor
- GCC (tested on version 4.8.4)
- A 64-bit build of Python
If using the multiplication operand, only some architectures are supported:
- ``x86-64`` / ``amd64``
- ``arm64`` / ``aarch64``
- ``mips64`` (untested)
``default_value`` is currently only implemented for :class:`Registers<Register>`,
not :class:`MemBlocks<MemBlock>`.
``CompiledSimulation`` is a drop-in replacement for :class:`Simulation`, so the two
classes share the same interface. See :class:`Simulation` for interface
documentation, and more details about PyRTL simulations.
"""
def __init__(
self,
tracer: SimulationTrace = True,
register_value_map: dict[Register, int] | None = None,
memory_value_map: dict[MemBlock, dict[int, int]] | None = None,
default_value: int = 0,
block: Block = None,
):
if memory_value_map is None:
memory_value_map = {}
if register_value_map is None:
register_value_map = {}
self._dll = self._dir = None
self.block = working_block(block)
self.block.sanity_check()
if tracer is True:
tracer = SimulationTrace()
self.tracer = tracer
self._remove_untraceable()
self.default_value = default_value
self._regmap = {} # Updated below
self._memmap = memory_value_map
self._uid_counter = 0
self.varname = {} # mapping from wires and memories to C variables
for r in self.block.wirevector_subset(Register):
rval = register_value_map.get(r, r.reset_value)
if rval is None:
rval = self.default_value
self._regmap[r] = rval
self.tracer._set_initial_values(
default_value, register_value_map, memory_value_map
)
self._create_dll()
self._initialize_mems()
def inspect_mem(self, mem: MemBlock) -> dict[int, int]:
return DllMemInspector(self, mem)
def inspect(self, w: str) -> int:
if isinstance(w, WireVector):
w = w.name
try:
vals = self.tracer.trace[w]
except KeyError:
pass
else:
if not vals:
msg = "No context available. Please run a simulation step"
raise PyrtlError(msg)
return vals[-1]
msg = "CompiledSimulation does not support inspecting internal WireVectors"
raise PyrtlError(msg)
def step(self, provided_inputs: dict[str, int] | None = None, inputs=None):
if provided_inputs is None:
provided_inputs = {}
if inputs is not None:
warnings.warn(
"CompiledSimulation.step: `inputs` was renamed to `provided_inputs`",
DeprecationWarning,
stacklevel=2,
)
provided_inputs = inputs
self.run([provided_inputs])
def step_multiple(
self,
provided_inputs: dict[str, list[int]] | None = None,
expected_outputs: dict[str, int] | None = None,
nsteps: int | None = None,
file=sys.stdout,
stop_after_first_error: bool = False,
):
if expected_outputs is None:
expected_outputs = {}
if provided_inputs is None:
provided_inputs = {}
if not nsteps and len(provided_inputs) == 0:
msg = "need to supply either input values or a number of steps to simulate"
raise PyrtlError(msg)
if len(provided_inputs) > 0:
longest = sorted(
provided_inputs.items(), key=lambda t: len(t[1]), reverse=True
)[0]
msteps = len(longest[1])
if nsteps:
if nsteps > msteps:
msg = (
"nsteps is specified but is greater than the number of values "
"supplied for each input"
)
raise PyrtlError(msg)
else:
nsteps = msteps
if nsteps < 1:
msg = "must simulate at least one step"
raise PyrtlError(msg)
if list(filter(lambda value: len(value) < nsteps, provided_inputs.values())):
msg = (
"must supply a value for each provided wire for each step of simulation"
)
raise PyrtlError(msg)
if list(filter(lambda value: len(value) < nsteps, expected_outputs.values())):
msg = (
"any expected outputs must have a supplied value each step of "
"simulation"
)
raise PyrtlError(msg)
failed = []
for i in range(nsteps):
self.step({w: int(v[i]) for w, v in provided_inputs.items()})
for expvar in expected_outputs:
expected = expected_outputs[expvar][i]
if expected == "?":
continue
expected = int(expected)
actual = self.inspect(expvar)
if expected != actual:
failed.append((i, expvar, expected, actual))
if failed and stop_after_first_error:
break
if failed:
if stop_after_first_error:
s = "(stopped after step with first error):"
else:
s = "on one or more steps:"
file.write("Unexpected output " + s + "\n")
file.write(
"{:>5} {:>10} {:>8} {:>8}\n".format(
"step", "name", "expected", "actual"
)
)
def _sort_tuple(t):
# Sort by step and then wire name
return (t[0], _trace_sort_key(t[1]))
failed_sorted = sorted(failed, key=_sort_tuple)
for step, name, expected, actual in failed_sorted:
file.write(f"{step:>5} {name:>10} {expected:>8} {actual:>8}\n")
file.flush()
def run(self, inputs: list[dict[str, int]]):
"""Run many steps of the ``CompiledSimulation``.
:meth:`CompiledSimulation.step` and :meth:`CompiledSimulation.step_multiple` are
wrappers around this lower-level method.
:param inputs: A list of input mappings for each step; its length is the number
of steps to be executed.
"""
steps = len(inputs)
# create i/o arrays of the appropriate length
ibuf_type = ctypes.c_uint64 * (steps * self._ibufsz)
obuf_type = ctypes.c_uint64 * (steps * self._obufsz)
ibuf = ibuf_type()
obuf = obuf_type()
# these array will be passed to _crun
self._crun.argtypes = [ctypes.c_uint64, ibuf_type, obuf_type]
# build the input array
for n, inmap in enumerate(inputs):
for w in inmap:
if isinstance(w, WireVector):
name = w.name
else:
name = w
start, count = self._inputpos[name]
start += n * self._ibufsz
val = inmap[w]
val = infer_val_and_bitwidth(val, bitwidth=self._inputbw[name]).value
# pack input
for pos in range(start, start + count):
ibuf[pos] = val & ((1 << 64) - 1)
val >>= 64
# run the simulation
self._crun(steps, ibuf, obuf)
# save traced wires
for name in self.tracer.trace:
rname = self._probe_mapping.get(name, name)
if rname in self._outputpos:
start, count = self._outputpos[rname]
buf, sz = obuf, self._obufsz
elif rname in self._inputpos:
start, count = self._inputpos[rname]
buf, sz = ibuf, self._ibufsz
else:
msg = "Untraceable wire in tracer"
raise PyrtlInternalError(msg)
res = []
for _step in range(steps):
val = 0
# unpack output
for pos in reversed(range(start, start + count)):
val <<= 64
val |= buf[pos]
res.append(val)
start += sz
self.tracer.trace[name].extend(res)
def _traceable(self, wv):
"""Check if wv is able to be traced.
If it is traceable due to a probe, record that probe in _probe_mapping.
"""
if isinstance(wv, (Input, Output)):
return True
for net in self.block.logic:
if (
net.op == "w"
and net.args[0].name == wv.name
and isinstance(net.dests[0], Output)
):
self._probe_mapping[wv.name] = net.dests[0].name
return True
return False
def _remove_untraceable(self):
"""Remove from the tracer those wires that CompiledSimulation cannot track.
Create _probe_mapping for wires only traceable via probes.
"""
self._probe_mapping = {}
wvs = {wv for wv in self.tracer.wires_to_track if self._traceable(wv)}
self.tracer.wires_to_track = wvs
self.tracer._wires = {wv.name: wv for wv in wvs}
self.tracer.trace.__init__(wvs)
def _create_dll(self):
"""Create a dynamically-linked library implementing the simulation logic."""
self._dir = tempfile.mkdtemp()
with open(path.join(self._dir, "pyrtlsim.c"), "w") as f:
self._create_code(lambda s: f.write(s + "\n"))
if platform.system() == "Darwin":
shared = "-dynamiclib"
march = ""
else:
shared = "-shared"
march = "-march=native"
subprocess.check_call(
[
"gcc",
"-O0",
march,
"-std=c99",
"-m64",
shared,
"-fPIC",
path.join(self._dir, "pyrtlsim.c"),
"-o",
path.join(self._dir, "pyrtlsim.so"),
],
shell=(platform.system() == "Windows"),
)
self._dll = ctypes.CDLL(path.join(self._dir, "pyrtlsim.so"))
self._crun = self._dll.sim_run_all
self._crun.restype = None # argtypes set on use
self._initialize_mems = self._dll.initialize_mems
self._initialize_mems.restype = None
self._mem_lookup = self._dll.lookup
self._mem_lookup.restype = ctypes.POINTER(ctypes.c_uint64)
def _limbs(self, w):
"""Number of 64-bit words needed to store value of wire."""
return (w.bitwidth + 63) // 64
def _makeini(self, w, v):
"""C initializer string for a wire with a given value."""
pieces = []
for _ in range(self._limbs(w)):
pieces.append(hex(v & ((1 << 64) - 1)))
v >>= 64
return ",".join(pieces).join("{}")
def _romwidth(self, m):
"""Bitwidth of integer type sufficient to hold rom entry.
On large memories, returns 64; an array will be needed.
"""
if m.bitwidth <= 8:
return 8
if m.bitwidth <= 16:
return 16
if m.bitwidth <= 32:
return 32
return 64
def _makemask(self, dest, res, pos):
"""Create a bitmask.
The value being masked is of width `res`. Limb number `pos` of `dest` is being
assigned to.
"""
if (res is None or dest.bitwidth < res) and 0 < (dest.bitwidth - 64 * pos) < 64:
return f"&0x{(1 << (dest.bitwidth % 64)) - 1:X}"
return ""
def _getarglimb(self, arg, n):
"""Get the nth limb of the given wire.
Returns '0' when the wire does not have sufficient limbs.
"""
return f"{self.varname[arg]}[{n}]" if arg.bitwidth > 64 * n else "0"
def _clean_name(self, prefix, obj):
"""Create a C variable name with the given prefix based on the name of obj."""
return "{}{}_{}".format(
prefix, self._uid(), "".join(c for c in obj.name if c.isalnum())
)
def _uid(self):
"""Get an auto-incrementing number suitable for use as a unique identifier."""
x = self._uid_counter
self._uid_counter += 1
return x
def _declare_roms(self, write, roms):
for mem in roms:
self.varname[mem] = vn = self._clean_name("m", mem)
# extract data from mem
romval = [mem._get_read_data(n) for n in range(1 << mem.addrwidth)]
write(
f"static const uint{self._romwidth(mem)}_t {vn}[]"
f"[{self._limbs(mem)}] = {{"
)
for rv in romval:
write(self._makeini(mem, rv) + ",")
write("};")
def _declare_mems(self, write, mems):
for mem in mems:
self.varname[mem] = vn = self._clean_name("m", mem)
write("EXPORT")
write(f"hashmap_t *{vn};")
next_tmp = 0
write("EXPORT")
write("void initialize_mems() {")
for mem in mems:
# Create hashmap
write(f"{self.varname[mem]} = create_hash_map(256, {self._limbs(mem)});")
if mem in self._memmap:
# Insert default values
for k, v in self._memmap[mem].items():
write(f"val_t t{next_tmp}[] = {self._makeini(mem, v)};")
write(f"insert({self.varname[mem]}, {k}, t{next_tmp});")
next_tmp += 1
write("}")
def _declare_wv(self, write, w):
self.varname[w] = vn = self._clean_name("w", w)
if isinstance(w, Const):
write(f"const uint64_t {vn}[{self._limbs(w)}] = {self._makeini(w, w.val)};")
elif isinstance(w, Register):
rval = self._regmap.get(w, w.reset_value)
if rval is None:
rval = self.default_value
write(f"static uint64_t {vn}[{self._limbs(w)}] = {self._makeini(w, rval)};")
else:
write(f"uint64_t {vn}[{self._limbs(w)}];")
def _build_memread(self, write, _op, param, args, dest):
mem = param[1]
for n in range(self._limbs(dest)):
if isinstance(mem, RomBlock):
write(
f"{self.varname[dest]}[{n}] = {self.varname[mem]}["
f"{self.varname[args[0]]}[0]][{n}]"
f"{self._makemask(dest, mem.bitwidth, n)};"
)
else:
write(
f"{self.varname[dest]}[{n}] = lookup({self.varname[mem]}, "
f"{self.varname[args[0]]}[0])[{n}]"
f"{self._makemask(dest, mem.bitwidth, n)};"
)
def _build_wire(self, write, _op, _param, args, dest):
for n in range(self._limbs(dest)):
write(
f"{self.varname[dest]}[{n}] = "
f"{self.varname[args[0]]}[{n}]"
f"{self._makemask(dest, args[0].bitwidth, n)};"
)
def _build_not(self, write, _op, _param, args, dest):
for n in range(self._limbs(dest)):
write(
f"{self.varname[dest]}[{n}] = "
f"(~{self.varname[args[0]]}[{n}]){self._makemask(dest, None, n)};"
)
def _build_bitwise(self, write, op, _param, args, dest): # &, |, ^ only
for n in range(self._limbs(dest)):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
write(
"{dest}[{n}] = ({arg0}{op}{arg1}){mask};".format(
dest=self.varname[dest],
n=n,
arg0=arg0,
arg1=arg1,
op=op,
mask=self._makemask(
dest, max(args[0].bitwidth, args[1].bitwidth), n
),
)
)
def _build_nand(self, write, _op, _param, args, dest):
for n in range(self._limbs(dest)):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
write(
f"{self.varname[dest]}[{n}] = "
f"(~({arg0}&{arg1})){self._makemask(dest, None, n)};"
)
def _build_eq(self, write, _op, _param, args, dest):
cond = []
for n in range(max(self._limbs(args[0]), self._limbs(args[1]))):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
cond.append(f"({arg0}=={arg1})")
write(
"{dest}[0] = {cond};".format(dest=self.varname[dest], cond="&&".join(cond))
)
def _build_cmp(self, write, op, _param, args, dest): # <, > only
cond = None
for n in range(max(self._limbs(args[0]), self._limbs(args[1]))):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
c = f"({arg0}{op}{arg1})"
if cond is None:
cond = c
else:
cond = f"({c}||(({arg0}=={arg1})&&{cond}))"
write(f"{self.varname[dest]}[0] = {cond};")
def _build_mux(self, write, _op, _param, args, dest):
write(f"if ({self.varname[args[0]]}[0]) {{")
for n in range(self._limbs(dest)):
write(
f"{self.varname[dest]}[{n}] = "
f"{self.varname[args[2]]}[{n}]"
f"{self._makemask(dest, args[2].bitwidth, n)};"
)
write("} else {")
for n in range(self._limbs(dest)):
write(
f"{self.varname[dest]}[{n}] = "
f"{self.varname[args[1]]}[{n}]"
f"{self._makemask(dest, args[1].bitwidth, n)};"
)
write("}")
def _build_add(self, write, _op, _param, args, dest):
write("carry = 0;")
for n in range(self._limbs(dest)):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
write(f"tmp = {arg0}+{arg1};")
write(
"{dest}[{n}] = (tmp + carry){mask};".format(
dest=self.varname[dest],
n=n,
mask=self._makemask(
dest, max(args[0].bitwidth, args[1].bitwidth) + 1, n
),
)
)
write(f"carry = (tmp < {arg0})|({self.varname[dest]}[{n}] < tmp);")
def _build_sub(self, write, _op, _param, args, dest):
write("carry = 0;")
for n in range(self._limbs(dest)):
arg0 = self._getarglimb(args[0], n)
arg1 = self._getarglimb(args[1], n)
write(f"tmp = {arg0}-{arg1};")
write(
f"{self.varname[dest]}[{n}] = (tmp - carry)"
f"{self._makemask(dest, None, n)};"
)
write(f"carry = (tmp > {arg0})|({self.varname[dest]}[{n}] > tmp);")
def _build_mul(self, write, _op, _param, args, dest):
for n in range(self._limbs(dest)):
write(f"{self.varname[dest]}[{n}] = 0;")
for p0 in range(self._limbs(args[0])):
write("carry = 0;")
arg0 = self._getarglimb(args[0], p0)
for p1 in range(self._limbs(args[1])):
if self._limbs(dest) <= p0 + p1:
break
arg1 = self._getarglimb(args[1], p1)
write(f"mul128({arg0}, {arg1}, tmplo, tmphi);")
write(f"tmp = {self.varname[dest]}[{p0 + p1}];")
write("tmplo += carry; carry = tmplo < carry; tmplo += tmp;")
write("tmphi += carry + (tmplo < tmp); carry = tmphi;")
write(
"{dest}[{p}] = tmplo{mask};".format(
dest=self.varname[dest],
p=p0 + p1,
mask=self._makemask(
dest, args[0].bitwidth + args[1].bitwidth, p0 + p1
),
)
)
if self._limbs(dest) > p0 + self._limbs(args[1]):
write(
"{dest}[{p}] = carry{mask};".format(
dest=self.varname[dest],
p=p0 + self._limbs(args[1]),
mask=self._makemask(
dest,
args[0].bitwidth + args[1].bitwidth,
p0 + self._limbs(args[1]),
),
)
)
def _build_concat(self, write, _op, _param, args, dest):
cattotal = sum(x.bitwidth for x in args)
pieces = (
(self.varname[a], lx, 0, min(64, a.bitwidth - 64 * lx))
for a in reversed(args)
for lx in range(self._limbs(a))
)
curr = next(pieces)
for n in range(self._limbs(dest)):
res = []
dpos = 0
while True:
arg, alimb, astart, asize = curr
res.append(f"(({arg}[{alimb}]>>{astart})<<{dpos})")
dpos += asize
if dpos >= dest.bitwidth - 64 * n:
break
if dpos > 64:
curr = (arg, alimb, 64 - (dpos - asize), dpos - 64)
break
curr = next(pieces)
if dpos == 64:
break
write(
"{dest}[{n}] = ({res}){mask};".format(
dest=self.varname[dest],
n=n,
res="|".join(res),
mask=self._makemask(dest, cattotal, n),
)
)
def _build_select(self, write, _op, param, args, dest):
for n in range(self._limbs(dest)):
bits = [
f"((1&({self.varname[args[0]]}[{b // 64}]>>{b % 64}))<<{en})"
for en, b in enumerate(param[64 * n : min(dest.bitwidth, 64 * (n + 1))])
]
write(
"{dest}[{n}] = {bits};".format(
dest=self.varname[dest], n=n, bits="|".join(bits)
)
)
def _declare_mem_helpers(self, write):
helpers = """
typedef uint64_t val_t;
typedef struct node
{
uint64_t key;
val_t *val;
struct node *next;
} node_t;
typedef struct hashmap
{
int size;
int val_limbs;
val_t *default_value;
node_t **list;
} hashmap_t;
hashmap_t *create_hash_map(int size, int val_limbs)
{
int i;
hashmap_t *h = (hashmap_t *) malloc(sizeof(hashmap_t));
h->size = size;
h->val_limbs = val_limbs;
h->list = (node_t **) malloc(sizeof(node_t *) * size);
h->default_value = (val_t *) malloc(sizeof(val_t) * val_limbs);
for (i = 0; i < val_limbs; i++)
h->default_value[i] = 0;
for (i = 0; i < size; i++)
h->list[i] = NULL;
return h;
}
int hash_code(hashmap_t *h, uint64_t key)
{
return key % h->size;
}
void insert(hashmap_t *h, uint64_t key, val_t val[])
{
int pos = hash_code(h, key);
struct node *list = h->list[pos];
struct node *new_node = (node_t *) malloc(sizeof(node_t));
struct node *temp = list;
while (temp)
{
if (temp->key == key)
{
memcpy(temp->val, val, sizeof(val_t) * h->val_limbs);
return;
}
temp = temp->next;
}
new_node->key = key;
new_node->val = (val_t *) malloc(sizeof(val_t) * h->val_limbs);
memcpy(new_node->val, val, sizeof(val_t) * h->val_limbs);
new_node->next = list;
h->list[pos] = new_node;
}
EXPORT
val_t* lookup(hashmap_t *h, uint64_t key)
{
int pos = hash_code(h, key);
node_t *list = h->list[pos];
node_t *temp = list;
while (temp)
{
if (temp->key == key)
{
return temp->val;
}
temp = temp->next;
}
return h->default_value;
}
"""
write(helpers)
def _create_code(self, write):
write("#include <stdint.h>")
write("#include <stdlib.h>")
write("#include <string.h>")
# windows dllexport needed to make symbols visible
if platform.system() == "Windows":
write("#define EXPORT __declspec(dllexport)")
else:
write("#define EXPORT")
# multiplication macro
# for efficient 64x64 -> 128 bit multiplication without uint128_t
# as -O0 optimization does not handle uint128_t well
machine_alias = {"amd64": "x86_64", "aarch64": "arm64", "aarch64_be": "arm64"}
machine = platform.machine().lower()
machine = machine_alias.get(machine, machine)
mulinstr = {
"x86_64": '"mulq %q3":"=a"(pl),"=d"(ph):"%0"(t0),"r"(t1):"cc"',
"arm64": '"mul %0, %2, %3\\n\\t" \\\n'
'"umulh %1, %2, %3":"=&r"(pl),"=r"(ph):"r"(t0),"r"(t1):"cc"',
"mips64": '"dmultu %2, %3\\n\\t" \\\n'
'"tmflo %0\\n\\t" \\\n'
'"mfhi %1":"=r"(pl),"=r"(ph):"r"(t0),"r"(t1)',
}
if machine in mulinstr:
write(f"#define mul128(t0, t1, pl, ph) __asm__({mulinstr[machine]})")
# declare memories
mems = {net.op_param[1] for net in self.block.logic_subset("m@")}
for key in self._memmap:
if key not in mems:
msg = "unrecognized MemBlock in memory_value_map"
raise PyrtlError(msg)
if isinstance(key, RomBlock):
msg = "RomBlock in memory_value_map"
raise PyrtlError(msg)
self._declare_mem_helpers(write)
roms = {mem for mem in mems if isinstance(mem, RomBlock)}
self._declare_roms(write, roms)
mems = {
mem
for mem in mems
if isinstance(mem, MemBlock) and not isinstance(mem, RomBlock)
}
self._declare_mems(write, mems)
# single step function
write("static void sim_run_step(uint64_t inputs[], uint64_t outputs[]) {")
write("uint64_t tmp, carry, tmphi, tmplo;") # temporary variables
# declare wire vectors
for w in self.block.wirevector_set:
self._declare_wv(write, w)
# inputs copied in
inputs = list(self.block.wirevector_subset(Input))
# for each input wire, start and number of elements in input array
self._inputpos = {}
self._inputbw = {} # bitwidth of each input wire
ipos = 0
for w in inputs:
self._inputpos[w.name] = ipos, self._limbs(w)
self._inputbw[w.name] = w.bitwidth
for n in range(self._limbs(w)):
write(f"{self.varname[w]}[{n}] = inputs[{ipos}];")
ipos += 1
self._ibufsz = ipos # total length of input array
# combinational logic
op_builders = {
"m": self._build_memread,
"w": self._build_wire,
"~": self._build_not,
"&": self._build_bitwise,
"|": self._build_bitwise,
"^": self._build_bitwise,
"n": self._build_nand,
"=": self._build_eq,
"<": self._build_cmp,
">": self._build_cmp,
"x": self._build_mux,
"+": self._build_add,
"-": self._build_sub,
"*": self._build_mul,
"c": self._build_concat,
"s": self._build_select,
}
for net in self.block: # topological order
if net.op in "r@":
continue # skip synchronized nets
op, param, args, dest = net.op, net.op_param, net.args, net.dests[0]
write(
"// net {op} : {args} -> {dest}".format(
op=op,
args=", ".join(self.varname[x] for x in args),
dest=self.varname[dest],
)
)
op_builders[op](write, op, param, args, dest)
# memory writes
for net in self.block.logic_subset("@"):
mem = net.op_param[1]
write(f"if ({self.varname[net.args[2]]}[0]) {{")
write(
f"insert({self.varname[mem]}, {self.varname[net.args[0]]}[0], "
f"{self.varname[net.args[1]]});"
)
write("}")
# register updates
regnets = list(self.block.logic_subset("r"))
for x, net in enumerate(regnets):
rin = net.args[0]
write(f"uint64_t regtmp{x}[{self._limbs(rin)}];")
for n in range(self._limbs(rin)):
write(f"regtmp{x}[{n}] = {self.varname[rin]}[{n}];")
# double loop to ensure register-to-register chains update correctly
for x, net in enumerate(regnets):
rout = net.dests[0]
for n in range(self._limbs(rout)):
write(f"{self.varname[rout]}[{n}] = regtmp{x}[{n}];")
# output copied out
outputs = list(self.block.wirevector_subset(Output))
# for each output wire, start and number of elements in output array
self._outputpos = {}
opos = 0
for w in outputs:
self._outputpos[w.name] = opos, self._limbs(w)
for n in range(self._limbs(w)):
write(f"outputs[{opos}] = {self.varname[w]}[{n}];")
opos += 1
self._obufsz = opos # total length of output array
write("}")
# entry point
write("EXPORT")
write(
"void sim_run_all("
"uint64_t stepcount, uint64_t inputs[], uint64_t outputs[]) {"
)
write("uint64_t input_pos = 0, output_pos = 0;")
write("for (uint64_t stepnum = 0; stepnum < stepcount; stepnum++) {")
write("sim_run_step(inputs+input_pos, outputs+output_pos);")
write(f"input_pos += {self._ibufsz};")
write(f"output_pos += {self._obufsz};")
write("}}")
def __del__(self):
"""Handle removal of the DLL when the simulator is deleted."""
if self._dll is not None:
handle = self._dll._handle
if platform.system() == "Windows":
_ctypes.FreeLibrary(handle)
else:
_ctypes.dlclose(handle)
self._dll = None
if self._dir is not None:
shutil.rmtree(self._dir)
self._dir = None