diff --git a/devenv.lock b/devenv.lock index c326bd43b..4b625b566 100644 --- a/devenv.lock +++ b/devenv.lock @@ -39,10 +39,10 @@ "flake-compat": { "flake": false, "locked": { - "lastModified": 1696426674, + "lastModified": 1733328505, "owner": "edolstra", "repo": "flake-compat", - "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", "type": "github" }, "original": { @@ -51,10 +51,32 @@ "type": "github" } }, + "git-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1733665616, + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "d8c02f0ffef0ef39f6063731fc539d8c71eb463a", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, "gitignore": { "inputs": { "nixpkgs": [ - "pre-commit-hooks", + "git-hooks", "nixpkgs" ] }, @@ -88,63 +110,44 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1704290814, + "lastModified": 1733730953, "owner": "NixOS", "repo": "nixpkgs", - "rev": "70bdadeb94ffc8806c0570eb5c2695ad29f0e421", + "rev": "7109b680d161993918b0a126f38bc39763e5a709", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-23.05", + "ref": "nixos-24.05", "repo": "nixpkgs", "type": "github" } }, "nixpkgs-stable_2": { "locked": { - "lastModified": 1728193676, + "lastModified": 1704290814, "owner": "NixOS", "repo": "nixpkgs", - "rev": "ecbc1ca8ffd6aea8372ad16be9ebbb39889e55b6", + "rev": "70bdadeb94ffc8806c0570eb5c2695ad29f0e421", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-24.05", + "ref": "nixos-23.05", "repo": "nixpkgs", "type": "github" } }, - "pre-commit-hooks": { - "inputs": { - "flake-compat": "flake-compat", - "gitignore": "gitignore", - "nixpkgs": [ - "nixpkgs" - ], - "nixpkgs-stable": "nixpkgs-stable_2" - }, - "locked": { - "lastModified": 1728092656, - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "rev": "1211305a5b237771e13fcca0c51e60ad47326a9a", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "type": "github" - } - }, "root": { "inputs": { "devenv": "devenv", "fenix": "fenix", + "git-hooks": "git-hooks", "nixpkgs": "nixpkgs", - "nixpkgs-stable": "nixpkgs-stable", - "pre-commit-hooks": "pre-commit-hooks" + "nixpkgs-stable": "nixpkgs-stable_2", + "pre-commit-hooks": [ + "git-hooks" + ] } }, "rust-analyzer-src": { diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index a2ff0eed8..312127e84 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -15,9 +15,8 @@ use crate::extension::resolution::{ use std::borrow::Cow; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ExtensionId, ExtensionSet}; +use crate::extension::{ExtensionId, ExtensionSet, ExtensionRegistry}; use crate::types::{EdgeKind, Signature, Substitution}; -use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; @@ -437,7 +436,7 @@ pub trait OpTrait: Sized + Clone { /// Apply a type-level substitution to this OpType, i.e. replace /// [type variables](crate::types::TypeArg::new_var_use) with new types. - fn substitute(&self, _subst: &Substitution) -> Self { + fn substitute(&self, _subst: &Substitution, _reg: &ExtensionRegistry) -> Self { self.clone() } } diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 44c78eae3..fcf841acb 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; -use crate::extension::ExtensionSet; +use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::types::{EdgeKind, Signature, Type, TypeRow}; use crate::Direction; @@ -42,11 +42,11 @@ impl DataflowOpTrait for TailLoop { ) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - just_inputs: self.just_inputs.substitute(subst), - just_outputs: self.just_outputs.substitute(subst), - rest: self.rest.substitute(subst), + just_inputs: self.just_inputs.substitute(subst, reg), + just_outputs: self.just_outputs.substitute(subst, reg), + rest: self.rest.substitute(subst, reg), extension_delta: self.extension_delta.substitute(subst), } } @@ -121,11 +121,11 @@ impl DataflowOpTrait for Conditional { ) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), - other_inputs: self.other_inputs.substitute(subst), - outputs: self.outputs.substitute(subst), + sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst, reg)).collect(), + other_inputs: self.other_inputs.substitute(subst, reg), + outputs: self.outputs.substitute(subst, reg), extension_delta: self.extension_delta.substitute(subst), } } @@ -159,9 +159,9 @@ impl DataflowOpTrait for CFG { Cow::Borrowed(&self.signature) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - signature: self.signature.substitute(subst), + signature: self.signature.substitute(subst, reg), } } } @@ -248,11 +248,11 @@ impl OpTrait for DataflowBlock { } } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - inputs: self.inputs.substitute(subst), - other_outputs: self.other_outputs.substitute(subst), - sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), + inputs: self.inputs.substitute(subst, reg), + other_outputs: self.other_outputs.substitute(subst, reg), + sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst, reg)).collect(), extension_delta: self.extension_delta.substitute(subst), } } @@ -282,9 +282,9 @@ impl OpTrait for ExitBlock { } } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - cfg_outputs: self.cfg_outputs.substitute(subst), + cfg_outputs: self.cfg_outputs.substitute(subst, reg), } } } @@ -351,9 +351,9 @@ impl OpTrait for Case { ::TAG } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - signature: self.signature.substitute(subst), + signature: self.signature.substitute(subst, reg), } } } @@ -372,6 +372,8 @@ impl Case { #[cfg(test)] mod test { + use std::borrow::Borrow; + use crate::{ extension::{ prelude::{qb_t, usize_t, PRELUDE_ID}, @@ -399,9 +401,9 @@ mod test { TypeArg::Extensions { es: PRELUDE_ID.into(), }, - ], + ]), &PRELUDE_REGISTRY, - )); + ); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( dfb2.inner_signature(), @@ -425,9 +427,9 @@ mod test { elems: vec![usize_t().into(); 3], }, qb_t().into(), - ], + ]), &PRELUDE_REGISTRY, - )); + ); let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants assert_eq!( cond2.signature(), @@ -453,12 +455,12 @@ mod test { TypeArg::Extensions { es: PRELUDE_ID.into(), }, - ], + ]), &PRELUDE_REGISTRY, - )); + ); assert_eq!( tail2.signature(), - Signature::new( + ow r::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index a0ad220e7..7c91595fe 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -167,16 +167,16 @@ impl DataflowOpTrait for ExtensionOp { Cow::Borrowed(&self.signature) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { let args = self .args .iter() - .map(|ta| ta.substitute(subst)) + .map(|ta| ta.substitute(subst, reg)) .collect::>(); - let signature = self.signature.substitute(subst); + let signature = self.signature.substitute(subst, reg); debug_assert_eq!( self.def - .compute_signature(&args, subst.extension_registry()) + .compute_signature(&args, reg) .as_ref(), Ok(&signature) ); @@ -280,10 +280,10 @@ impl DataflowOpTrait for OpaqueOp { Cow::Borrowed(&self.signature) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - args: self.args.iter().map(|ta| ta.substitute(subst)).collect(), - signature: self.signature.substitute(subst), + args: self.args.iter().map(|ta| ta.substitute(subst, reg)).collect(), + signature: self.signature.substitute(subst, reg), ..self.clone() } } diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 9e6208b65..dfc65e18a 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -54,7 +54,7 @@ pub trait DataflowOpTrait: Sized { /// Apply a type-level substitution to this OpType, i.e. replace /// [type variables](TypeArg::new_var_use) with new types. - fn substitute(&self, _subst: &Substitution) -> Self; + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self; } /// Helpers to construct input and output nodes @@ -116,9 +116,9 @@ impl DataflowOpTrait for Input { Cow::Owned(Signature::new(TypeRow::new(), self.types.clone())) } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { Self { - types: self.types.substitute(subst), + types: self.types.substitute(subst, reg), } } } @@ -140,9 +140,9 @@ impl DataflowOpTrait for Output { None } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { Self { - types: self.types.substitute(subst), + types: self.types.substitute(subst, reg), } } } @@ -172,8 +172,8 @@ impl OpTrait for T { DataflowOpTrait::static_input(self) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { - DataflowOpTrait::substitute(self, subst) + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { + DataflowOpTrait::substitute(self, subst, reg) } } impl StaticTag for T { @@ -212,16 +212,16 @@ impl DataflowOpTrait for Call { Some(EdgeKind::Function(self.called_function_type().clone())) } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { let type_args = self .type_args .iter() - .map(|ta| ta.substitute(subst)) + .map(|ta| ta.substitute(subst, reg)) .collect::>(); - let instantiation = self.instantiation.substitute(subst); + let instantiation = self.instantiation.substitute(subst, reg); debug_assert_eq!( self.func_sig - .instantiate(&type_args, subst.extension_registry()) + .instantiate(&type_args, reg) .as_ref(), Ok(&instantiation) ); @@ -325,9 +325,9 @@ impl DataflowOpTrait for CallIndirect { Cow::Owned(s) } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { Self { - signature: self.signature.substitute(subst), + signature: self.signature.substitute(subst, reg), } } } @@ -356,7 +356,7 @@ impl DataflowOpTrait for LoadConstant { Some(EdgeKind::Const(self.constant_type().clone())) } - fn substitute(&self, _subst: &Substitution) -> Self { + fn substitute(&self, _subst: &Substitution, _reg: &ExtensionRegistry) -> Self { // Constants cannot refer to TypeArgs, so neither can loading them self.clone() } @@ -420,16 +420,16 @@ impl DataflowOpTrait for LoadFunction { Some(EdgeKind::Function(self.func_sig.clone())) } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { let type_args = self .type_args .iter() - .map(|ta| ta.substitute(subst)) + .map(|ta| ta.substitute(subst, reg)) .collect::>(); - let instantiation = self.instantiation.substitute(subst); + let instantiation = self.instantiation.substitute(subst, reg); debug_assert_eq!( self.func_sig - .instantiate(&type_args, subst.extension_registry()) + .instantiate(&type_args, reg) .as_ref(), Ok(&instantiation) ); @@ -528,9 +528,9 @@ impl DataflowOpTrait for DFG { self.inner_signature() } - fn substitute(&self, subst: &Substitution) -> Self { + fn substitute(&self, subst: &Substitution, reg: &ExtensionRegistry) -> Self { Self { - signature: self.signature.substitute(subst), + signature: self.signature.substitute(subst, reg), } } } diff --git a/hugr-core/src/ops/sum.rs b/hugr-core/src/ops/sum.rs index 7cc9fb577..55b268cb4 100644 --- a/hugr-core/src/ops/sum.rs +++ b/hugr-core/src/ops/sum.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use super::dataflow::DataflowOpTrait; use super::{impl_op_name, OpTag}; +use crate::extension::ExtensionRegistry; use crate::types::{EdgeKind, Signature, Type, TypeRow}; /// An operation that creates a tagged sum value from one of its variants. @@ -56,9 +57,9 @@ impl DataflowOpTrait for Tag { Some(EdgeKind::StateOrder) } - fn substitute(&self, subst: &crate::types::Substitution) -> Self { + fn substitute(&self, subst: &crate::types::Substitution, reg: &ExtensionRegistry) -> Self { Self { - variants: self.variants.iter().map(|r| r.substitute(subst)).collect(), + variants: self.variants.iter().map(|r| r.substitute(subst, reg)).collect(), tag: self.tag, } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 72aae7b7e..dc91fc01e 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -448,7 +448,7 @@ impl TypeBase { /// * If [Type::validate]`(false)` returns successfully, this method will return a Vec containing exactly one type /// * If [Type::validate]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) /// return a Vec containing any number of [Type]s. These may (or not) pass [Type::validate] - fn substitute(&self, t: &Substitution) -> Vec { + fn substitute(&self, t: &Substitution, reg: &ExtensionRegistry) -> Vec { match &self.0 { TypeEnum::RowVar(rv) => rv.substitute(t), TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], @@ -458,18 +458,18 @@ impl TypeBase { }; vec![ty.into_()] } - TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t))], - TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t))], + TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t, reg))], + TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t, reg))], TypeEnum::Sum(SumType::General { rows }) => { - vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t)))] + vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t, reg)))] } } } } impl Type { - fn substitute1(&self, s: &Substitution) -> Self { - let v = self.substitute(s); + fn substitute1(&self, s: &Substitution, reg: &ExtensionRegistry) -> Self { + let v = self.substitute(s, reg); let [r] = v.try_into().unwrap(); // No row vars, so every Type produces exactly one r } @@ -548,7 +548,7 @@ impl From for TypeRV { /// Details a replacement of type variables with a finite list of known values. /// (Variables out of the range of the list will result in a panic) -pub struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); +pub struct Substitution<'a>(&'a [TypeArg]); impl<'a> Substitution<'a> { /// Create a new Substitution given the replacement values (indexed @@ -557,8 +557,8 @@ impl<'a> Substitution<'a> { /// containing a type-variable. /// /// [TypeDef]: crate::extension::TypeDef - pub fn new(items: &'a [TypeArg], exts: &'a ExtensionRegistry) -> Self { - Self(items, exts) + pub fn new(items: &'a [TypeArg]) -> Self { + Self(items) } pub(crate) fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { @@ -599,10 +599,6 @@ impl<'a> Substitution<'a> { _ => panic!("Not a type or list of types - call validate() ?"), } } - - pub(crate) fn extension_registry(&self) -> &ExtensionRegistry { - self.1 - } } pub(crate) fn check_typevar_decl( diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 81bc814ac..f1a08098f 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -109,14 +109,14 @@ impl CustomType { }) } - pub(super) fn substitute(&self, tr: &Substitution) -> Self { + pub(super) fn substitute(&self, tr: &Substitution, reg: &ExtensionRegistry) -> Self { let args = self .args .iter() - .map(|arg| arg.substitute(tr)) + .map(|arg| arg.substitute(tr, reg)) .collect::>(); let bound = self - .get_type_def(tr.extension_registry()) + .get_type_def(reg) .expect("validate should rule this out") .bound(&args); debug_assert!(self.bound.contains(bound)); diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 66dff8245..23d65453a 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -125,7 +125,7 @@ impl PolyFuncTypeBase { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. check_type_args(args, &self.params)?; - Ok(self.body.substitute(&Substitution(args, ext_reg))) + Ok(self.body.substitute(&Substitution(args), ext_reg)) } /// Validates this instance, checking that the types in the body are diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index c5d3459e1..2112acd34 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -61,10 +61,10 @@ impl FuncTypeBase { self.with_extension_delta(crate::extension::prelude::PRELUDE_ID) } - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { + pub(crate) fn substitute(&self, tr: &Substitution, reg: &ExtensionRegistry) -> Self { Self { - input: self.input.substitute(tr), - output: self.output.substitute(tr), + input: self.input.substitute(tr, reg), + output: self.output.substitute(tr, reg), extension_reqs: self.extension_reqs.substitute(tr), } } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index be688ad7c..9d39c327e 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -295,11 +295,11 @@ impl TypeArg { } } - pub(crate) fn substitute(&self, t: &Substitution) -> Self { + pub(crate) fn substitute(&self, t: &Substitution, reg: &ExtensionRegistry) -> Self { match self { TypeArg::Type { ty } => { // RowVariables are represented as TypeArg::Variable - ty.substitute1(t).into() + ty.substitute1(t, reg).into() } TypeArg::BoundedNat { .. } | TypeArg::String { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's TypeArg::Sequence { elems } => { @@ -314,7 +314,7 @@ impl TypeArg { // So, anything that doesn't produce a Type, was a row variable => multiple Types elems .iter() - .flat_map(|ta| match ta.substitute(t) { + .flat_map(|ta| match ta.substitute(t, reg) { ty @ TypeArg::Type { .. } => vec![ty], TypeArg::Sequence { elems } => elems, _ => panic!("Expected Type or row of Types"), @@ -323,7 +323,7 @@ impl TypeArg { } _ => { // not types, no need to flatten (and mustn't, in case of nested Sequences) - elems.iter().map(|ta| ta.substitute(t)).collect() + elems.iter().map(|ta| ta.substitute(t, reg)).collect() } }; TypeArg::Sequence { elems } @@ -551,7 +551,7 @@ mod test { }; check_type_arg(&outer_arg, &outer_param).unwrap(); - let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg], &PRELUDE_REGISTRY)); + let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg]), &PRELUDE_REGISTRY); assert_eq!( outer_arg2, vec![bool_t().into(), TypeArg::UNIT, usize_t().into()].into() @@ -594,7 +594,7 @@ mod test { let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); check_type_arg(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = - good_arg.substitute(&Substitution(&[row_var_arg.clone()], &PRELUDE_REGISTRY)); + good_arg.substitute(&Substitution(&[row_var_arg.clone()]), &PRELUDE_REGISTRY); check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index c807b872f..8dfb2a434 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -71,9 +71,9 @@ impl TypeRowBase { /// Applies a substitution to the row. /// For `TypeRowRV`, note this may change the length of the row. /// For `TypeRow`, guaranteed not to change the length of the row. - pub(crate) fn substitute(&self, s: &Substitution) -> Self { + pub(crate) fn substitute(&self, s: &Substitution, reg: &ExtensionRegistry) -> Self { self.iter() - .flat_map(|ty| ty.substitute(s)) + .flat_map(|ty| ty.substitute(s, reg)) .collect::>() .into() } diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index d25eff8cd..f0b3dd94c 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -20,8 +20,8 @@ use itertools::Itertools as _; /// /// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic /// signature then the hugr is untouched. -pub fn monomorphize(mut h: Hugr, reg: &ExtensionRegistry) -> Hugr { - let validate = |h: &Hugr| h.validate(reg).unwrap_or_else(|e| panic!("{e}")); +pub fn monomorphize(mut h: Hugr) -> Hugr { + let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}")); #[cfg(debug_assertions)] validate(&h); @@ -29,7 +29,7 @@ pub fn monomorphize(mut h: Hugr, reg: &ExtensionRegistry) -> Hugr { let root = h.root(); // If the root is a polymorphic function, then there are no external calls, so nothing to do if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(&mut h, root, None, &mut HashMap::new(), reg); + mono_scan(&mut h, root, None, &mut HashMap::new()); if !h.get_optype(root).is_module() { return remove_polyfuncs(h); } @@ -86,7 +86,6 @@ fn mono_scan( parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, - reg: &ExtensionRegistry, ) { for old_ch in h.children(parent).collect_vec() { let ch_op = h.get_optype(old_ch); @@ -97,17 +96,17 @@ fn mono_scan( // Perform substitution, and recurse into containers (mono_scan does nothing if no children) let ch = if let Some(ref mut inst) = subst_into { let new_ch = - h.add_node_with_parent(inst.target_container, ch_op.substitute(inst.subst)); + h.add_node_with_parent(inst.target_container, ch_op.substitute(inst.subst, h.extensions())); inst.node_map.insert(old_ch, new_ch); let mut inst = Instantiating { target_container: new_ch, node_map: inst.node_map, ..**inst }; - mono_scan(h, old_ch, Some(&mut inst), cache, reg); + mono_scan(h, old_ch, Some(&mut inst), cache); new_ch } else { - mono_scan(h, old_ch, None, cache, reg); + mono_scan(h, old_ch, None, cache); old_ch }; @@ -119,7 +118,7 @@ fn mono_scan( ( &c.type_args, mono_sig.clone(), - OpType::from(Call::try_new(mono_sig.into(), [], reg).unwrap()), + OpType::from(Call::try_new(mono_sig.into(), [], h.extensions()).unwrap()), ) } OpType::LoadFunction(lf) => { @@ -128,7 +127,7 @@ fn mono_scan( ( &lf.type_args, mono_sig.clone(), - LoadFunction::try_new(mono_sig.into(), [], reg) + LoadFunction::try_new(mono_sig.into(), [], h.extensions()) .unwrap() .into(), ) @@ -140,7 +139,7 @@ fn mono_scan( }; let fn_inp = ch_op.static_input_port().unwrap(); let tgt = h.static_source(old_ch).unwrap(); // Use old_ch as edges not copied yet - let new_tgt = instantiate(h, tgt, type_args.clone(), mono_sig.clone(), cache, reg); + let new_tgt = instantiate(h, tgt, type_args.clone(), mono_sig.clone(), cache); let fn_out = { let func = h.get_optype(new_tgt).as_func_defn().unwrap(); debug_assert_eq!(func.signature, mono_sig.into()); @@ -159,7 +158,6 @@ fn instantiate( type_args: Vec, mono_sig: Signature, cache: &mut Instantiations, - reg: &ExtensionRegistry, ) -> Node { let for_func = cache.entry(poly_func).or_insert_with(|| { // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. @@ -201,11 +199,11 @@ fn instantiate( // Now make the instantiation let mut node_map = HashMap::new(); let mut inst = Instantiating { - subst: &Substitution::new(&type_args, reg), + subst: &Substitution::new(&type_args), target_container: mono_tgt, node_map: &mut node_map, }; - mono_scan(h, poly_func, Some(&mut inst), cache, reg); + mono_scan(h, poly_func, Some(&mut inst), cache); // Copy edges...we have built a node_map for every node in the function. // Note we could avoid building the "large" map (smaller than the Hugr we've just created) // by doing this during recursion, but we'd need to be careful with nonlocal edges - @@ -324,8 +322,8 @@ mod test { let dfg_builder = DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); - let hugr = dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap(); - let hugr2 = monomorphize(hugr.clone(), &PRELUDE_REGISTRY); + let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); + let hugr2 = monomorphize(hugr.clone()); assert_eq!(hugr, hugr2); } @@ -348,7 +346,6 @@ mod test { &FuncID::::from(fb.container_node()), &[tv0().into()], [elem], - &EMPTY_REG, )?; fb.finish_with_outputs(tag.outputs())? }; @@ -360,7 +357,7 @@ mod test { PolyFuncType::new([TypeBound::Copyable.into()], sig), )?; let [elem] = fb.input_wires_arr(); - let pair = fb.call(db.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY)?; + let pair = fb.call(db.handle(), &[tv0().into()], [elem])?; let [elem1, elem2] = fb .add_dataflow_op(UnpackTuple::new(vec![tv0(); 2].into()), pair.outputs())? @@ -374,24 +371,24 @@ mod test { let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; let [elem] = fb.input_wires_arr(); let [res1] = fb - .call(tr.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)? + .call(tr.handle(), &[usize_t().into()], [elem])? .outputs_arr(); - let pair = fb.call(db.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)?; + let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?; let pty = pair_type(usize_t()).into(); let [res2] = fb - .call(tr.handle(), &[pty], pair.outputs(), &PRELUDE_REGISTRY)? + .call(tr.handle(), &[pty], pair.outputs())? .outputs_arr(); fb.finish_with_outputs([res1, res2])?; } - let hugr = mb.finish_hugr(&PRELUDE_REGISTRY)?; + let hugr = mb.finish_hugr()?; assert_eq!( hugr.nodes() .filter(|n| hugr.get_optype(*n).is_func_defn()) .count(), 3 ); - let mono = monomorphize(hugr, &PRELUDE_REGISTRY); - mono.validate(&PRELUDE_REGISTRY)?; + let mono = monomorphize(hugr); + mono.validate()?; let mut funcs = list_funcs(&mono); let expected_mangled_names = [ @@ -410,7 +407,7 @@ mod test { ["double", "main", "triple"] ); - assert_eq!(monomorphize(mono.clone(), &PRELUDE_REGISTRY), mono); // Idempotent + assert_eq!(monomorphize(mono.clone()), mono); // Idempotent let nopoly = remove_polyfuncs(mono); let mut funcs = list_funcs(&nopoly); @@ -427,8 +424,6 @@ mod test { fn test_flattening() -> Result<(), Box> { //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func - let reg = - ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()])?; let tv0 = || Type::new_var_use(0, TypeBound::Any); let ity = || INT_TYPES[3].clone(); @@ -448,28 +443,28 @@ mod test { }; let pf2 = { let [inw] = pf2.input_wires_arr(); - let [usz] = pf2.call(mono_func.handle(), &[], [], ®)?.outputs_arr(); + let [usz] = pf2.call(mono_func.handle(), &[], [])?.outputs_arr(); pf2.finish_with_outputs([inw, usz])? }; // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not let [a, u] = pf1 - .call(pf2.handle(), &[tv0().into()], pf1.input_wires(), ®)? + .call(pf2.handle(), &[tv0().into()], pf1.input_wires())? .outputs_arr(); let [u1, u2] = pf1 - .call(pf2.handle(), &[usize_t().into()], [u], ®)? + .call(pf2.handle(), &[usize_t().into()], [u])? .outputs_arr(); let pf1 = pf1.finish_with_outputs([a, u1, u2])?; // Outer: two calls to pf1 with different TypeArgs let [_, u, _] = outer - .call(pf1.handle(), &[ity().into()], outer.input_wires(), ®)? + .call(pf1.handle(), &[ity().into()], outer.input_wires())? .outputs_arr(); let [_, u, _] = outer - .call(pf1.handle(), &[usize_t().into()], [u], ®)? + .call(pf1.handle(), &[usize_t().into()], [u])? .outputs_arr(); - let hugr = outer.finish_hugr_with_outputs([u], ®)?; + let hugr = outer.finish_hugr_with_outputs([u])?; - let mono_hugr = monomorphize(hugr, ®); - mono_hugr.validate(®)?; + let mono_hugr = monomorphize(hugr); + mono_hugr.validate()?; let funcs = list_funcs(&mono_hugr); let pf2_name = mangle_inner_func("pf1", "pf2"); assert_eq!( @@ -502,8 +497,6 @@ mod test { #[test] fn test_no_flatten_out_of_mono_func() -> Result<(), Box> { let ity = || INT_TYPES[4].clone(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), int_types::EXTENSION.to_owned()])?; let sig = Signature::new_endo(vec![usize_t(), ity()]); let mut dfg = DFGBuilder::new(sig.clone())?; let mut mono = dfg.define_function("id2", sig)?; @@ -518,15 +511,15 @@ mod test { let pf = pf.finish_with_outputs(outs)?; let [a, b] = mono.input_wires_arr(); let [a] = mono - .call(pf.handle(), &[usize_t().into()], [a], ®)? + .call(pf.handle(), &[usize_t().into()], [a])? .outputs_arr(); let [b] = mono - .call(pf.handle(), &[ity().into()], [b], ®)? + .call(pf.handle(), &[ity().into()], [b])? .outputs_arr(); let mono = mono.finish_with_outputs([a, b])?; - let c = dfg.call(mono.handle(), &[], dfg.input_wires(), ®)?; - let hugr = dfg.finish_hugr_with_outputs(c.outputs(), ®)?; - let mono_hugr = monomorphize(hugr, ®); + let c = dfg.call(mono.handle(), &[], dfg.input_wires())?; + let hugr = dfg.finish_hugr_with_outputs(c.outputs())?; + let mono_hugr = monomorphize(hugr); let mut funcs = list_funcs(&mono_hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); @@ -564,7 +557,7 @@ mod test { .define_function("main", Signature::new_endo(Type::UNIT)) .unwrap(); let func_ptr = builder - .load_func(foo.handle(), &[Type::UNIT.into()], &EMPTY_REG) + .load_func(foo.handle(), &[Type::UNIT.into()]) .unwrap(); let [r] = builder .call_indirect( @@ -576,10 +569,10 @@ mod test { .outputs_arr(); builder.finish_with_outputs([r]).unwrap() }; - module_builder.finish_hugr(&EMPTY_REG).unwrap() + module_builder.finish_hugr().unwrap() }; - let mono_hugr = remove_polyfuncs(monomorphize(hugr, &EMPTY_REG)); + let mono_hugr = remove_polyfuncs(monomorphize(hugr)); let funcs = list_funcs(&mono_hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));