From 4d2d43019891609598c1650040c7102a6b61d78d Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Mon, 9 Oct 2023 15:21:36 +1100 Subject: [PATCH] Use Threshold type in concrete::Policy::Thresh Use the `Threshold` type in `policy::concrete::Policy::Thresh` to help maintain invariants on n and k. --- src/iter/mod.rs | 4 +- src/policy/compiler.rs | 42 ++++++++------- src/policy/concrete.rs | 116 +++++++++++++++++++++++------------------ src/policy/mod.rs | 6 +-- 4 files changed, 93 insertions(+), 75 deletions(-) diff --git a/src/iter/mod.rs b/src/iter/mod.rs index c91c58c8c..771bd4d81 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -77,7 +77,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for &'a policy::Concrete { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()), - Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()), + Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::as_ref).collect()), } } } @@ -90,7 +90,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for Arc> { | Ripemd160(_) | Hash160(_) => Tree::Nullary, And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()), - Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()), + Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::clone).collect()), } } } diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index f80d06c64..0f12fabe2 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -920,8 +920,9 @@ where compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI); compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI); } - Concrete::Thresh(k, ref subs) => { - let n = subs.len(); + Concrete::Thresh(ref thresh) => { + let k = thresh.k(); + let n = thresh.n(); let k_over_n = k as f64 / n as f64; let mut sub_ast = Vec::with_capacity(n); @@ -931,7 +932,7 @@ where let mut best_ws = Vec::with_capacity(n); let mut min_value = (0, f64::INFINITY); - for (i, ast) in subs.iter().enumerate() { + for (i, ast) in thresh.iter().enumerate() { let sp = sat_prob * k_over_n; //Expressions must be dissatisfiable let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); @@ -949,7 +950,7 @@ where } sub_ext_data.push(best_es[min_value.0].0); sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms)); - for (i, _ast) in subs.iter().enumerate() { + for (i, _ast) in thresh.iter().enumerate() { if i != min_value.0 { sub_ext_data.push(best_ws[i].0); sub_ast.push(Arc::clone(&best_ws[i].1.ms)); @@ -966,7 +967,7 @@ where insert_wrap!(ast_ext); } - let key_vec: Vec = subs + let key_vec: Vec = thresh .iter() .filter_map(|s| { if let Concrete::Key(ref pk) = s.as_ref() { @@ -978,16 +979,16 @@ where .collect(); match Ctx::sig_type() { - SigType::Schnorr if key_vec.len() == subs.len() => { + SigType::Schnorr if key_vec.len() == thresh.n() => { insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec))) } SigType::Ecdsa - if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG => + if key_vec.len() == thresh.n() && thresh.n() <= MAX_PUBKEYS_PER_MULTISIG => { insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec))) } - _ if k == subs.len() => { - let mut it = subs.iter(); + _ if k == thresh.n() => { + let mut it = thresh.iter(); let mut policy = it.next().expect("No sub policy in thresh() ?").clone(); policy = it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into()); @@ -1157,6 +1158,7 @@ mod tests { use super::*; use crate::miniscript::{Legacy, Segwitv0, Tap}; use crate::policy::Liftable; + use crate::threshold::Threshold; use crate::{script_num_size, ToPublicKey}; type SPolicy = Concrete; @@ -1301,19 +1303,19 @@ mod tests { let policy: BPolicy = Concrete::Or(vec![ ( 127, - Arc::new(Concrete::Thresh( + Arc::new(Concrete::Thresh(Threshold::new_unchecked( 3, key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(), - )), + ))), ), ( 1, Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(Sequence::from_height(10000))), - Arc::new(Concrete::Thresh( + Arc::new(Concrete::Thresh(Threshold::new_unchecked( 2, key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(), - )), + ))), ])), ), ]); @@ -1430,7 +1432,7 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let big_thresh = Concrete::Thresh(*k, pubkeys); + let big_thresh = Concrete::Thresh(Threshold::new_unchecked(*k, pubkeys)); let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap(); if *k == 21 { // N * (PUSH + pubkey + CHECKSIGVERIFY) @@ -1466,8 +1468,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Thresh(keys_a.len(), keys_a))), - (1, Arc::new(Concrete::Thresh(keys_b.len(), keys_b))), + (1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_a.len(), keys_a)))), + (1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_b.len(), keys_b)))), ]) .compile(); let script_size = thresh_res.clone().and_then(|m| Ok(m.script_size())); @@ -1484,7 +1486,8 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let thresh_res: Result = Concrete::Thresh(keys.len(), keys).compile(); + let thresh_res: Result = + Concrete::Thresh(Threshold::new_unchecked(keys.len(), keys)).compile(); let n_elements = thresh_res .clone() .and_then(|m| Ok(m.max_satisfaction_witness_elements())); @@ -1505,7 +1508,7 @@ mod tests { .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); let thresh_res: Result = - Concrete::Thresh(keys.len() - 1, keys).compile(); + Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); assert_eq!( thresh_res, @@ -1519,7 +1522,8 @@ mod tests { .iter() .map(|pubkey| Arc::new(Concrete::Key(*pubkey))) .collect(); - let thresh_res = Concrete::Thresh(keys.len() - 1, keys).compile::(); + let thresh_res = + Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile::(); let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count())); assert_eq!( thresh_res, diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 5e1675ebf..d1b4b6c7d 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -27,6 +27,7 @@ use crate::iter::TreeLike; use crate::miniscript::types::extra_props::TimelockInfo; use crate::prelude::*; use crate::sync::Arc; +use crate::threshold::Threshold; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; use crate::{errstr, AbsLockTime, Error, ForEachKey, MiniscriptKey, Translator}; @@ -67,7 +68,7 @@ pub enum Policy { /// relative probabilities for each one. Or(Vec<(usize, Arc>)>), /// A set of descriptors, satisfactions must be provided for `k` of them. - Thresh(usize, Vec>>), + Thresh(Threshold>>), } impl Policy @@ -210,9 +211,10 @@ impl Policy { }) .collect::>() } - Policy::Thresh(k, ref subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(thresh) if thresh.k() == 1 => { + let total_odds = thresh.n(); + thresh + .iter() .flat_map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64)) .collect::>() } @@ -430,13 +432,16 @@ impl Policy { .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) .collect::>() } - Policy::Thresh(k, subs) if *k == 1 => { - let total_odds = subs.len(); - subs.iter() + Policy::Thresh(thresh) if thresh.k() == 1 => { + let total_odds = thresh.n(); + thresh + .iter() .map(|pol| (prob / total_odds as f64, pol.clone())) .collect::>() } - Policy::Thresh(k, subs) if *k != subs.len() => generate_combination(subs, prob, *k), + Policy::Thresh(thresh) if thresh.k() != thresh.n() => { + generate_combination(thresh, prob) + } pol => vec![(prob, Arc::new(pol.clone()))], } } @@ -585,7 +590,7 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect()), - Thresh(ref k, ref subs) => Thresh(*k, (0..subs.len()).map(child_n).collect()), + Thresh(ref thresh) => Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect())), }; translated.push(Arc::new(new_policy)); } @@ -611,7 +616,9 @@ impl Policy { .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect())), - Thresh(k, ref subs) => Some(Thresh(*k, (0..subs.len()).map(child_n).collect())), + Thresh(ref thresh) => { + Some(Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect()))) + } _ => None, }; match new_policy { @@ -647,7 +654,7 @@ impl Policy { let num = match data.node { Or(subs) => (0..subs.len()).map(num_for_child_n).sum(), - Thresh(k, subs) if *k == 1 => (0..subs.len()).map(num_for_child_n).sum(), + Thresh(thresh) if thresh.k() == 1 => (0..thresh.n()).map(num_for_child_n).sum(), _ => 1, }; nums.push(num); @@ -730,9 +737,9 @@ impl Policy { let iter = (0..subs.len()).map(info_for_child_n); TimelockInfo::combine_threshold(1, iter) } - Thresh(ref k, subs) => { - let iter = (0..subs.len()).map(info_for_child_n); - TimelockInfo::combine_threshold(*k, iter) + Thresh(ref thresh) => { + let iter = (0..thresh.n()).map(info_for_child_n); + TimelockInfo::combine_threshold(thresh.k(), iter) } _ => TimelockInfo::default(), }; @@ -768,15 +775,11 @@ impl Policy { return Err(PolicyError::TimeTooFar); } } - Or(ref subs) => { - if subs.len() != 2 { - return Err(PolicyError::NonBinaryArgOr); - } - } - Thresh(k, ref subs) => { - if k == 0 || k > subs.len() { - return Err(PolicyError::IncorrectThresh); - } + Policy::Thresh(ref thresh) => { + thresh + .iter() + .map(|pol| pol.is_valid()) + .collect::, PolicyError>>()?; } _ => {} } @@ -817,16 +820,16 @@ impl Policy { }); (all_safe, atleast_one_safe && all_non_mall) } - Thresh(k, ref subs) => { - let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( - (0, 0), - |(safe_count, non_mall_count), (safe, non_mall)| { + Policy::Thresh(ref thresh) => { + let (safe_count, non_mall_count) = thresh + .iter() + .map(|sub| sub.is_safe_nonmalleable()) + .fold((0, 0), |(safe_count, non_mall_count), (safe, non_mall)| { (safe_count + safe as usize, non_mall_count + non_mall as usize) - }, - ); + }); ( - safe_count >= (subs.len() - k + 1), - non_mall_count == subs.len() && safe_count >= (subs.len() - k), + safe_count >= (thresh.n() - thresh.k() + 1), + non_mall_count == thresh.n() && safe_count >= (thresh.n() - thresh.k()), ) } }; @@ -869,10 +872,10 @@ impl fmt::Debug for Policy { } f.write_str(")") } - Policy::Thresh(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{:?}", sub)?; + Policy::Thresh(ref thresh) => { + write!(f, "thresh({}", thresh.k())?; + for policy in thresh.iter() { + write!(f, ",{:?}", policy)?; } f.write_str(")") } @@ -912,10 +915,10 @@ impl fmt::Display for Policy { } f.write_str(")") } - Policy::Thresh(k, ref subs) => { - write!(f, "thresh({}", k)?; - for sub in subs { - write!(f, ",{}", sub)?; + Policy::Thresh(ref thresh) => { + write!(f, "thresh({}", thresh.k())?; + for policy in thresh.iter() { + write!(f, ",{}", policy)?; } f.write_str(")") } @@ -1028,8 +1031,9 @@ impl_block_str!( return Err(Error::PolicyError(PolicyError::IncorrectThresh)); } - let thresh = expression::parse_num(top.args[0].name)?; - if thresh >= nsubs || thresh == 0 { + // TODO: Find out why this cast ok. + let k = expression::parse_num(top.args[0].name)?; + if k >= nsubs || k == 0 { return Err(Error::PolicyError(PolicyError::IncorrectThresh)); } @@ -1037,7 +1041,13 @@ impl_block_str!( for arg in &top.args[1..] { subs.push(Policy::from_tree(arg)?); } - Ok(Policy::Thresh(thresh as usize, subs.into_iter().map(Arc::new).collect())) + let v = subs.into_iter().map(Arc::new).collect(); + + // TODO: OK to cast from u32 to usize in this codebase? + let k = k as usize; + + let thresh = Threshold::new(k, v).map_err(|_| PolicyError::IncorrectThresh)?; + Ok(Policy::Thresh(thresh)) } _ => Err(errstr(top.name)), } @@ -1089,20 +1099,20 @@ fn with_huffman_tree( /// any one of the conditions exclusively. #[cfg(feature = "compiler")] fn generate_combination( - policy_vec: &Vec>>, + policy_thresh: &Threshold>>, prob: f64, - k: usize, ) -> Vec<(f64, Arc>)> { - debug_assert!(k <= policy_vec.len()); - let mut ret: Vec<(f64, Arc>)> = vec![]; - for i in 0..policy_vec.len() { - let policies: Vec>> = policy_vec + let k = policy_thresh.k(); + for i in 0..policy_thresh.n() { + let policies: Vec>> = policy_thresh .iter() .enumerate() .filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None }) .collect(); - ret.push((prob / policy_vec.len() as f64, Arc::new(Policy::Thresh(k, policies)))); + if let Ok(thresh) = Threshold::new(k, policies) { + ret.push((prob / policy_thresh.n() as f64, Arc::new(Policy::Thresh(thresh)))); + } } ret } @@ -1123,7 +1133,8 @@ mod compiler_tests { .map(|p| Arc::new(p)) .collect(); - let combinations = generate_combination(&policies, 1.0, 2); + let thresh = Threshold::new_unchecked(2, policies); + let combinations = generate_combination(&thresh, 1.0); let comb_a: Vec> = vec![ policy_str!("pk(B)"), @@ -1150,7 +1161,10 @@ mod compiler_tests { .map(|sub_pol| { ( 0.25, - Arc::new(Policy::Thresh(2, sub_pol.into_iter().map(|p| Arc::new(p)).collect())), + Arc::new(Policy::Thresh(Threshold::new_unchecked( + 2, + sub_pol.into_iter().map(|p| Arc::new(p)).collect(), + ))), ) }) .collect::>(); diff --git a/src/policy/mod.rs b/src/policy/mod.rs index aa6d4ef75..122ce5b3f 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -205,9 +205,9 @@ impl Liftable for Concrete { subs.iter().map(|(_p, sub)| sub.lift()).collect(); Semantic::Thresh(1, semantic_subs?) } - Concrete::Thresh(k, ref subs) => { - let semantic_subs: Result<_, Error> = subs.iter().map(Liftable::lift).collect(); - Semantic::Thresh(k, semantic_subs?) + Concrete::Thresh(ref thresh) => { + let semantic_subs: Result<_, Error> = thresh.iter().map(Liftable::lift).collect(); + Semantic::Thresh(thresh.k(), semantic_subs?) } } .normalized();