diff --git a/src/miniscript/astelem.rs b/src/miniscript/astelem.rs index 22cd07565..f9e8661d2 100644 --- a/src/miniscript/astelem.rs +++ b/src/miniscript/astelem.rs @@ -15,8 +15,7 @@ use bitcoin::{absolute, opcodes, script, Sequence}; use sync::Arc; use crate::miniscript::context::SigType; -use crate::miniscript::types::{self, Property}; -use crate::miniscript::ScriptContext; +use crate::miniscript::{types, ScriptContext}; use crate::prelude::*; use crate::util::MsKeyBuilder; use crate::{ diff --git a/src/miniscript/decode.rs b/src/miniscript/decode.rs index 950612881..a058cf4b4 100644 --- a/src/miniscript/decode.rs +++ b/src/miniscript/decode.rs @@ -17,7 +17,7 @@ use sync::Arc; use crate::miniscript::lex::{Token as Tk, TokenIter}; use crate::miniscript::limits::MAX_PUBKEYS_PER_MULTISIG; use crate::miniscript::types::extra_props::ExtData; -use crate::miniscript::types::{Property, Type}; +use crate::miniscript::types::Type; use crate::miniscript::ScriptContext; use crate::prelude::*; #[cfg(doc)] diff --git a/src/miniscript/mod.rs b/src/miniscript/mod.rs index cf634a954..b8fd8a902 100644 --- a/src/miniscript/mod.rs +++ b/src/miniscript/mod.rs @@ -40,7 +40,6 @@ use core::cmp; use sync::Arc; use self::lex::{lex, TokenIter}; -use self::types::Property; pub use crate::miniscript::context::ScriptContext; use crate::miniscript::decode::Terminal; use crate::miniscript::types::extra_props::ExtData; diff --git a/src/miniscript/types/extra_props.rs b/src/miniscript/types/extra_props.rs index 8cbb4aa3c..32d0dd74a 100644 --- a/src/miniscript/types/extra_props.rs +++ b/src/miniscript/types/extra_props.rs @@ -856,22 +856,12 @@ impl Property for ExtData { exec_stack_elem_count_dissat, }) } +} - fn type_check_with_child( - _fragment: &Terminal, - mut _child: C, - ) -> Result> - where - C: FnMut(usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - unreachable!() - } - +impl ExtData { /// Compute the type of a fragment assuming all the children of /// Miniscript have been computed already. - fn type_check(fragment: &Terminal) -> Result> + pub fn type_check(fragment: &Terminal) -> Result> where Ctx: ScriptContext, Pk: MiniscriptKey, diff --git a/src/miniscript/types/mod.rs b/src/miniscript/types/mod.rs index e95585c9b..d9d579231 100644 --- a/src/miniscript/types/mod.rs +++ b/src/miniscript/types/mod.rs @@ -344,175 +344,6 @@ pub trait Property: Sized { fn threshold(k: usize, n: usize, sub_ck: S) -> Result where S: FnMut(usize) -> Result; - - /// Compute the type of a fragment, given a function to look up - /// the types of its children, if available and relevant for the - /// given fragment - fn type_check_common<'a, Pk, Ctx, C>( - fragment: &'a Terminal, - mut get_child: C, - ) -> Result> - where - C: FnMut(&'a Terminal, usize) -> Result>, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let wrap_err = |result: Result| { - result.map_err(|kind| Error { fragment: fragment.clone(), error: kind }) - }; - - let ret = match *fragment { - Terminal::True => Ok(Self::from_true()), - Terminal::False => Ok(Self::from_false()), - Terminal::PkK(..) => Ok(Self::from_pk_k::()), - Terminal::PkH(..) | Terminal::RawPkH(..) => Ok(Self::from_pk_h::()), - Terminal::Multi(k, ref pks) | Terminal::MultiA(k, ref pks) => { - if k == 0 { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::ZeroThreshold, - }); - } - if k > pks.len() { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::OverThreshold(k, pks.len()), - }); - } - match *fragment { - Terminal::Multi(..) => Ok(Self::from_multi(k, pks.len())), - Terminal::MultiA(..) => Ok(Self::from_multi_a(k, pks.len())), - _ => unreachable!(), - } - } - Terminal::After(t) => { - // Note that for CLTV this is a limitation not of Bitcoin but Miniscript. The - // number on the stack would be a 5 bytes signed integer but Miniscript's B type - // only consumes 4 bytes from the stack. - if t == absolute::LockTime::ZERO.into() { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::InvalidTime, - }); - } - Ok(Self::from_after(t.into())) - } - Terminal::Older(t) => { - if t == Sequence::ZERO || !t.is_relative_lock_time() { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::InvalidTime, - }); - } - Ok(Self::from_older(t)) - } - Terminal::Sha256(..) => Ok(Self::from_sha256()), - Terminal::Hash256(..) => Ok(Self::from_hash256()), - Terminal::Ripemd160(..) => Ok(Self::from_ripemd160()), - Terminal::Hash160(..) => Ok(Self::from_hash160()), - Terminal::Alt(ref sub) => wrap_err(Self::cast_alt(get_child(&sub.node, 0)?)), - Terminal::Swap(ref sub) => wrap_err(Self::cast_swap(get_child(&sub.node, 0)?)), - Terminal::Check(ref sub) => wrap_err(Self::cast_check(get_child(&sub.node, 0)?)), - Terminal::DupIf(ref sub) => wrap_err(Self::cast_dupif(get_child(&sub.node, 0)?)), - Terminal::Verify(ref sub) => wrap_err(Self::cast_verify(get_child(&sub.node, 0)?)), - Terminal::NonZero(ref sub) => wrap_err(Self::cast_nonzero(get_child(&sub.node, 0)?)), - Terminal::ZeroNotEqual(ref sub) => { - wrap_err(Self::cast_zeronotequal(get_child(&sub.node, 0)?)) - } - Terminal::AndB(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::and_b(ltype, rtype)) - } - Terminal::AndV(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::and_v(ltype, rtype)) - } - Terminal::OrB(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::or_b(ltype, rtype)) - } - Terminal::OrD(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::or_d(ltype, rtype)) - } - Terminal::OrC(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::or_c(ltype, rtype)) - } - Terminal::OrI(ref l, ref r) => { - let ltype = get_child(&l.node, 0)?; - let rtype = get_child(&r.node, 1)?; - wrap_err(Self::or_i(ltype, rtype)) - } - Terminal::AndOr(ref a, ref b, ref c) => { - let atype = get_child(&a.node, 0)?; - let btype = get_child(&b.node, 1)?; - let ctype = get_child(&c.node, 2)?; - wrap_err(Self::and_or(atype, btype, ctype)) - } - Terminal::Thresh(k, ref subs) => { - if k == 0 { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::ZeroThreshold, - }); - } - if k > subs.len() { - return Err(Error { - fragment: fragment.clone(), - error: ErrorKind::OverThreshold(k, subs.len()), - }); - } - - let mut last_err_frag = None; - let res = Self::threshold(k, subs.len(), |n| match get_child(&subs[n].node, n) { - Ok(x) => Ok(x), - Err(e) => { - last_err_frag = Some(e.fragment); - Err(e.error) - } - }); - - res.map_err(|kind| Error { - fragment: last_err_frag.unwrap_or_else(|| fragment.clone()), - error: kind, - }) - } - }; - if let Ok(ref ret) = ret { - ret.sanity_checks() - } - ret - } - - /// Compute the type of a fragment, given a function to look up - /// the types of its children. - fn type_check_with_child( - fragment: &Terminal, - mut child: C, - ) -> Result> - where - C: FnMut(usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let get_child = |_sub, n| Ok(child(n)); - Self::type_check_common(fragment, get_child) - } - - /// Compute the type of a fragment. - fn type_check(fragment: &Terminal) -> Result> - where - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - Self::type_check_common(fragment, |sub, _n| Self::type_check(sub)) - } } impl Property for Type { @@ -693,22 +524,12 @@ impl Property for Type { mall: Property::threshold(k, n, |n| Ok(sub_ck(n)?.mall))?, }) } +} - fn type_check_with_child( - _fragment: &Terminal, - mut _child: C, - ) -> Result> - where - C: FnMut(usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - unreachable!() - } - +impl Type { /// Compute the type of a fragment assuming all the children of /// Miniscript have been computed already. - fn type_check(fragment: &Terminal) -> Result> + pub fn type_check(fragment: &Terminal) -> Result> where Pk: MiniscriptKey, Ctx: ScriptContext, diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index fabf27d3c..50ae3a6d2 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -9,6 +9,7 @@ use core::{cmp, f64, fmt, hash, mem}; #[cfg(feature = "std")] use std::error; +use bitcoin::{absolute, Sequence}; use sync::Arc; use crate::miniscript::context::SigType; @@ -413,6 +414,165 @@ impl Property for CompilerExtData { } } +/// None-returning function to help type inference when we need a +/// closure that simply returns `None` +fn return_none(_: usize) -> Option { + None +} + +impl CompilerExtData { + /// Compute the type of a fragment, given a function to look up + /// the types of its children, if available and relevant for the + /// given fragment + pub fn type_check<'a, Pk, Ctx, C>( + fragment: &'a Terminal, + mut get_child: C, + ) -> Result> + where + C: FnMut(usize) -> Option, + Pk: MiniscriptKey, + Ctx: ScriptContext, + { + let mut get_child = |sub, n| { + get_child(n) + .map(Ok) + .unwrap_or_else(|| Self::type_check(sub, return_none)) + }; + + let wrap_err = |result: Result| { + result.map_err(|kind| types::Error { fragment: fragment.clone(), error: kind }) + }; + + let ret = match *fragment { + Terminal::True => Ok(Self::from_true()), + Terminal::False => Ok(Self::from_false()), + Terminal::PkK(..) => Ok(Self::from_pk_k::()), + Terminal::PkH(..) | Terminal::RawPkH(..) => Ok(Self::from_pk_h::()), + Terminal::Multi(k, ref pks) | Terminal::MultiA(k, ref pks) => { + if k == 0 { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::ZeroThreshold, + }); + } + if k > pks.len() { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::OverThreshold(k, pks.len()), + }); + } + match *fragment { + Terminal::Multi(..) => Ok(Self::from_multi(k, pks.len())), + Terminal::MultiA(..) => Ok(Self::from_multi_a(k, pks.len())), + _ => unreachable!(), + } + } + Terminal::After(t) => { + // Note that for CLTV this is a limitation not of Bitcoin but Miniscript. The + // number on the stack would be a 5 bytes signed integer but Miniscript's B type + // only consumes 4 bytes from the stack. + if t == absolute::LockTime::ZERO.into() { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::InvalidTime, + }); + } + Ok(Self::from_after(t.into())) + } + Terminal::Older(t) => { + if t == Sequence::ZERO || !t.is_relative_lock_time() { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::InvalidTime, + }); + } + Ok(Self::from_older(t)) + } + Terminal::Sha256(..) => Ok(Self::from_sha256()), + Terminal::Hash256(..) => Ok(Self::from_hash256()), + Terminal::Ripemd160(..) => Ok(Self::from_ripemd160()), + Terminal::Hash160(..) => Ok(Self::from_hash160()), + Terminal::Alt(ref sub) => wrap_err(Self::cast_alt(get_child(&sub.node, 0)?)), + Terminal::Swap(ref sub) => wrap_err(Self::cast_swap(get_child(&sub.node, 0)?)), + Terminal::Check(ref sub) => wrap_err(Self::cast_check(get_child(&sub.node, 0)?)), + Terminal::DupIf(ref sub) => wrap_err(Self::cast_dupif(get_child(&sub.node, 0)?)), + Terminal::Verify(ref sub) => wrap_err(Self::cast_verify(get_child(&sub.node, 0)?)), + Terminal::NonZero(ref sub) => wrap_err(Self::cast_nonzero(get_child(&sub.node, 0)?)), + Terminal::ZeroNotEqual(ref sub) => { + wrap_err(Self::cast_zeronotequal(get_child(&sub.node, 0)?)) + } + Terminal::AndB(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::and_b(ltype, rtype)) + } + Terminal::AndV(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::and_v(ltype, rtype)) + } + Terminal::OrB(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::or_b(ltype, rtype)) + } + Terminal::OrD(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::or_d(ltype, rtype)) + } + Terminal::OrC(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::or_c(ltype, rtype)) + } + Terminal::OrI(ref l, ref r) => { + let ltype = get_child(&l.node, 0)?; + let rtype = get_child(&r.node, 1)?; + wrap_err(Self::or_i(ltype, rtype)) + } + Terminal::AndOr(ref a, ref b, ref c) => { + let atype = get_child(&a.node, 0)?; + let btype = get_child(&b.node, 1)?; + let ctype = get_child(&c.node, 2)?; + wrap_err(Self::and_or(atype, btype, ctype)) + } + Terminal::Thresh(k, ref subs) => { + if k == 0 { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::ZeroThreshold, + }); + } + if k > subs.len() { + return Err(types::Error { + fragment: fragment.clone(), + error: types::ErrorKind::OverThreshold(k, subs.len()), + }); + } + + let mut last_err_frag = None; + let res = Self::threshold(k, subs.len(), |n| match get_child(&subs[n].node, n) { + Ok(x) => Ok(x), + Err(e) => { + last_err_frag = Some(e.fragment); + Err(e.error) + } + }); + + res.map_err(|kind| types::Error { + fragment: last_err_frag.unwrap_or_else(|| fragment.clone()), + error: kind, + }) + } + }; + if let Ok(ref ret) = ret { + ret.sanity_checks() + } + ret + } +} + /// Miniscript AST fragment with additional data needed by the compiler #[derive(Clone, Debug)] struct AstElemExt { @@ -441,7 +601,7 @@ impl AstElemExt { impl AstElemExt { fn terminal(ast: Terminal) -> AstElemExt { AstElemExt { - comp_ext_data: CompilerExtData::type_check(&ast).unwrap(), + comp_ext_data: CompilerExtData::type_check(&ast, return_none).unwrap(), ms: Arc::new(Miniscript::from_ast(ast).expect("Terminal creation must always succeed")), } } @@ -452,15 +612,15 @@ impl AstElemExt { r: &AstElemExt, ) -> Result, types::Error> { let lookup_ext = |n| match n { - 0 => l.comp_ext_data, - 1 => r.comp_ext_data, + 0 => Some(l.comp_ext_data), + 1 => Some(r.comp_ext_data), _ => unreachable!(), }; //Types and ExtData are already cached and stored in children. So, we can //type_check without cache. For Compiler extra data, we supply a cache. let ty = types::Type::type_check(&ast)?; let ext = types::ExtData::type_check(&ast)?; - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext)?; + let comp_ext_data = CompilerExtData::type_check(&ast, lookup_ext)?; Ok(AstElemExt { ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), comp_ext_data, @@ -474,16 +634,16 @@ impl AstElemExt { c: &AstElemExt, ) -> Result, types::Error> { let lookup_ext = |n| match n { - 0 => a.comp_ext_data, - 1 => b.comp_ext_data, - 2 => c.comp_ext_data, + 0 => Some(a.comp_ext_data), + 1 => Some(b.comp_ext_data), + 2 => Some(c.comp_ext_data), _ => unreachable!(), }; //Types and ExtData are already cached and stored in children. So, we can //type_check without cache. For Compiler extra data, we supply a cache. let ty = types::Type::type_check(&ast)?; let ext = types::ExtData::type_check(&ast)?; - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext)?; + let comp_ext_data = CompilerExtData::type_check(&ast, lookup_ext)?; Ok(AstElemExt { ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), comp_ext_data,