diff --git a/src/lib.rs b/src/lib.rs index 7b798915e..d5b136b31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,7 @@ pub mod miniscript; pub mod plan; pub mod policy; pub mod psbt; +pub mod threshold; #[cfg(test)] mod test_utils; @@ -861,7 +862,7 @@ mod prelude { rc, slice, string::{String, ToString}, sync, - vec::Vec, + vec::{self, Vec}, }; #[cfg(any(feature = "std", test))] pub use std::{ @@ -872,7 +873,7 @@ mod prelude { string::{String, ToString}, sync, sync::Mutex, - vec::Vec, + vec::{self, Vec}, }; #[cfg(all(not(feature = "std"), not(test)))] diff --git a/src/threshold.rs b/src/threshold.rs new file mode 100644 index 000000000..f04aa141c --- /dev/null +++ b/src/threshold.rs @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! A generic (k,n)-threshold type. + +use core::fmt; + +use crate::prelude::{vec, Vec}; + +/// A (k, n)-threshold. +/// +/// This type maintains the following invariants: +/// - n > 0 +/// - k > 0 +/// - k <= n +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Threshold { + k: usize, + v: Vec, +} + +impl Threshold { + /// Creates a `Theshold` after checking that invariants hold. + pub fn new(k: usize, v: Vec) -> Result, Error> { + if v.len() == 0 { + Err(Error::ZeroN) + } else if k == 0 { + Err(Error::ZeroK) + } else if k > v.len() { + Err(Error::BigK) + } else { + Ok(Threshold { k, v }) + } + } + + /// Creates a `Theshold` without checking that invariants hold. + #[cfg(test)] + pub fn new_unchecked(k: usize, v: Vec) -> Threshold { Threshold { k, v } } + + /// Returns `k`, the threshold value. + pub fn k(&self) -> usize { self.k } + + /// Returns `n`, the total number of elements in the threshold. + pub fn n(&self) -> usize { self.v.len() } + + /// Returns a read-only iterator over the threshold elements. + pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() } + + /// Creates an iterator over the threshold elements. + pub fn into_iter(self) -> vec::IntoIter { self.v.into_iter() } + + /// Creates an iterator over the threshold elements. + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { self.v.iter_mut() } + + /// Returns the threshold elements, consuming self. + pub fn into_elements(self) -> Vec { self.v } + + /// Creates a new (k, n)-threshold using a newly mapped vector. + /// + /// Typically this function is called after collecting a vector that was + /// created by iterating this threshold. E.g., + /// + /// `thresh.mapped((0..thresh.n()).map(|element| some_function(element)).collect())` + /// + /// # Panics + /// + /// Panics if the new vector is not the same length as the + /// original i.e., `new.len() != self.n()`. + pub(crate) fn mapped(&self, new: Vec) -> Threshold { + if self.n() != new.len() { + panic!("cannot map to a different length vector") + } + Threshold { k: self.k(), v: new } + } +} + +/// An error attempting to construct a `Threshold`. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum Error { + /// Threshold `n` value must be non-zero. + ZeroN, + /// Threshold `k` value must be non-zero. + ZeroK, + /// Threshold `k` value must be <= `n`. + BigK, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Error::*; + + match *self { + ZeroN => f.write_str("threshold `n` value must be non-zero"), + ZeroK => f.write_str("threshold `k` value must be non-zero"), + BigK => f.write_str("threshold `k` value must be <= `n`"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn cause(&self) -> Option<&dyn std::error::Error> { + use Error::*; + + match *self { + ZeroN | ZeroK | BigK => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn threshold_constructor_valid() { + let v = vec![1, 2, 3]; + let n = 3; + + for k in 1..=3 { + let thresh = Threshold::new(k, v.clone()).expect("failed to create threshold"); + assert_eq!(thresh.k(), k); + assert_eq!(thresh.n(), n); + } + } + + #[test] + fn threshold_constructor_invalid() { + let v = vec![1, 2, 3]; + assert!(Threshold::new(0, v.clone()).is_err()); + assert!(Threshold::new(4, v.clone()).is_err()); + } +}