From 90c18f399be8c5db5f2fcade00674f15aa54888e Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Mon, 2 Oct 2023 09:18:18 +1100 Subject: [PATCH 1/3] Remove recursion from is_valid Currently the `is_valid` function recursively calls itself, we can use the `pre_order_iter` to loop over the policy nodes instead and remove the recursion. Note that this whole `is_valid` function is pretty inefficient because we have already iterated the nodes twice already (in `check_timelocks` and `check_duplicate_keys`). --- src/policy/concrete.rs | 75 +++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 44 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 23305087d..2e1a8bb84 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -749,59 +749,46 @@ impl Policy { /// Validity condition also checks whether there is a possible satisfaction /// combination of timelocks and heightlocks pub fn is_valid(&self) -> Result<(), PolicyError> { + use Policy::*; + self.check_timelocks()?; self.check_duplicate_keys()?; - match *self { - Policy::And(ref subs) => { - if subs.len() != 2 { - Err(PolicyError::NonBinaryArgAnd) - } else { - subs.iter() - .map(|sub| sub.is_valid()) - .collect::, PolicyError>>()?; - Ok(()) + + for policy in self.pre_order_iter() { + match *policy { + And(ref subs) => { + if subs.len() != 2 { + return Err(PolicyError::NonBinaryArgAnd); + } } - } - Policy::Or(ref subs) => { - if subs.len() != 2 { - Err(PolicyError::NonBinaryArgOr) - } else { - subs.iter() - .map(|(_prob, sub)| sub.is_valid()) - .collect::, PolicyError>>()?; - Ok(()) + Or(ref subs) => { + if subs.len() != 2 { + return Err(PolicyError::NonBinaryArgOr); + } } - } - Policy::Threshold(k, ref subs) => { - if k == 0 || k > subs.len() { - Err(PolicyError::IncorrectThresh) - } else { - subs.iter() - .map(|sub| sub.is_valid()) - .collect::, PolicyError>>()?; - Ok(()) + Threshold(k, ref subs) => { + if k == 0 || k > subs.len() { + return Err(PolicyError::IncorrectThresh); + } } - } - Policy::After(n) => { - if n == absolute::LockTime::ZERO.into() { - Err(PolicyError::ZeroTime) - } else if n.to_u32() > 2u32.pow(31) { - Err(PolicyError::TimeTooFar) - } else { - Ok(()) + After(n) => { + if n == absolute::LockTime::ZERO.into() { + return Err(PolicyError::ZeroTime); + } else if n.to_u32() > 2u32.pow(31) { + return Err(PolicyError::TimeTooFar); + } } - } - Policy::Older(n) => { - if n == Sequence::ZERO { - Err(PolicyError::ZeroTime) - } else if n.to_consensus_u32() > 2u32.pow(31) { - Err(PolicyError::TimeTooFar) - } else { - Ok(()) + Older(n) => { + if n == Sequence::ZERO { + return Err(PolicyError::ZeroTime); + } else if n.to_consensus_u32() > 2u32.pow(31) { + return Err(PolicyError::TimeTooFar); + } } + _ => {} } - _ => Ok(()), } + Ok(()) } /// Checks if any possible compilation of the policy could be compiled From d2278d5a74590ba697bd139ff9285c63ef7a4b06 Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Mon, 2 Oct 2023 09:42:58 +1100 Subject: [PATCH 2/3] Remove recursion from is_safe_nonmalleable As we have done for other `Policy` functions remove the recursive calls in `is_safe_nonmalleable` and use the `post_order_iter` to process each node accumulating the required results in a vector during iteration. --- src/policy/concrete.rs | 75 ++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 2e1a8bb84..0c1812a8f 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -799,43 +799,48 @@ impl Policy { /// Returns a tuple `(safe, non-malleable)` to avoid the fact that /// non-malleability depends on safety and we would like to cache results. pub fn is_safe_nonmalleable(&self) -> (bool, bool) { - match *self { - Policy::Unsatisfiable | Policy::Trivial => (true, true), - Policy::Key(_) => (true, true), - Policy::Sha256(_) - | Policy::Hash256(_) - | Policy::Ripemd160(_) - | Policy::Hash160(_) - | Policy::After(_) - | Policy::Older(_) => (false, true), - Policy::Threshold(k, ref subs) => { - let (safe_count, non_mall_count) = subs - .iter() - .map(|sub| sub.is_safe_nonmalleable()) - .fold((0, 0), |(safe_count, non_mall_count), (safe, non_mall)| { - (safe_count + safe as usize, non_mall_count + non_mall as usize) - }); - ( - safe_count >= (subs.len() - k + 1), - non_mall_count == subs.len() && safe_count >= (subs.len() - k), - ) - } - Policy::And(ref subs) => { - let (atleast_one_safe, all_non_mall) = subs - .iter() - .map(|sub| sub.is_safe_nonmalleable()) - .fold((false, true), |acc, x| (acc.0 || x.0, acc.1 && x.1)); - (atleast_one_safe, all_non_mall) - } + use Policy::*; - Policy::Or(ref subs) => { - let (all_safe, atleast_one_safe, all_non_mall) = subs - .iter() - .map(|(_, sub)| sub.is_safe_nonmalleable()) - .fold((true, false, true), |acc, x| (acc.0 && x.0, acc.1 || x.0, acc.2 && x.1)); - (all_safe, atleast_one_safe && all_non_mall) - } + let mut acc = vec![]; + for data in Arc::new(self).post_order_iter() { + let acc_for_child_n = |n| acc[data.child_indices[n]]; + + let new = match data.node { + Unsatisfiable | Trivial | Key(_) => (true, true), + Sha256(_) | Hash256(_) | Ripemd160(_) | Hash160(_) | After(_) | Older(_) => { + (false, true) + } + Threshold(k, ref subs) => { + let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( + (0, 0), + |(safe_count, non_mall_count), (safe, non_mall)| { + (safe_count + safe as usize, non_mall_count + non_mall as usize) + }, + ); + ( + safe_count >= (subs.len() - k + 1), + non_mall_count == subs.len() && safe_count >= (subs.len() - k), + ) + } + And(ref subs) => { + let (atleast_one_safe, all_non_mall) = (0..subs.len()) + .map(acc_for_child_n) + .fold((false, true), |acc, x: (bool, bool)| (acc.0 || x.0, acc.1 && x.1)); + (atleast_one_safe, all_non_mall) + } + Or(ref subs) => { + let (all_safe, atleast_one_safe, all_non_mall) = (0..subs.len()) + .map(acc_for_child_n) + .fold((true, false, true), |acc, x| { + (acc.0 && x.0, acc.1 || x.0, acc.2 && x.1) + }); + (all_safe, atleast_one_safe && all_non_mall) + } + }; + acc.push(new); } + // Ok to unwrap because we know we processed at least one node. + acc.pop().unwrap() } } From 0452040f6b31b981ee89a5c2480193c879336f65 Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Mon, 2 Oct 2023 09:59:35 +1100 Subject: [PATCH 3/3] Order arms as in enum definition Currently there are a bunch of places where when matching on the `Policy` enum variants are matched in different order to the enum. During recent work this was maintained to make diffs easier to review however since we just modified a lot of this file lets clean it up while we are here. Move the match arms around so that they are ordered in the same order as the variants are defined in the `Policy` enum. Refactor only, no logic changes. --- src/policy/concrete.rs | 68 +++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 0c1812a8f..19e7771b9 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -579,13 +579,13 @@ impl Policy { Hash160(ref h) => t.hash160(h).map(Hash160)?, Older(ref n) => Older(*n), After(ref n) => After(*n), - Threshold(ref k, ref subs) => Threshold(*k, (0..subs.len()).map(child_n).collect()), And(ref subs) => And((0..subs.len()).map(child_n).collect()), Or(ref subs) => Or(subs .iter() .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect()), + Threshold(ref k, ref subs) => Threshold(*k, (0..subs.len()).map(child_n).collect()), }; translated.push(Arc::new(new_policy)); } @@ -605,15 +605,15 @@ impl Policy { let new_policy = match data.node.as_ref() { Policy::Key(ref k) if k.clone() == *key => Some(Policy::Unsatisfiable), - Threshold(k, ref subs) => { - Some(Threshold(*k, (0..subs.len()).map(child_n).collect())) - } And(ref subs) => Some(And((0..subs.len()).map(child_n).collect())), Or(ref subs) => Some(Or(subs .iter() .enumerate() .map(|(i, (prob, _))| (*prob, child_n(i))) .collect())), + Threshold(k, ref subs) => { + Some(Threshold(*k, (0..subs.len()).map(child_n).collect())) + } _ => None, }; match new_policy { @@ -724,10 +724,6 @@ impl Policy { cltv_with_time: false, contains_combination: false, }, - Threshold(ref k, subs) => { - let iter = (0..subs.len()).map(info_for_child_n); - TimelockInfo::combine_threshold(*k, iter) - } And(ref subs) => { let iter = (0..subs.len()).map(info_for_child_n); TimelockInfo::combine_threshold(subs.len(), iter) @@ -736,6 +732,10 @@ impl Policy { let iter = (0..subs.len()).map(info_for_child_n); TimelockInfo::combine_threshold(1, iter) } + Threshold(ref k, subs) => { + let iter = (0..subs.len()).map(info_for_child_n); + TimelockInfo::combine_threshold(*k, iter) + } _ => TimelockInfo::default(), }; infos.push(info); @@ -756,6 +756,20 @@ impl Policy { for policy in self.pre_order_iter() { match *policy { + After(n) => { + if n == absolute::LockTime::ZERO.into() { + return Err(PolicyError::ZeroTime); + } else if n.to_u32() > 2u32.pow(31) { + return Err(PolicyError::TimeTooFar); + } + } + Older(n) => { + if n == Sequence::ZERO { + return Err(PolicyError::ZeroTime); + } else if n.to_consensus_u32() > 2u32.pow(31) { + return Err(PolicyError::TimeTooFar); + } + } And(ref subs) => { if subs.len() != 2 { return Err(PolicyError::NonBinaryArgAnd); @@ -771,20 +785,6 @@ impl Policy { return Err(PolicyError::IncorrectThresh); } } - After(n) => { - if n == absolute::LockTime::ZERO.into() { - return Err(PolicyError::ZeroTime); - } else if n.to_u32() > 2u32.pow(31) { - return Err(PolicyError::TimeTooFar); - } - } - Older(n) => { - if n == Sequence::ZERO { - return Err(PolicyError::ZeroTime); - } else if n.to_consensus_u32() > 2u32.pow(31) { - return Err(PolicyError::TimeTooFar); - } - } _ => {} } } @@ -810,18 +810,6 @@ impl Policy { Sha256(_) | Hash256(_) | Ripemd160(_) | Hash160(_) | After(_) | Older(_) => { (false, true) } - Threshold(k, ref subs) => { - let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( - (0, 0), - |(safe_count, non_mall_count), (safe, non_mall)| { - (safe_count + safe as usize, non_mall_count + non_mall as usize) - }, - ); - ( - safe_count >= (subs.len() - k + 1), - non_mall_count == subs.len() && safe_count >= (subs.len() - k), - ) - } And(ref subs) => { let (atleast_one_safe, all_non_mall) = (0..subs.len()) .map(acc_for_child_n) @@ -836,6 +824,18 @@ impl Policy { }); (all_safe, atleast_one_safe && all_non_mall) } + Threshold(k, ref subs) => { + let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold( + (0, 0), + |(safe_count, non_mall_count), (safe, non_mall)| { + (safe_count + safe as usize, non_mall_count + non_mall as usize) + }, + ); + ( + safe_count >= (subs.len() - k + 1), + non_mall_count == subs.len() && safe_count >= (subs.len() - k), + ) + } }; acc.push(new); }