Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 113 additions & 7 deletions protocols/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ pub struct Evaluator<'a> {
args_mapping: HashMap<SymbolId, BitVecValue>,
input_mapping: HashMap<SymbolId, ExprRef>,
output_mapping: HashMap<SymbolId, Output>,
input_dependencies: HashMap<SymbolId, Vec<SymbolId>>,
output_dependencies: HashMap<SymbolId, Vec<SymbolId>>,

// tracks forbidden ports due to combinational dependencies
forbidden_inputs: Vec<SymbolId>,
forbidden_outputs: Vec<SymbolId>,

// tracks the input pins and their values
input_vals: HashMap<SymbolId, InputValue>,
Expand Down Expand Up @@ -112,8 +118,8 @@ impl<'a> Evaluator<'a> {
let dut = tr.type_param.unwrap();
let dut_symbols = &st.get_children(&dut);

let mut input_mapping = HashMap::new();
let mut output_mapping = HashMap::new();
let mut input_mapping: HashMap<SymbolId, ExprRef> = HashMap::new();
let mut output_mapping: HashMap<SymbolId, Output> = HashMap::new();

for input in &sys.inputs {
info!(
Expand Down Expand Up @@ -156,6 +162,61 @@ impl<'a> Evaluator<'a> {
}
}

// find the combinational cone of influence for each input
let mut output_dependencies: HashMap<SymbolId, Vec<SymbolId>> = HashMap::new();
let mut input_dependencies: HashMap<SymbolId, Vec<SymbolId>> = HashMap::new();

// initialize: keys are outputs -> output_dependencies, and inputs -> input_dependencies
for symbol_id in output_mapping.keys() {
output_dependencies.insert(*symbol_id, Vec::new());
}
for symbol_id in input_mapping.keys() {
input_dependencies.insert(*symbol_id, Vec::new());
}

for (out_sym, out) in output_mapping.clone() {
let input_exprs =
patronus::system::analysis::cone_of_influence_comb(ctx, sys, out.expr);
// println!("{:?} {:?}", st[out_sym].name(), input_exprs.len());
for input_expr in input_exprs {
if let Some(input_sym) = input_mapping
.iter()
.find_map(|(k, v)| if *v == input_expr { Some(*k) } else { None })
{
// println!("{:?}", input_sym.clone());
if let Some(vec) = output_dependencies.get_mut(&out_sym) {
vec.push(input_sym);
}
if let Some(vec) = input_dependencies.get_mut(&input_sym) {
vec.push(out_sym.clone());
}
}
}
}

// DEBUG
// for (out_sym, inputs) in &output_dependencies {
// let out_name = st[out_sym].name();
// let input_names: Vec<String> = inputs.iter().map(|s| st[s].name().to_string()).collect();
// println!(
// "Output '{}' ({:?}) depends on inputs: {:?}",
// out_name,
// out_sym,
// input_names
// );
// }

// for (in_sym, outputs) in &input_dependencies {
// let in_name = st[in_sym].name();
// let output_names: Vec<String> = outputs.iter().map(|s| st[s].name().to_string()).collect();
// println!(
// "Input '{}' ({:?}) affects outputs: {:?}",
// in_name,
// in_sym,
// output_names
// );
// }

// For simplicity, we initialize an RNG with the seed 0 when generating
// random values for `DontCare`s
let mut rng = StdRng::seed_from_u64(0);
Expand Down Expand Up @@ -191,6 +252,10 @@ impl<'a> Evaluator<'a> {
args_mapping,
input_mapping,
output_mapping,
input_dependencies,
output_dependencies,
forbidden_inputs: Vec::new(),
forbidden_outputs: Vec::new(),
input_vals,
assertions_enabled: false,
rng,
Expand Down Expand Up @@ -220,6 +285,10 @@ impl<'a> Evaluator<'a> {
self.st = todo.st;
self.args_mapping = Evaluator::generate_args_mapping(self.st, todo.args);
self.next_stmt_map = todo.next_stmt_map;

// during each context switch (in a non-fixed point approach), we're in a new cycle so we need to clear forbidden ports
self.forbidden_inputs = Vec::new();
self.forbidden_outputs = Vec::new();
}

pub fn input_vals(&self) -> HashMap<SymbolId, InputValue> {
Expand Down Expand Up @@ -263,12 +332,28 @@ impl<'a> Evaluator<'a> {
Expr::Sym(sym_id) => {
let name = self.st[sym_id].name();
if let Some(expr_ref) = self.input_mapping.get(sym_id) {
Ok(ExprValue::Concrete(
// FIXME: if we observe a dut input prot, nothing to do??
return Ok(ExprValue::Concrete(
self.sim.get(*expr_ref).try_into().unwrap(),
))
));
} else if let Some(output) = self.output_mapping.get(sym_id) {
// if observing this output port is forbidden, error out
// FIXME: make a new error type for this
if self.forbidden_outputs.contains(sym_id) {
return Err(ExecutionError::dont_care_operation(
String::from("OBSERVED FORBIDDEN PORT"),
String::from(""),
*expr_id,
));
}

// if we observe a dut output port, restrict assignments to dependent input ports
if let Some(deps) = self.input_dependencies.get(sym_id) {
self.forbidden_inputs.extend(deps.iter().copied());
}

Ok(ExprValue::Concrete(
self.sim.get((output).expr).try_into().unwrap(),
self.sim.get(output.expr).try_into().unwrap(),
))
} else if let Some(bvv) = self.args_mapping.get(sym_id) {
Ok(ExprValue::Concrete(bvv.clone()))
Expand Down Expand Up @@ -448,6 +533,15 @@ impl<'a> Evaluator<'a> {
// FIXME: This should return a DontCare or a NewValue
let expr_val = self.evaluate_expr(expr_id)?;

// if the symbol is a forbidden input, error out
if self.forbidden_inputs.contains(symbol_id) {
return Err(ExecutionError::dont_care_operation(
String::from("ASSIGNED FORBIDDEN PORT"),
String::from(""),
*expr_id,
));
}

// if the symbol is currently a DontCare or OldValue, turn it into a NewValue
// if the symbol is currently a NewValue, error out -- two threads are trying to assign to the same input
if let Some(value) = self.input_vals.get_mut(symbol_id) {
Expand All @@ -456,7 +550,11 @@ impl<'a> Evaluator<'a> {
match expr_val {
ExprValue::DontCare => {
// Do nothing for DontCare
// (we don't want to re-randomize within a cycle because that would prevent convergence)

// assigning don't care: proceed with the assignment and forbid all dependent outputs
if let Some(deps) = self.output_dependencies.get(symbol_id) {
self.forbidden_outputs.extend(deps.iter().copied());
}
}
ExprValue::Concrete(bvv) => {
*value = InputValue::NewValue(bvv);
Expand All @@ -469,6 +567,11 @@ impl<'a> Evaluator<'a> {
&mut self.rng,
old_val.width(),
));

// assigning don 't care: proceed with the assignment and forbid all dependent outputs
if let Some(deps) = self.output_dependencies.get(symbol_id) {
self.forbidden_outputs.extend(deps.iter().copied());
}
}
ExprValue::Concrete(bvv) => {
*value = InputValue::NewValue(bvv);
Expand All @@ -477,7 +580,10 @@ impl<'a> Evaluator<'a> {
InputValue::NewValue(current_val) => {
match expr_val {
ExprValue::DontCare => {
// do nothing
// assigning don't care: proceed with the assignment and forbid all dependent outputs
if let Some(deps) = self.output_dependencies.get(symbol_id) {
self.forbidden_outputs.extend(deps.iter().copied());
}
}
ExprValue::Concrete(new_val) => {
// no width check needed; guaranteed to be the same
Expand Down
2 changes: 1 addition & 1 deletion protocols/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright 2025 Cornell University
// released under MIT License
// author: Nikil Shyamunder <[email protected]>
// author: Nikil Shyamsunder <[email protected]>
// author: Kevin Laeufer <[email protected]>
// author: Francis Pham <[email protected]>

Expand Down
102 changes: 51 additions & 51 deletions protocols/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub type TodoItem = (String, Vec<BitVecValue>);
type TransactionInfo<'a> = (&'a Transaction, &'a SymbolTable, NextStmtMap);

/// The maximum number of iterations to run for convergence before breaking with an ExecutionLimitExceeded error
const MAX_ITERS: usize = 10000;
const MAX_ITERS: usize = 0;

/// A `Todo` is a function call to be executed (i.e. a line in the `.tx` file)
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -262,62 +262,62 @@ impl<'a> Scheduler<'a> {
// initially there are no previous values.
// we always need to cycle at least twice to check convergence,
// and the first time we will get a previous input val.
let mut previous_input_vals: Option<HashMap<SymbolId, InputValue>> = None;
let mut active_input_vals: HashMap<SymbolId, InputValue>;
// let mut previous_input_vals: Option<HashMap<SymbolId, InputValue>> = None;
// let mut active_input_vals: HashMap<SymbolId, InputValue>;

// fixed point iteration with assertions off
self.evaluator.disable_assertions();

let mut iters = 0;
loop {
// run every active thread up to the next step to synchronize on
self.run_all_active_until_next_step(iters == 0); // only enable forks on the first iteration

// if there are threads now in next_threads, we need to move them to active_threads
if !self.next_threads.is_empty() {
info!(
"Moving {} threads from next_threads to active_threads",
self.next_threads.len()
);
self.active_threads.append(&mut self.next_threads);
}

// update the active input vals to reflect the current state
// for each thread, get its current input_vals (read-only clone)
active_input_vals = self.evaluator.input_vals();

if let Some(prev_vals) = previous_input_vals {
if prev_vals == active_input_vals {
break;
}
}

// if we've exceeded the max number of iterations before convergence,
// return an ExecutionLimitExceeded error on every thread.
// we should be able to theoretically show convergence is always possible, however
if iters > MAX_ITERS {
for thread in &self.active_threads {
self.results[thread.todo_idx] =
Err(ExecutionError::execution_limit_exceeded(MAX_ITERS));
}
// Emit diagnostics for all errors before returning
self.emit_all_diagnostics();
return self.results.clone();
}

info!("Active Input Vals {:?}", active_input_vals);

// change the previous input vals to equal the active input vals
previous_input_vals = Some(active_input_vals);

iters += 1;
}
// self.evaluator.disable_assertions();

// let mut iters = 0;
// loop {
// // run every active thread up to the next step to synchronize on
// self.run_all_active_until_next_step(iters == 0); // only enable forks on the first iteration

// // if there are threads now in next_threads, we need to move them to active_threads
// if !self.next_threads.is_empty() {
// info!(
// "Moving {} threads from next_threads to active_threads",
// self.next_threads.len()
// );
// self.active_threads.append(&mut self.next_threads);
// }

// // update the active input vals to reflect the current state
// // for each thread, get its current input_vals (read-only clone)
// active_input_vals = self.evaluator.input_vals();

// if let Some(prev_vals) = previous_input_vals {
// if prev_vals == active_input_vals {
// break;
// }
// }

// // if we've exceeded the max number of iterations before convergence,
// // return an ExecutionLimitExceeded error on every thread.
// // we should be able to theoretically show convergence is always possible, however
// if iters > MAX_ITERS {
// for thread in &self.active_threads {
// self.results[thread.todo_idx] =
// Err(ExecutionError::execution_limit_exceeded(MAX_ITERS));
// }
// // Emit diagnostics for all errors before returning
// self.emit_all_diagnostics();
// return self.results.clone();
// }

// info!("Active Input Vals {:?}", active_input_vals);

// // change the previous input vals to equal the active input vals
// previous_input_vals = Some(active_input_vals);

// iters += 1;
// }

// achieved convergence, run one more time with assertions on
info!("Achieved Convergence. Running once more with assertions enabled...");
// info!("Achieved Convergence. Running once more with assertions enabled...");
self.evaluator.enable_assertions();
// Disable forks when we run all threads till the next
self.run_all_active_until_next_step(false);
self.run_all_active_until_next_step(true);

// Move each active thread into inactive or next
while let Some(mut active_thread) = self.active_threads.pop() {
Expand Down
32 changes: 32 additions & 0 deletions protocols/tests/identities/dual_identity_d0/dual_identity_d0.prot
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
struct DualIdentity {
in a: u1,
in b: u1,
out s1: u1,
out s2: u1,
}

fn identity<dut: DualIdentity>(in a: u1, in b: u1) {
dut.a := 1'b0;
dut.b := 1'b0;

step();

assert_eq(dut.s1, 1'b0);
assert_eq(dut.s2, 1'b0);

if (dut.s2 == 1'b0) {
// during fixed point, dut.a is set to a
// even though the else branch is taken at fixed point
dut.a := a;
dut.b := b;
} else {
dut.b := b;
}


assert_eq(dut.s2, b);

step();
fork();
step();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// ARGS: --verilog identities/dual_identity_d0/dual_identity_d0.v --protocol identities/dual_identity_d0/dual_identity_d0.prot
// RETURN: 1
identity(1, 1);
10 changes: 10 additions & 0 deletions protocols/tests/identities/dual_identity_d0/dual_identity_d0.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module identity_d0 (
input clk,
input a,
input b,
output s1,
output s2,
);
assign s1 = a;
assign s2 = b;
endmodule