diff --git a/src/descriptor/mod.rs b/src/descriptor/mod.rs index a8fd23b0d..c00db0b65 100644 --- a/src/descriptor/mod.rs +++ b/src/descriptor/mod.rs @@ -1492,7 +1492,7 @@ mod tests { #[test] fn roundtrip_tests() { let descriptor = Descriptor::::from_str("multi"); - assert_eq!(descriptor.unwrap_err().to_string(), "unexpected «no arguments given»") + assert_eq!(descriptor.unwrap_err().to_string(), "expected threshold, found terminal",); } #[test] diff --git a/src/descriptor/sortedmulti.rs b/src/descriptor/sortedmulti.rs index 11101c5bb..5b7079292 100644 --- a/src/descriptor/sortedmulti.rs +++ b/src/descriptor/sortedmulti.rs @@ -32,15 +32,21 @@ pub struct SortedMultiVec { } impl SortedMultiVec { - fn constructor_check(&self) -> Result<(), Error> { + fn constructor_check(mut self) -> Result { // Check the limits before creating a new SortedMultiVec // For example, under p2sh context the scriptlen can only be // upto 520 bytes. - let term: Terminal = Terminal::Multi(self.inner.k(), self.inner.data().to_owned()); + let term: Terminal = Terminal::Multi(self.inner); let ms = Miniscript::from_ast(term)?; // This would check all the consensus rules for p2sh/p2wsh and // even tapscript in future - Ctx::check_local_validity(&ms).map_err(From::from) + Ctx::check_local_validity(&ms)?; + if let Terminal::Multi(inner) = ms.node { + self.inner = inner; + Ok(self) + } else { + unreachable!() + } } /// Create a new instance of `SortedMultiVec` given a list of keys and the threshold @@ -49,8 +55,7 @@ impl SortedMultiVec { pub fn new(k: usize, pks: Vec) -> Result { let ret = Self { inner: Threshold::new(k, pks).map_err(Error::Threshold)?, phantom: PhantomData }; - ret.constructor_check()?; - Ok(ret) + ret.constructor_check() } /// Parse an expression tree into a SortedMultiVec @@ -66,8 +71,7 @@ impl SortedMultiVec { .translate_by_index(|i| expression::terminal(&tree.args[i + 1], Pk::from_str))?, phantom: PhantomData, }; - ret.constructor_check()?; - Ok(ret) + ret.constructor_check() } /// This will panic if fpk returns an uncompressed key when @@ -85,8 +89,7 @@ impl SortedMultiVec { inner: self.inner.translate_ref(|pk| t.pk(pk))?, phantom: PhantomData, }; - ret.constructor_check().map_err(TranslateErr::OuterError)?; - Ok(ret) + ret.constructor_check().map_err(TranslateErr::OuterError) } /// The threshold value for the multisig. @@ -113,11 +116,8 @@ impl SortedMultiVec { /// utility function to sanity a sorted multi vec pub fn sanity_check(&self) -> Result<(), Error> { let ms: Miniscript = - Miniscript::from_ast(Terminal::Multi(self.k(), self.pks().to_owned())) - .expect("Must typecheck"); - // '?' for doing From conversion - ms.sanity_check()?; - Ok(()) + Miniscript::from_ast(Terminal::Multi(self.inner.clone())).expect("Must typecheck"); + ms.sanity_check().map_err(From::from) } } @@ -127,16 +127,16 @@ impl SortedMultiVec { where Pk: ToPublicKey, { - let mut pks = self.pks().to_owned(); + let mut thresh = self.inner.clone(); // Sort pubkeys lexicographically according to BIP 67 - pks.sort_by(|a, b| { + thresh.data_mut().sort_by(|a, b| { a.to_public_key() .inner .serialize() .partial_cmp(&b.to_public_key().inner.serialize()) .unwrap() }); - Terminal::Multi(self.k(), pks) + Terminal::Multi(thresh) } /// Encode as a Bitcoin script diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index af405bda9..d9bc3ebab 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -845,9 +845,9 @@ where None => return Some(Err(Error::UnexpectedStackEnd)), } } - Terminal::MultiA(k, ref subs) => { - if node_state.n_evaluated == subs.len() { - if node_state.n_satisfied == k { + Terminal::MultiA(ref thresh) => { + if node_state.n_evaluated == thresh.n() { + if node_state.n_satisfied == thresh.k() { self.stack.push(stack::Element::Satisfied); } else { self.stack.push(stack::Element::Dissatisfied); @@ -856,10 +856,10 @@ where // evaluate each key with as a pk // note that evaluate_pk will error on non-empty incorrect sigs // push 1 on satisfied sigs and push 0 on empty sigs - match self - .stack - .evaluate_pk(&mut self.verify_sig, subs[node_state.n_evaluated]) - { + match self.stack.evaluate_pk( + &mut self.verify_sig, + thresh.data()[node_state.n_evaluated], + ) { Some(Ok(x)) => { self.push_evaluation_state( node_state.node, @@ -886,9 +886,9 @@ where } } } - Terminal::Multi(ref k, ref subs) if node_state.n_evaluated == 0 => { + Terminal::Multi(ref thresh) if node_state.n_evaluated == 0 => { let len = self.stack.len(); - if len < k + 1 { + if len < thresh.k() + 1 { return Some(Err(Error::InsufficientSignaturesMultiSig)); } else { //Non-sat case. If the first sig is empty, others k elements must @@ -896,13 +896,13 @@ where match self.stack.last() { Some(&stack::Element::Dissatisfied) => { //Remove the extra zero from multi-sig check - let sigs = self.stack.split_off(len - (k + 1)); + let sigs = self.stack.split_off(len - (thresh.k() + 1)); let nonsat = sigs .iter() .map(|sig| *sig == stack::Element::Dissatisfied) .filter(|empty| *empty) .count(); - if nonsat == *k + 1 { + if nonsat == thresh.k() + 1 { self.stack.push(stack::Element::Dissatisfied); } else { return Some(Err(Error::MissingExtraZeroMultiSig)); @@ -910,10 +910,10 @@ where } None => return Some(Err(Error::UnexpectedStackEnd)), _ => { - match self - .stack - .evaluate_multi(&mut self.verify_sig, &subs[subs.len() - 1]) - { + match self.stack.evaluate_multi( + &mut self.verify_sig, + &thresh.data()[thresh.n() - 1], + ) { Some(Ok(x)) => { self.push_evaluation_state( node_state.node, @@ -933,20 +933,20 @@ where } } } - Terminal::Multi(k, ref subs) => { - if node_state.n_satisfied == k { + Terminal::Multi(ref thresh) => { + if node_state.n_satisfied == thresh.k() { //multi-sig bug: Pop extra 0 if let Some(stack::Element::Dissatisfied) = self.stack.pop() { self.stack.push(stack::Element::Satisfied); } else { return Some(Err(Error::MissingExtraZeroMultiSig)); } - } else if node_state.n_evaluated == subs.len() { + } else if node_state.n_evaluated == thresh.n() { return Some(Err(Error::MultiSigEvaluationError)); } else { match self.stack.evaluate_multi( &mut self.verify_sig, - &subs[subs.len() - node_state.n_evaluated - 1], + &thresh.data()[thresh.n() - node_state.n_evaluated - 1], ) { Some(Ok(x)) => { self.push_evaluation_state( diff --git a/src/miniscript/astelem.rs b/src/miniscript/astelem.rs index d11781a5e..ba2ec6fbd 100644 --- a/src/miniscript/astelem.rs +++ b/src/miniscript/astelem.rs @@ -82,8 +82,20 @@ impl Terminal { fmt_2(f, "or_i(", l, r, is_debug) } Terminal::Thresh(k, ref subs) => fmt_n(f, "thresh(", k, subs, is_debug), - Terminal::Multi(k, ref keys) => fmt_n(f, "multi(", k, keys, is_debug), - Terminal::MultiA(k, ref keys) => fmt_n(f, "multi_a(", k, keys, is_debug), + Terminal::Multi(ref thresh) => { + if is_debug { + fmt::Debug::fmt(&thresh.debug("multi", true), f) + } else { + fmt::Display::fmt(&thresh.display("multi", true), f) + } + } + Terminal::MultiA(ref thresh) => { + if is_debug { + fmt::Debug::fmt(&thresh.debug("multi_a", true), f) + } else { + fmt::Display::fmt(&thresh.display("multi_a", true), f) + } + } // wrappers _ => { if let Some((ch, sub)) = self.wrap_char() { @@ -314,27 +326,16 @@ impl crate::expression::FromTree for Termina Ok(Terminal::Thresh(k, subs?)) } - ("multi", n) | ("multi_a", n) => { - if n == 0 { - return Err(errstr("no arguments given")); - } - let k = expression::terminal(&top.args[0], expression::parse_num)? as usize; - if k > n - 1 { - return Err(errstr("higher threshold than there were keys in multi")); - } - - let pks: Result, _> = top.args[1..] - .iter() - .map(|sub| expression::terminal(sub, Pk::from_str)) - .collect(); - - if frag_name == "multi" { - pks.map(|pks| Terminal::Multi(k, pks)) - } else { - // must be multi_a - pks.map(|pks| Terminal::MultiA(k, pks)) - } - } + ("multi", _) => top + .to_null_threshold() + .map_err(Error::ParseThreshold)? + .translate_by_index(|i| expression::terminal(&top.args[1 + i], Pk::from_str)) + .map(Terminal::Multi), + ("multi_a", _) => top + .to_null_threshold() + .map_err(Error::ParseThreshold)? + .translate_by_index(|i| expression::terminal(&top.args[1 + i], Pk::from_str)) + .map(Terminal::MultiA), _ => Err(Error::Unexpected(format!( "{}({} args) while parsing Miniscript", top.name, @@ -483,27 +484,27 @@ impl Terminal { .push_int(k as i64) .push_opcode(opcodes::all::OP_EQUAL) } - Terminal::Multi(k, ref keys) => { + Terminal::Multi(ref thresh) => { debug_assert!(Ctx::sig_type() == SigType::Ecdsa); - builder = builder.push_int(k as i64); - for pk in keys { + builder = builder.push_int(thresh.k() as i64); + for pk in thresh.data() { builder = builder.push_key(&pk.to_public_key()); } builder - .push_int(keys.len() as i64) + .push_int(thresh.n() as i64) .push_opcode(opcodes::all::OP_CHECKMULTISIG) } - Terminal::MultiA(k, ref keys) => { + Terminal::MultiA(ref thresh) => { debug_assert!(Ctx::sig_type() == SigType::Schnorr); // keys must be atleast len 1 here, guaranteed by typing rules - builder = builder.push_ms_key::<_, Ctx>(&keys[0]); + builder = builder.push_ms_key::<_, Ctx>(&thresh.data()[0]); builder = builder.push_opcode(opcodes::all::OP_CHECKSIG); - for pk in keys.iter().skip(1) { + for pk in thresh.iter().skip(1) { builder = builder.push_ms_key::<_, Ctx>(pk); builder = builder.push_opcode(opcodes::all::OP_CHECKSIGADD); } builder - .push_int(k as i64) + .push_int(thresh.k() as i64) .push_opcode(opcodes::all::OP_NUMEQUAL) } } diff --git a/src/miniscript/context.rs b/src/miniscript/context.rs index 9dc2f828e..584ff4adf 100644 --- a/src/miniscript/context.rs +++ b/src/miniscript/context.rs @@ -10,9 +10,8 @@ use bitcoin::Weight; use super::decode::ParseableKey; use crate::miniscript::limits::{ - MAX_OPS_PER_SCRIPT, MAX_PUBKEYS_PER_MULTISIG, MAX_SCRIPTSIG_SIZE, MAX_SCRIPT_ELEMENT_SIZE, - MAX_SCRIPT_SIZE, MAX_STACK_SIZE, MAX_STANDARD_P2WSH_SCRIPT_SIZE, - MAX_STANDARD_P2WSH_STACK_ITEMS, + MAX_OPS_PER_SCRIPT, MAX_SCRIPTSIG_SIZE, MAX_SCRIPT_ELEMENT_SIZE, MAX_SCRIPT_SIZE, + MAX_STACK_SIZE, MAX_STANDARD_P2WSH_SCRIPT_SIZE, MAX_STANDARD_P2WSH_STACK_ITEMS, }; use crate::miniscript::types; use crate::prelude::*; @@ -61,8 +60,6 @@ pub enum ScriptContextError { TaprootMultiDisabled, /// Stack size exceeded in script execution StackSizeLimitExceeded { actual: usize, limit: usize }, - /// More than 20 keys in a Multi fragment - CheckMultiSigLimitExceeded, /// MultiA is only allowed in post tapscript MultiANotAllowed, } @@ -87,7 +84,6 @@ impl error::Error for ScriptContextError { | ImpossibleSatisfaction | TaprootMultiDisabled | StackSizeLimitExceeded { .. } - | CheckMultiSigLimitExceeded | MultiANotAllowed => None, } } @@ -149,9 +145,6 @@ impl fmt::Display for ScriptContextError { actual, limit ) } - ScriptContextError::CheckMultiSigLimitExceeded => { - write!(f, "CHECkMULTISIG ('multi()' descriptor) only supports up to 20 pubkeys") - } ScriptContextError::MultiANotAllowed => { write!(f, "Multi a(CHECKSIGADD) only allowed post tapscript") } @@ -405,11 +398,8 @@ impl ScriptContext for Legacy { match ms.node { Terminal::PkK(ref pk) => Self::check_pk(pk), - Terminal::Multi(_k, ref pks) => { - if pks.len() > MAX_PUBKEYS_PER_MULTISIG { - return Err(ScriptContextError::CheckMultiSigLimitExceeded); - } - for pk in pks.iter() { + Terminal::Multi(ref thresh) => { + for pk in thresh.iter() { Self::check_pk(pk)?; } Ok(()) @@ -506,11 +496,8 @@ impl ScriptContext for Segwitv0 { match ms.node { Terminal::PkK(ref pk) => Self::check_pk(pk), - Terminal::Multi(_k, ref pks) => { - if pks.len() > MAX_PUBKEYS_PER_MULTISIG { - return Err(ScriptContextError::CheckMultiSigLimitExceeded); - } - for pk in pks.iter() { + Terminal::Multi(ref thresh) => { + for pk in thresh.iter() { Self::check_pk(pk)?; } Ok(()) @@ -620,8 +607,8 @@ impl ScriptContext for Tap { match ms.node { Terminal::PkK(ref pk) => Self::check_pk(pk), - Terminal::MultiA(_, ref keys) => { - for pk in keys.iter() { + Terminal::MultiA(ref thresh) => { + for pk in thresh.iter() { Self::check_pk(pk)?; } Ok(()) @@ -716,11 +703,8 @@ impl ScriptContext for BareCtx { } match ms.node { Terminal::PkK(ref key) => Self::check_pk(key), - Terminal::Multi(_k, ref pks) => { - if pks.len() > MAX_PUBKEYS_PER_MULTISIG { - return Err(ScriptContextError::CheckMultiSigLimitExceeded); - } - for pk in pks.iter() { + Terminal::Multi(ref thresh) => { + for pk in thresh.iter() { Self::check_pk(pk)?; } Ok(()) @@ -749,7 +733,7 @@ impl ScriptContext for BareCtx { Terminal::PkK(_pk) | Terminal::PkH(_pk) => Ok(()), _ => Err(Error::NonStandardBareScript), }, - Terminal::Multi(_k, subs) if subs.len() <= 3 => Ok(()), + Terminal::Multi(ref thresh) if thresh.n() <= 3 => Ok(()), _ => Err(Error::NonStandardBareScript), } } diff --git a/src/miniscript/decode.rs b/src/miniscript/decode.rs index 6500eb8d5..6f6448321 100644 --- a/src/miniscript/decode.rs +++ b/src/miniscript/decode.rs @@ -21,7 +21,9 @@ use crate::miniscript::ScriptContext; use crate::prelude::*; #[cfg(doc)] use crate::Descriptor; -use crate::{hash256, AbsLockTime, Error, Miniscript, MiniscriptKey, RelLockTime, ToPublicKey}; +use crate::{ + hash256, AbsLockTime, Error, Miniscript, MiniscriptKey, RelLockTime, Threshold, ToPublicKey, +}; /// Trait for parsing keys from byte slices pub trait ParseableKey: Sized + ToPublicKey + private::Sealed { @@ -181,9 +183,9 @@ pub enum Terminal { /// `[E] ([W] ADD)* k EQUAL` Thresh(usize, Vec>>), /// `k ()* n CHECKMULTISIG` - Multi(usize, Vec), + Multi(Threshold), /// ` CHECKSIG ( CHECKSIGADD)*(n-1) k NUMEQUAL` - MultiA(usize, Vec), + MultiA(Threshold), } macro_rules! match_token { @@ -428,6 +430,7 @@ pub fn parse( }, // CHECKMULTISIG based multisig Tk::CheckMultiSig, Tk::Num(n) => { + // Check size before allocating keys if n as usize > MAX_PUBKEYS_PER_MULTISIG { return Err(Error::CmsTooManyKeys(n)); } @@ -446,7 +449,8 @@ pub fn parse( Tk::Num(k) => k, ); keys.reverse(); - term.reduce0(Terminal::Multi(k as usize, keys))?; + let thresh = Threshold::new(k as usize, keys).map_err(Error::Threshold)?; + term.reduce0(Terminal::Multi(thresh))?; }, // MultiA Tk::NumEqual, Tk::Num(k) => { @@ -469,7 +473,8 @@ pub fn parse( .map_err(|e| Error::PubKeyCtxError(e, Ctx::name_str()))?), ); keys.reverse(); - term.reduce0(Terminal::MultiA(k as usize, keys))?; + let thresh = Threshold::new(k as usize, keys).map_err(Error::Threshold)?; + term.reduce0(Terminal::MultiA(thresh))?; }, ); } diff --git a/src/miniscript/iter.rs b/src/miniscript/iter.rs index 3b2adad1b..11cbe067e 100644 --- a/src/miniscript/iter.rs +++ b/src/miniscript/iter.rs @@ -29,7 +29,7 @@ impl Miniscript { /// them. pub fn branches(&self) -> Vec<&Miniscript> { match self.node { - Terminal::PkK(_) | Terminal::PkH(_) | Terminal::RawPkH(_) | Terminal::Multi(_, _) => { + Terminal::PkK(_) | Terminal::PkH(_) | Terminal::RawPkH(_) | Terminal::Multi(_) => { vec![] } @@ -94,10 +94,9 @@ impl Miniscript { /// NB: The function analyzes only single miniscript item and not any of its descendants in AST. pub fn get_nth_pk(&self, n: usize) -> Option { match (&self.node, n) { - (&Terminal::PkK(ref key), 0) | (&Terminal::PkH(ref key), 0) => Some(key.clone()), - (&Terminal::Multi(_, ref keys), _) | (&Terminal::MultiA(_, ref keys), _) => { - keys.get(n).cloned() - } + (Terminal::PkK(key), 0) | (Terminal::PkH(key), 0) => Some(key.clone()), + (Terminal::Multi(thresh), _) => thresh.data().get(n).cloned(), + (Terminal::MultiA(thresh), _) => thresh.data().get(n).cloned(), _ => None, } } diff --git a/src/miniscript/mod.rs b/src/miniscript/mod.rs index 26f220a9f..3fa79be48 100644 --- a/src/miniscript/mod.rs +++ b/src/miniscript/mod.rs @@ -155,17 +155,17 @@ impl Miniscript { + subs.len() // ADD - 1 // no ADD on first element } - Terminal::Multi(k, ref pks) => { - script_num_size(k) + Terminal::Multi(ref thresh) => { + script_num_size(thresh.k()) + 1 - + script_num_size(pks.len()) - + pks.iter().map(|pk| Ctx::pk_len(pk)).sum::() + + script_num_size(thresh.n()) + + thresh.iter().map(|pk| Ctx::pk_len(pk)).sum::() } - Terminal::MultiA(k, ref pks) => { - script_num_size(k) + Terminal::MultiA(ref thresh) => { + script_num_size(thresh.k()) + 1 // NUMEQUAL - + pks.iter().map(|pk| Ctx::pk_len(pk)).sum::() // n keys - + pks.len() // n times CHECKSIGADD + + thresh.iter().map(|pk| Ctx::pk_len(pk)).sum::() // n keys + + thresh.n() // n times CHECKSIGADD } } } @@ -415,8 +415,15 @@ impl ForEachKey for Miniscript { - if !keys.iter().all(&mut pred) { + // These branches cannot be combined since technically the two `thresh`es + // have different types (have different maximum values). + Terminal::Multi(ref thresh) => { + if !thresh.iter().all(&mut pred) { + return false; + } + } + Terminal::MultiA(ref thresh) => { + if !thresh.iter().all(&mut pred) { return false; } } @@ -488,13 +495,9 @@ impl Miniscript { Terminal::Thresh(k, ref subs) => { Terminal::Thresh(k, (0..subs.len()).map(child_n).collect()) } - Terminal::Multi(k, ref keys) => { - let keys: Result, _> = keys.iter().map(|k| t.pk(k)).collect(); - Terminal::Multi(k, keys?) - } - Terminal::MultiA(k, ref keys) => { - let keys: Result, _> = keys.iter().map(|k| t.pk(k)).collect(); - Terminal::MultiA(k, keys?) + Terminal::Multi(ref thresh) => Terminal::Multi(thresh.translate_ref(|k| t.pk(k))?), + Terminal::MultiA(ref thresh) => { + Terminal::MultiA(thresh.translate_ref(|k| t.pk(k))?) } }; let new_ms = Miniscript::from_ast(new_term).map_err(TranslateErr::OuterError)?; diff --git a/src/miniscript/satisfy.rs b/src/miniscript/satisfy.rs index 30123a433..ea320a9f6 100644 --- a/src/miniscript/satisfy.rs +++ b/src/miniscript/satisfy.rs @@ -1498,11 +1498,11 @@ impl Satisfaction> { Terminal::Thresh(k, ref subs) => { thresh_fn(k, subs, stfr, root_has_sig, leaf_hash, min_fn) } - Terminal::Multi(k, ref keys) => { + Terminal::Multi(ref thresh) => { // Collect all available signatures let mut sig_count = 0; - let mut sigs = Vec::with_capacity(k); - for pk in keys { + let mut sigs = Vec::with_capacity(thresh.k()); + for pk in thresh.data() { match Witness::signature::<_, Ctx>(stfr, pk, leaf_hash) { Witness::Stack(sig) => { sigs.push(sig); @@ -1515,7 +1515,7 @@ impl Satisfaction> { } } - if sig_count < k { + if sig_count < thresh.k() { Satisfaction { stack: Witness::Impossible, has_sig: false, @@ -1524,7 +1524,7 @@ impl Satisfaction> { } } else { // Throw away the most expensive ones - for _ in 0..sig_count - k { + for _ in 0..sig_count - thresh.k() { let max_idx = sigs .iter() .enumerate() @@ -1544,11 +1544,11 @@ impl Satisfaction> { } } } - Terminal::MultiA(k, ref keys) => { + Terminal::MultiA(ref thresh) => { // Collect all available signatures let mut sig_count = 0; - let mut sigs = vec![vec![Placeholder::PushZero]; keys.len()]; - for (i, pk) in keys.iter().rev().enumerate() { + let mut sigs = vec![vec![Placeholder::PushZero]; thresh.n()]; + for (i, pk) in thresh.iter().rev().enumerate() { match Witness::signature::<_, Ctx>(stfr, pk, leaf_hash) { Witness::Stack(sig) => { sigs[i] = sig; @@ -1557,7 +1557,7 @@ impl Satisfaction> { // sigs. Incase pk at pos 1 is not selected, we know we did not have access to it // bitcoin core also implements the same logic for MULTISIG, so I am not bothering // permuting the sigs for now - if sig_count == k { + if sig_count == thresh.k() { break; } } @@ -1568,7 +1568,7 @@ impl Satisfaction> { } } - if sig_count < k { + if sig_count < thresh.k() { Satisfaction { stack: Witness::Impossible, has_sig: false, @@ -1772,14 +1772,14 @@ impl Satisfaction> { relative_timelock: None, absolute_timelock: None, }, - Terminal::Multi(k, _) => Satisfaction { - stack: Witness::Stack(vec![Placeholder::PushZero; k + 1]), + Terminal::Multi(ref thresh) => Satisfaction { + stack: Witness::Stack(vec![Placeholder::PushZero; thresh.k() + 1]), has_sig: false, relative_timelock: None, absolute_timelock: None, }, - Terminal::MultiA(_, ref pks) => Satisfaction { - stack: Witness::Stack(vec![Placeholder::PushZero; pks.len()]), + Terminal::MultiA(ref thresh) => Satisfaction { + stack: Witness::Stack(vec![Placeholder::PushZero; thresh.n()]), has_sig: false, relative_timelock: None, absolute_timelock: None, diff --git a/src/miniscript/types/extra_props.rs b/src/miniscript/types/extra_props.rs index dbdf1657c..983dcc570 100644 --- a/src/miniscript/types/extra_props.rs +++ b/src/miniscript/types/extra_props.rs @@ -902,25 +902,8 @@ impl ExtData { Terminal::False => Self::FALSE, Terminal::PkK(..) => Self::pk_k::(), Terminal::PkH(..) | Terminal::RawPkH(..) => Self::pk_h::(), - Terminal::Multi(k, ref pks) | Terminal::MultiA(k, ref pks) => { - if k == 0 { - return Err(Error { - fragment_string: fragment.to_string(), - error: ErrorKind::ZeroThreshold, - }); - } - if k > pks.len() { - return Err(Error { - fragment_string: fragment.to_string(), - error: ErrorKind::OverThreshold(k, pks.len()), - }); - } - match *fragment { - Terminal::Multi(..) => Self::multi(k, pks.len()), - Terminal::MultiA(..) => Self::multi_a(k, pks.len()), - _ => unreachable!(), - } - } + Terminal::Multi(ref thresh) => Self::multi(thresh.k(), thresh.n()), + Terminal::MultiA(ref thresh) => Self::multi_a(thresh.k(), thresh.n()), Terminal::After(t) => Self::after(t), Terminal::Older(t) => Self::older(t), Terminal::Sha256(..) => Self::sha256(), diff --git a/src/miniscript/types/mod.rs b/src/miniscript/types/mod.rs index 658331c99..7c2a9ccf0 100644 --- a/src/miniscript/types/mod.rs +++ b/src/miniscript/types/mod.rs @@ -436,25 +436,8 @@ impl Type { Terminal::False => Ok(Self::FALSE), Terminal::PkK(..) => Ok(Self::pk_k()), Terminal::PkH(..) | Terminal::RawPkH(..) => Ok(Self::pk_h()), - Terminal::Multi(k, ref pks) | Terminal::MultiA(k, ref pks) => { - if k == 0 { - return Err(Error { - fragment_string: fragment.to_string(), - error: ErrorKind::ZeroThreshold, - }); - } - if k > pks.len() { - return Err(Error { - fragment_string: fragment.to_string(), - error: ErrorKind::OverThreshold(k, pks.len()), - }); - } - match *fragment { - Terminal::Multi(..) => Ok(Self::multi()), - Terminal::MultiA(..) => Ok(Self::multi_a()), - _ => unreachable!(), - } - } + Terminal::Multi(..) => Ok(Self::multi()), + Terminal::MultiA(..) => Ok(Self::multi_a()), Terminal::After(_) => Ok(Self::time()), Terminal::Older(_) => Ok(Self::time()), Terminal::Sha256(..) => Ok(Self::hash()), diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index e81deb1a1..5d1dbd39a 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -12,7 +12,6 @@ use std::error; use sync::Arc; use crate::miniscript::context::SigType; -use crate::miniscript::limits::{MAX_PUBKEYS_IN_CHECKSIGADD, MAX_PUBKEYS_PER_MULTISIG}; use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; @@ -426,25 +425,8 @@ impl CompilerExtData { Terminal::False => Ok(Self::FALSE), Terminal::PkK(..) => Ok(Self::pk_k::()), Terminal::PkH(..) | Terminal::RawPkH(..) => Ok(Self::pk_h::()), - Terminal::Multi(k, ref pks) | Terminal::MultiA(k, ref pks) => { - if k == 0 { - return Err(types::Error { - fragment_string: fragment.to_string(), - error: types::ErrorKind::ZeroThreshold, - }); - } - if k > pks.len() { - return Err(types::Error { - fragment_string: fragment.to_string(), - error: types::ErrorKind::OverThreshold(k, pks.len()), - }); - } - match *fragment { - Terminal::Multi(..) => Ok(Self::multi(k, pks.len())), - Terminal::MultiA(..) => Ok(Self::multi_a(k, pks.len())), - _ => unreachable!(), - } - } + Terminal::Multi(ref thresh) => Ok(Self::multi(thresh.k(), thresh.n())), + Terminal::MultiA(ref thresh) => Ok(Self::multi_a(thresh.k(), thresh.n())), Terminal::After(_) => Ok(Self::time()), Terminal::Older(_) => Ok(Self::time()), Terminal::Sha256(..) => Ok(Self::hash()), @@ -1008,8 +990,9 @@ where compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI); compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI); } - Concrete::Thresh(k, ref subs) => { - let n = subs.len(); + Concrete::Thresh(ref thresh) => { + let k = thresh.k(); + let n = thresh.n(); let k_over_n = k as f64 / n as f64; let mut sub_ast = Vec::with_capacity(n); @@ -1019,7 +1002,7 @@ where let mut best_ws = Vec::with_capacity(n); let mut min_value = (0, f64::INFINITY); - for (i, ast) in subs.iter().enumerate() { + for (i, ast) in thresh.iter().enumerate() { let sp = sat_prob * k_over_n; //Expressions must be dissatisfiable let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); @@ -1037,7 +1020,7 @@ where } sub_ext_data.push(best_es[min_value.0].0); sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms)); - for (i, _ast) in subs.iter().enumerate() { + for (i, _ast) in thresh.iter().enumerate() { if i != min_value.0 { sub_ext_data.push(best_ws[i].0); sub_ast.push(Arc::clone(&best_ws[i].1.ms)); @@ -1054,40 +1037,40 @@ where insert_wrap!(ast_ext); } - let key_vec: Vec = subs + let key_count = thresh .iter() - .filter_map(|s| { - if let Concrete::Key(ref pk) = s.as_ref() { - Some(pk.clone()) + .filter(|s| matches!(***s, Concrete::Key(_))) + .count(); + if key_count == thresh.n() { + let pk_thresh = thresh.map_ref(|s| { + if let Concrete::Key(ref pk) = **s { + Pk::clone(pk) } else { - None + unreachable!() } - }) - .collect(); - - if key_vec.len() == subs.len() { + }); match Ctx::sig_type() { SigType::Schnorr => { - if key_vec.len() <= MAX_PUBKEYS_IN_CHECKSIGADD { - insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec))) + if let Ok(pk_thresh) = pk_thresh.set_maximum() { + insert_wrap!(AstElemExt::terminal(Terminal::MultiA(pk_thresh))) } } SigType::Ecdsa => { - if key_vec.len() <= MAX_PUBKEYS_PER_MULTISIG { - insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec))) + if let Ok(pk_thresh) = pk_thresh.set_maximum() { + insert_wrap!(AstElemExt::terminal(Terminal::Multi(pk_thresh))) } } } } - if k == subs.len() { - let mut it = subs.iter(); + if thresh.is_and() { + let mut it = thresh.iter(); let mut policy = it.next().expect("No sub policy in thresh() ?").clone(); policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into()); ret = best_compilations(policy_cache, policy.as_ref(), sat_prob, dissat_prob)?; } - // FIXME: Should we also optimize thresh(1, subs) ? + // FIXME: Should we also special-case thresh.is_or() ? } } for k in ret.keys() { @@ -1247,7 +1230,7 @@ mod tests { use super::*; use crate::miniscript::{Legacy, Segwitv0, Tap}; use crate::policy::Liftable; - use crate::{script_num_size, AbsLockTime, RelLockTime, ToPublicKey}; + use crate::{script_num_size, AbsLockTime, RelLockTime, Threshold, ToPublicKey}; type SPolicy = Concrete; type BPolicy = Concrete; @@ -1393,8 +1376,8 @@ mod tests { ( 127, Arc::new(Concrete::Thresh( - 3, - key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(), + Threshold::from_iter(3, key_pol[0..5].iter().map(|p| (p.clone()).into())) + .unwrap(), )), ), ( @@ -1402,8 +1385,8 @@ mod tests { Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(RelLockTime::from_height(10000))), Arc::new(Concrete::Thresh( - 2, - key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(), + Threshold::from_iter(2, key_pol[5..8].iter().map(|p| (p.clone()).into())) + .unwrap(), )), ])), ), @@ -1524,11 +1507,12 @@ mod tests { // and to a ms thresh otherwise. // k = 1 (or 2) does not compile, see https://github.com/rust-bitcoin/rust-miniscript/issues/114 for k in &[10, 15, 21] { - let pubkeys: Vec>> = keys - .iter() - .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) - .collect(); - let big_thresh = Concrete::Thresh(*k, pubkeys); + let thresh: Threshold>, 0> = Threshold::from_iter( + *k, + keys.iter().map(|pubkey| Arc::new(Concrete::Key(*pubkey))), + ) + .unwrap(); + let big_thresh = Concrete::Thresh(thresh); let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap(); if *k == 21 { // N * (PUSH + pubkey + CHECKSIGVERIFY) @@ -1564,8 +1548,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Thresh(keys_a.len(), keys_a))), - (1, Arc::new(Concrete::Thresh(keys_b.len(), keys_b))), + (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), + (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), ]) .compile(); let script_size = thresh_res.clone().map(|m| m.script_size()); @@ -1582,7 +1566,8 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let thresh_res: Result = Concrete::Thresh(keys.len(), keys).compile(); + let thresh_res: Result = + Concrete::Thresh(Threshold::and_n(keys)).compile(); let n_elements = thresh_res .clone() .map(|m| m.max_satisfaction_witness_elements()); @@ -1598,12 +1583,12 @@ mod tests { fn shared_limits() { // Test the maximum number of OPs with a 67-of-68 multisig let (keys, _) = pubkeys_and_a_sig(68); - let keys: Vec>> = keys - .iter() - .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) - .collect(); - let thresh_res: Result = - Concrete::Thresh(keys.len() - 1, keys).compile(); + let thresh = Threshold::from_iter( + keys.len() - 1, + keys.iter().map(|pubkey| Arc::new(Concrete::Key(*pubkey))), + ) + .unwrap(); + let thresh_res: Result = Concrete::Thresh(thresh).compile(); let ops_count = thresh_res.clone().map(|m| m.ext.ops.op_count()); assert_eq!( thresh_res, @@ -1613,11 +1598,13 @@ mod tests { ); // For legacy too.. let (keys, _) = pubkeys_and_a_sig(68); - let keys: Vec>> = keys - .iter() - .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) - .collect(); - let thresh_res = Concrete::Thresh(keys.len() - 1, keys).compile::(); + let thresh = Threshold::from_iter( + keys.len() - 1, + keys.iter().map(|pubkey| Arc::new(Concrete::Key(*pubkey))), + ) + .unwrap(); + + let thresh_res = Concrete::Thresh(thresh).compile::(); let ops_count = thresh_res.clone().map(|m| m.ext.ops.op_count()); assert_eq!( thresh_res, diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 0c6a957f3..a07bf262a 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -30,7 +30,8 @@ use crate::sync::Arc; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; use crate::{ - errstr, AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, RelLockTime, Translator, + errstr, AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, RelLockTime, Threshold, + Translator, }; /// Maximum TapLeafs allowed in a compiled TapTree @@ -69,7 +70,7 @@ pub enum Policy { /// relative probabilities for each one. Or(Vec<(usize, Arc>)>), /// A set of descriptors, satisfactions must be provided for `k` of them. - Thresh(usize, Vec>>), + Thresh(Threshold>, 0>), } /// Detailed error type for concrete policies. @@ -79,8 +80,6 @@ pub enum PolicyError { NonBinaryArgAnd, /// `Or` fragments only support two args. NonBinaryArgOr, - /// `Thresh` fragment can only have `1<=k<=n`. - IncorrectThresh, /// Semantic Policy Error: `And` `Or` fragments must take args: `k > 1`. InsufficientArgsforAnd, /// Semantic policy error: `And` `Or` fragments must take args: `k > 1`. @@ -115,9 +114,6 @@ impl fmt::Display for PolicyError { f.write_str("And policy fragment must take 2 arguments") } PolicyError::NonBinaryArgOr => f.write_str("Or policy fragment must take 2 arguments"), - PolicyError::IncorrectThresh => { - f.write_str("Threshold k must be greater than 0 and less than or equal to n 0 { f.write_str("Semantic Policy 'And' fragment must have at least 2 args ") } @@ -143,7 +139,6 @@ impl error::Error for PolicyError { match self { NonBinaryArgAnd | NonBinaryArgOr - | IncorrectThresh | InsufficientArgsforAnd | InsufficientArgsforOr | EntailmentMaxTerminals @@ -187,9 +182,10 @@ impl Policy { }) .collect::>() } - Policy::Thresh(k, ref subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(ref thresh) if thresh.is_or() => { + let total_odds = thresh.n(); + thresh + .iter() .flat_map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64)) .collect::>() } @@ -407,13 +403,14 @@ impl Policy { .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) .collect::>() } - Policy::Thresh(k, subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(ref thresh) if thresh.is_or() => { + let total_odds = thresh.n(); + thresh + .iter() .map(|pol| (prob / total_odds as f64, pol.clone())) .collect::>() } - Policy::Thresh(k, subs) if *k != subs.len() => generate_combination(subs, prob, *k), + Policy::Thresh(ref thresh) if !thresh.is_and() => generate_combination(thresh, prob), pol => vec![(prob, Arc::new(pol.clone()))], } } @@ -562,7 +559,9 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect()), - Thresh(ref k, ref subs) => Thresh(*k, (0..subs.len()).map(child_n).collect()), + Thresh(ref thresh) => { + Thresh(thresh.map_from_post_order_iter(&data.child_indices, &translated)) + } }; translated.push(Arc::new(new_policy)); } @@ -588,7 +587,9 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect())), - Thresh(k, ref subs) => Some(Thresh(*k, (0..subs.len()).map(child_n).collect())), + Thresh(ref thresh) => { + Some(Thresh(thresh.map_from_post_order_iter(&data.child_indices, &translated))) + } _ => None, }; match new_policy { @@ -624,7 +625,7 @@ impl Policy { let num = match data.node { Or(subs) => (0..subs.len()).map(num_for_child_n).sum(), - Thresh(k, subs) if *k == 1 => (0..subs.len()).map(num_for_child_n).sum(), + Thresh(thresh) if thresh.is_or() => (0..thresh.n()).map(num_for_child_n).sum(), _ => 1, }; nums.push(num); @@ -707,9 +708,9 @@ impl Policy { let iter = (0..subs.len()).map(info_for_child_n); TimelockInfo::combine_threshold(1, iter) } - Thresh(ref k, subs) => { - let iter = (0..subs.len()).map(info_for_child_n); - TimelockInfo::combine_threshold(*k, iter) + Thresh(ref thresh) => { + let iter = (0..thresh.n()).map(info_for_child_n); + TimelockInfo::combine_threshold(thresh.k(), iter) } _ => TimelockInfo::default(), }; @@ -731,8 +732,6 @@ impl Policy { for policy in self.pre_order_iter() { match *policy { - After(_) => {} - Older(_) => {} And(ref subs) => { if subs.len() != 2 { return Err(PolicyError::NonBinaryArgAnd); @@ -743,11 +742,6 @@ impl Policy { return Err(PolicyError::NonBinaryArgOr); } } - Thresh(k, ref subs) => { - if k == 0 || k > subs.len() { - return Err(PolicyError::IncorrectThresh); - } - } _ => {} } } @@ -787,16 +781,16 @@ impl Policy { }); (all_safe, atleast_one_safe && all_non_mall) } - Thresh(k, ref subs) => { - let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( + Thresh(ref thresh) => { + let (safe_count, non_mall_count) = (0..thresh.n()).map(acc_for_child_n).fold( (0, 0), |(safe_count, non_mall_count), (safe, non_mall)| { (safe_count + safe as usize, non_mall_count + non_mall as usize) }, ); ( - safe_count >= (subs.len() - k + 1), - non_mall_count == subs.len() && safe_count >= (subs.len() - k), + safe_count >= (thresh.n() - thresh.k() + 1), + non_mall_count == thresh.n() && safe_count >= (thresh.n() - thresh.k()), ) } }; @@ -839,13 +833,7 @@ impl fmt::Debug for Policy { } f.write_str(")") } - Policy::Thresh(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{:?}", sub)?; - } - f.write_str(")") - } + Policy::Thresh(ref thresh) => fmt::Debug::fmt(&thresh.debug("thresh", true), f), } } } @@ -882,13 +870,7 @@ impl fmt::Display for Policy { } f.write_str(")") } - Policy::Thresh(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{}", sub)?; - } - f.write_str(")") - } + Policy::Thresh(ref thresh) => fmt::Display::fmt(&thresh.display("thresh", true), f), } } } @@ -907,13 +889,13 @@ impl str::FromStr for Policy { serde_string_impl_pk!(Policy, "a miniscript concrete policy"); -#[rustfmt::skip] impl Policy { /// Helper function for `from_tree` to parse subexpressions with /// names of the form x@y - fn from_tree_prob(top: &expression::Tree, allow_prob: bool,) - -> Result<(usize, Policy), Error> - { + fn from_tree_prob( + top: &expression::Tree, + allow_prob: bool, + ) -> Result<(usize, Policy), Error> { let frag_prob; let frag_name; let mut name_split = top.name.split('@'); @@ -981,24 +963,17 @@ impl Policy { for arg in &top.args { subs.push(Policy::from_tree_prob(arg, true)?); } - Ok(Policy::Or(subs.into_iter().map(|(prob, sub)| (prob, Arc::new(sub))).collect())) - } - ("thresh", nsubs) => { - if top.args.is_empty() || !top.args[0].args.is_empty() { - return Err(Error::PolicyError(PolicyError::IncorrectThresh)); - } - - let thresh = expression::parse_num(top.args[0].name)?; - if thresh >= nsubs || thresh == 0 { - return Err(Error::PolicyError(PolicyError::IncorrectThresh)); - } - - let mut subs = Vec::with_capacity(top.args.len() - 1); - for arg in &top.args[1..] { - subs.push(Policy::from_tree(arg)?); - } - Ok(Policy::Thresh(thresh as usize, subs.into_iter().map(Arc::new).collect())) + Ok(Policy::Or( + subs.into_iter() + .map(|(prob, sub)| (prob, Arc::new(sub))) + .collect(), + )) } + ("thresh", _) => top + .to_null_threshold() + .map_err(Error::ParseThreshold)? + .translate_by_index(|i| Policy::from_tree(&top.args[1 + i]).map(Arc::new)) + .map(Policy::Thresh), _ => Err(errstr(top.name)), } .map(|res| (frag_prob, res)) @@ -1048,20 +1023,23 @@ fn with_huffman_tree( /// any one of the conditions exclusively. #[cfg(feature = "compiler")] fn generate_combination( - policy_vec: &[Arc>], + thresh: &Threshold>, 0>, prob: f64, - k: usize, ) -> Vec<(f64, Arc>)> { - debug_assert!(k <= policy_vec.len()); + debug_assert!(thresh.k() < thresh.n()); + let prob_over_n = prob / thresh.n() as f64; let mut ret: Vec<(f64, Arc>)> = vec![]; - for i in 0..policy_vec.len() { - let policies: Vec>> = policy_vec - .iter() - .enumerate() - .filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None }) - .collect(); - ret.push((prob / policy_vec.len() as f64, Arc::new(Policy::Thresh(k, policies)))); + for i in 0..thresh.n() { + let thresh_less_1 = Threshold::from_iter( + thresh.k(), + thresh + .iter() + .enumerate() + .filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None }), + ) + .expect("k is strictly less than n, so (k, n-1) is a valid threshold"); + ret.push((prob_over_n, Arc::new(Policy::Thresh(thresh_less_1)))); } ret } @@ -1075,7 +1053,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for &'a Policy { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()), - Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), + Thresh(ref thresh) => Tree::Nary(thresh.iter().map(Arc::as_ref).collect()), } } } @@ -1089,7 +1067,7 @@ impl TreeLike for Arc> { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()), - Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), + Thresh(ref thresh) => Tree::Nary(thresh.iter().map(Arc::clone).collect()), } } } @@ -1107,8 +1085,9 @@ mod compiler_tests { .map(|st| policy_str!("{}", st)) .map(Arc::new) .collect(); + let thresh = Threshold::new(2, policies).unwrap(); - let combinations = generate_combination(&policies, 1.0, 2); + let combinations = generate_combination(&thresh, 1.0); let comb_a: Vec> = vec![ policy_str!("pk(B)"), @@ -1133,7 +1112,9 @@ mod compiler_tests { let expected_comb = vec![comb_a, comb_b, comb_c, comb_d] .into_iter() .map(|sub_pol| { - (0.25, Arc::new(Policy::Thresh(2, sub_pol.into_iter().map(Arc::new).collect()))) + let expected_thresh = + Threshold::from_iter(2, sub_pol.into_iter().map(Arc::new)).unwrap(); + (0.25, Arc::new(Policy::Thresh(expected_thresh))) }) .collect::>(); assert_eq!(combinations, expected_comb); diff --git a/src/policy/mod.rs b/src/policy/mod.rs index 86d616a9d..db337b3b5 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -164,14 +164,15 @@ impl Liftable for Terminal { // unwrap to be removed in a later commit Semantic::Thresh(Threshold::new(k, semantic_subs).unwrap()) } - Terminal::Multi(k, ref keys) | Terminal::MultiA(k, ref keys) => Semantic::Thresh( - Threshold::new( - k, - keys.iter() - .map(|k| Arc::new(Semantic::Key(k.clone()))) - .collect(), - ) - .unwrap(), // unwrap to be removed in a later commit + Terminal::Multi(ref thresh) => Semantic::Thresh( + thresh + .map_ref(|key| Arc::new(Semantic::Key(key.clone()))) + .forget_maximum(), + ), + Terminal::MultiA(ref thresh) => Semantic::Thresh( + thresh + .map_ref(|key| Arc::new(Semantic::Key(key.clone()))) + .forget_maximum(), ), } .normalized(); @@ -223,11 +224,8 @@ impl Liftable for Concrete { let semantic_subs = semantic_subs?.into_iter().map(Arc::new).collect(); Semantic::Thresh(Threshold::new(1, semantic_subs).unwrap()) } - Concrete::Thresh(k, ref subs) => { - let semantic_subs: Result>, Error> = - subs.iter().map(Liftable::lift).collect(); - let semantic_subs = semantic_subs?.into_iter().map(Arc::new).collect(); - Semantic::Thresh(Threshold::new(k, semantic_subs).unwrap()) + Concrete::Thresh(ref thresh) => { + Semantic::Thresh(thresh.translate_ref(|sub| Liftable::lift(sub).map(Arc::new))?) } } .normalized(); @@ -307,13 +305,13 @@ mod tests { ConcretePol::from_str("thresh(2,pk(),thresh(0))") .unwrap_err() .to_string(), - "Threshold k must be greater than 0 and less than or equal to n 0 0", ); assert_eq!( ConcretePol::from_str("and(pk())").unwrap_err().to_string(),