Skip to content

Commit d6e693a

Browse files
committed
Merge BlockstreamResearch#212: Typed value
e5afbb0 Bit Machine: Handle wrong input during execution (Christian Lewe) f562d73 Bit Machine: Handle empty input (Christian Lewe) 93a83b1 Bit Machine: Check type of input (Christian Lewe) df31202 Value: Check for emptiness (Christian Lewe) e99cf03 Value: Test type checker (Christian Lewe) d27b143 Value: Add type checker (Christian Lewe) 5e7616a Value: Add conversion methods (Christian Lewe) fac715d Final: split_* -> as_* (Christian Lewe) Pull request description: Add type checker to values and use it in the Bit Machine. This is particularly useful for running expressions with non-unit input types. ACKs for top commit: apoelstra: ACK e5afbb0 Tree-SHA512: 4cea989243d25f2161941b97a98b39888e7ccb8821b7448e7f424321c78888b20932f1ef22738c79ff389b51183c791b876337d66d88becfdfda62e826e044d7
2 parents 7b5b738 + e5afbb0 commit d6e693a

File tree

4 files changed

+135
-31
lines changed

4 files changed

+135
-31
lines changed

src/bit_machine/mod.rs

+29-16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::analysis;
1616
use crate::dag::{DagLike, NoSharing};
1717
use crate::jet::{Jet, JetFailed};
1818
use crate::node::{self, RedeemNode};
19+
use crate::types::Final;
1920
use crate::{Cmr, FailEntropy, Value};
2021
use frame::Frame;
2122

@@ -30,6 +31,8 @@ pub struct BitMachine {
3031
read: Vec<Frame>,
3132
/// Write frame stack
3233
write: Vec<Frame>,
34+
/// Acceptable source type
35+
source_ty: Arc<Final>,
3336
}
3437

3538
impl BitMachine {
@@ -42,6 +45,7 @@ impl BitMachine {
4245
next_frame_start: 0,
4346
read: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
4447
write: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
48+
source_ty: program.arrow().source.clone(),
4549
}
4650
}
4751

@@ -193,11 +197,17 @@ impl BitMachine {
193197

194198
/// Add a read frame with some given value in it, as input to the
195199
/// program
196-
pub fn input(&mut self, input: &Value) {
197-
// FIXME typecheck this
198-
self.new_frame(input.len());
199-
self.write_value(input);
200-
self.move_frame();
200+
pub fn input(&mut self, input: &Value) -> Result<(), ExecutionError> {
201+
if !input.is_of_type(&self.source_ty) {
202+
return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
203+
}
204+
// Unit value doesn't need extra frame
205+
if !input.is_empty() {
206+
self.new_frame(input.len());
207+
self.write_value(input);
208+
self.move_frame();
209+
}
210+
Ok(())
201211
}
202212

203213
/// Execute the given program on the Bit Machine, using the given environment.
@@ -229,16 +239,14 @@ impl BitMachine {
229239
}
230240
}
231241

242+
if self.read.is_empty() != self.source_ty.is_empty() {
243+
return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
244+
}
245+
232246
let mut ip = program;
233247
let mut call_stack = vec![];
234248
let mut iterations = 0u64;
235249

236-
let input_width = ip.arrow().source.bit_width();
237-
// TODO: convert into crate::Error
238-
assert!(
239-
self.read.is_empty() || input_width > 0,
240-
"Program requires a non-empty input to execute",
241-
);
242250
let output_width = ip.arrow().target.bit_width();
243251
if output_width > 0 {
244252
self.new_frame(output_width);
@@ -257,14 +265,14 @@ impl BitMachine {
257265
self.copy(size_a);
258266
}
259267
node::Inner::InjL(left) => {
260-
let (b, _c) = ip.arrow().target.split_sum().unwrap();
268+
let (b, _c) = ip.arrow().target.as_sum().unwrap();
261269
let padl_b_c = ip.arrow().target.bit_width() - b.bit_width() - 1;
262270
self.write_bit(false);
263271
self.skip(padl_b_c);
264272
call_stack.push(CallStack::Goto(left));
265273
}
266274
node::Inner::InjR(left) => {
267-
let (_b, c) = ip.arrow().target.split_sum().unwrap();
275+
let (_b, c) = ip.arrow().target.as_sum().unwrap();
268276
let padr_b_c = ip.arrow().target.bit_width() - c.bit_width() - 1;
269277
self.write_bit(true);
270278
self.skip(padr_b_c);
@@ -305,16 +313,16 @@ impl BitMachine {
305313
}
306314
node::Inner::Take(left) => call_stack.push(CallStack::Goto(left)),
307315
node::Inner::Drop(left) => {
308-
let size_a = ip.arrow().source.split_product().unwrap().0.bit_width();
316+
let size_a = ip.arrow().source.as_product().unwrap().0.bit_width();
309317
self.fwd(size_a);
310318
call_stack.push(CallStack::Back(size_a));
311319
call_stack.push(CallStack::Goto(left));
312320
}
313321
node::Inner::Case(..) | node::Inner::AssertL(..) | node::Inner::AssertR(..) => {
314322
let choice_bit = self.read[self.read.len() - 1].peek_bit(&self.data);
315323

316-
let (sum_a_b, _c) = ip.arrow().source.split_product().unwrap();
317-
let (a, b) = sum_a_b.split_sum().unwrap();
324+
let (sum_a_b, _c) = ip.arrow().source.as_product().unwrap();
325+
let (a, b) = sum_a_b.as_sum().unwrap();
318326
let size_a = a.bit_width();
319327
let size_b = b.bit_width();
320328

@@ -484,6 +492,8 @@ impl BitMachine {
484492
/// Errors related to simplicity Execution
485493
#[derive(Debug)]
486494
pub enum ExecutionError {
495+
/// Provided input is of wrong type
496+
InputWrongType(Arc<Final>),
487497
/// Reached a fail node
488498
ReachedFailNode(FailEntropy),
489499
/// Reached a pruned branch
@@ -495,6 +505,9 @@ pub enum ExecutionError {
495505
impl fmt::Display for ExecutionError {
496506
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
497507
match self {
508+
ExecutionError::InputWrongType(expected_ty) => {
509+
write!(f, "Expected input of type: {expected_ty}")
510+
}
498511
ExecutionError::ReachedFailNode(entropy) => {
499512
write!(f, "Execution reached a fail node: {}", entropy)
500513
}

src/merkle/amr.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl Amr {
4848
/// Produce a CMR for an injl combinator
4949
pub fn injl(ty: &FinalArrow, child: Amr) -> Self {
5050
let a = &ty.source;
51-
let (b, c) = ty.target.split_sum().unwrap();
51+
let (b, c) = ty.target.as_sum().unwrap();
5252
Self::INJL_IV
5353
.update(a.tmr().into(), b.tmr().into())
5454
.update(c.tmr().into(), child)
@@ -57,15 +57,15 @@ impl Amr {
5757
/// Produce a CMR for an injr combinator
5858
pub fn injr(ty: &FinalArrow, child: Amr) -> Self {
5959
let a = &ty.source;
60-
let (b, c) = ty.target.split_sum().unwrap();
60+
let (b, c) = ty.target.as_sum().unwrap();
6161
Self::INJR_IV
6262
.update(a.tmr().into(), b.tmr().into())
6363
.update(c.tmr().into(), child)
6464
}
6565

6666
/// Produce a CMR for a take combinator
6767
pub fn take(ty: &FinalArrow, child: Amr) -> Self {
68-
let (a, b) = ty.source.split_product().unwrap();
68+
let (a, b) = ty.source.as_sum().unwrap();
6969
let c = &ty.target;
7070
Self::TAKE_IV
7171
.update(a.tmr().into(), b.tmr().into())
@@ -74,7 +74,7 @@ impl Amr {
7474

7575
/// Produce a CMR for a drop combinator
7676
pub fn drop(ty: &FinalArrow, child: Amr) -> Self {
77-
let (a, b) = ty.source.split_product().unwrap();
77+
let (a, b) = ty.source.as_product().unwrap();
7878
let c = &ty.target;
7979
Self::DROP_IV
8080
.update(a.tmr().into(), b.tmr().into())
@@ -93,8 +93,8 @@ impl Amr {
9393
}
9494

9595
fn case_helper(iv: Amr, ty: &FinalArrow, left: Amr, right: Amr) -> Self {
96-
let (sum_a_b, c) = ty.source.split_product().unwrap();
97-
let (a, b) = sum_a_b.split_sum().unwrap();
96+
let (sum_a_b, c) = ty.source.as_product().unwrap();
97+
let (a, b) = sum_a_b.as_sum().unwrap();
9898
let d = &ty.target;
9999
iv.update(a.tmr().into(), b.tmr().into())
100100
.update(c.tmr().into(), d.tmr().into())
@@ -136,7 +136,7 @@ impl Amr {
136136
/// Produce a CMR for a disconnect combinator
137137
pub fn disconnect(ty: &FinalArrow, right_arrow: &FinalArrow, left: Amr, right: Amr) -> Self {
138138
let a = &ty.source;
139-
let (b, d) = ty.target.split_product().unwrap();
139+
let (b, d) = ty.target.as_product().unwrap();
140140
let c = &right_arrow.source;
141141
Self::DISCONNECT_IV
142142
.update(a.tmr().into(), b.tmr().into())

src/types/final_data.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -191,28 +191,34 @@ impl Final {
191191
self.bit_width
192192
}
193193

194+
/// Check if the type is a nested product of units.
195+
/// In this case, values contain no information.
196+
pub fn is_empty(&self) -> bool {
197+
self.bit_width() == 0
198+
}
199+
194200
/// Accessor for the type bound
195201
pub fn bound(&self) -> &CompleteBound {
196202
&self.bound
197203
}
198204

199-
/// Returns whether this is the unit type
205+
/// Check if the type is a unit.
200206
pub fn is_unit(&self) -> bool {
201207
self.bound == CompleteBound::Unit
202208
}
203209

204-
/// Return both children, if the type is a sum type
205-
pub fn split_sum(&self) -> Option<(Arc<Self>, Arc<Self>)> {
210+
/// Access the inner types of a sum type.
211+
pub fn as_sum(&self) -> Option<(&Self, &Self)> {
206212
match &self.bound {
207-
CompleteBound::Sum(left, right) => Some((left.clone(), right.clone())),
213+
CompleteBound::Sum(left, right) => Some((left.as_ref(), right.as_ref())),
208214
_ => None,
209215
}
210216
}
211217

212-
/// Return both children, if the type is a product type
213-
pub fn split_product(&self) -> Option<(Arc<Self>, Arc<Self>)> {
218+
/// Access the inner types of a product type.
219+
pub fn as_product(&self) -> Option<(&Self, &Self)> {
214220
match &self.bound {
215-
CompleteBound::Product(left, right) => Some((left.clone(), right.clone())),
221+
CompleteBound::Product(left, right) => Some((left.as_ref(), right.as_ref())),
216222
_ => None,
217223
}
218224
}

src/value.rs

+86-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! i.e., inputs, intermediate results and outputs.
77
88
use crate::dag::{Dag, DagLike, NoSharing};
9+
use crate::types::Final;
910

1011
use std::collections::VecDeque;
1112
use std::convert::TryInto;
@@ -83,14 +84,48 @@ impl Value {
8384
Arc::new(Value::Prod(left, right))
8485
}
8586

86-
#[allow(clippy::len_without_is_empty)]
8787
/// The length, in bits, of the value when encoded in the Bit Machine
8888
pub fn len(&self) -> usize {
8989
self.pre_order_iter::<NoSharing>()
9090
.filter(|inner| matches!(inner, Value::SumL(_) | Value::SumR(_)))
9191
.count()
9292
}
9393

94+
/// Check if the value is a nested product of units.
95+
/// In this case, the value contains no information.
96+
pub fn is_empty(&self) -> bool {
97+
self.len() == 0
98+
}
99+
100+
/// Check if the value is a unit.
101+
pub fn is_unit(&self) -> bool {
102+
matches!(self, Value::Unit)
103+
}
104+
105+
/// Access the inner value of a left sum value.
106+
pub fn as_left(&self) -> Option<&Self> {
107+
match self {
108+
Value::SumL(inner) => Some(inner.as_ref()),
109+
_ => None,
110+
}
111+
}
112+
113+
/// Access the inner value of a right sum value.
114+
pub fn as_right(&self) -> Option<&Self> {
115+
match self {
116+
Value::SumR(inner) => Some(inner.as_ref()),
117+
_ => None,
118+
}
119+
}
120+
121+
/// Access the inner values of a product value.
122+
pub fn as_product(&self) -> Option<(&Self, &Self)> {
123+
match self {
124+
Value::Prod(left, right) => Some((left.as_ref(), right.as_ref())),
125+
_ => None,
126+
}
127+
}
128+
94129
/// Encode a single bit as a value. Will panic if the input is out of range
95130
pub fn u1(n: u8) -> Arc<Self> {
96131
match n {
@@ -264,6 +299,36 @@ impl Value {
264299

265300
(bytes, bit_length)
266301
}
302+
303+
/// Check if the value is of the given type.
304+
pub fn is_of_type(&self, ty: &Final) -> bool {
305+
let mut stack = vec![(self, ty)];
306+
307+
while let Some((value, ty)) = stack.pop() {
308+
if ty.is_unit() {
309+
if !value.is_unit() {
310+
return false;
311+
}
312+
} else if let Some((ty_l, ty_r)) = ty.as_sum() {
313+
if let Some(value_l) = value.as_left() {
314+
stack.push((value_l, ty_l));
315+
} else if let Some(value_r) = value.as_right() {
316+
stack.push((value_r, ty_r));
317+
} else {
318+
return false;
319+
}
320+
} else if let Some((ty_l, ty_r)) = ty.as_product() {
321+
if let Some((value_l, value_r)) = value.as_product() {
322+
stack.push((value_r, ty_r));
323+
stack.push((value_l, ty_l));
324+
} else {
325+
return false;
326+
}
327+
}
328+
}
329+
330+
true
331+
}
267332
}
268333

269334
impl fmt::Debug for Value {
@@ -308,6 +373,7 @@ impl fmt::Display for Value {
308373
#[cfg(test)]
309374
mod tests {
310375
use super::*;
376+
use crate::jet::type_name::TypeName;
311377

312378
#[test]
313379
fn value_display() {
@@ -317,4 +383,23 @@ mod tests {
317383
assert_eq!(Value::u1(1).to_string(), "1",);
318384
assert_eq!(Value::u4(6).to_string(), "((0,1),(1,0))",);
319385
}
386+
387+
#[test]
388+
fn is_of_type() {
389+
let value_typename = [
390+
(Value::unit(), TypeName(b"1")),
391+
(Value::sum_l(Value::unit()), TypeName(b"+11")),
392+
(Value::sum_r(Value::unit()), TypeName(b"+11")),
393+
(Value::sum_l(Value::unit()), TypeName(b"+1h")),
394+
(Value::sum_r(Value::unit()), TypeName(b"+h1")),
395+
(Value::prod(Value::unit(), Value::unit()), TypeName(b"*11")),
396+
(Value::u8(u8::MAX), TypeName(b"c")),
397+
(Value::u64(u64::MAX), TypeName(b"l")),
398+
];
399+
400+
for (value, typename) in value_typename {
401+
let ty = typename.to_final();
402+
assert!(value.is_of_type(ty.as_ref()));
403+
}
404+
}
320405
}

0 commit comments

Comments
 (0)