diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index e509917bf..952490596 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; use crate::extension::simple_op::MakeExtensionOp; use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; -use crate::types::{EdgeKind, Signature}; +use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; @@ -369,7 +369,7 @@ pub trait StaticTag { #[enum_dispatch] /// Trait implemented by all OpType variants. -pub trait OpTrait { +pub trait OpTrait: Sized + Clone { /// A human-readable description of the operation. fn description(&self) -> &str; @@ -433,6 +433,12 @@ pub trait OpTrait { } .is_some() as usize } + + /// Apply a type-level substitution to this OpType, i.e. replace + /// [type variables](crate::types::TypeArg::new_var_use) with new types. + fn substitute(&self, _subst: &Substitution) -> Self { + self.clone() + } } /// Properties of child graphs of ops, if the op has children. diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index fe145d6a7..6f3cc5659 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -92,6 +92,8 @@ impl OpTrait for Const { fn static_output(&self) -> Option { Some(EdgeKind::Const(self.get_type())) } + + // Constants cannot refer to TypeArgs of the enclosing Hugr, so no substitute(). } impl From for Value { diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 94c817a53..44c78eae3 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -41,6 +41,15 @@ impl DataflowOpTrait for TailLoop { Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()), ) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + just_inputs: self.just_inputs.substitute(subst), + just_outputs: self.just_outputs.substitute(subst), + rest: self.rest.substitute(subst), + extension_delta: self.extension_delta.substitute(subst), + } + } } impl TailLoop { @@ -111,6 +120,15 @@ impl DataflowOpTrait for Conditional { .with_extension_delta(self.extension_delta.clone()), ) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), + other_inputs: self.other_inputs.substitute(subst), + outputs: self.outputs.substitute(subst), + extension_delta: self.extension_delta.substitute(subst), + } + } } impl Conditional { @@ -140,6 +158,12 @@ impl DataflowOpTrait for CFG { fn signature(&self) -> Cow<'_, Signature> { Cow::Borrowed(&self.signature) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + signature: self.signature.substitute(subst), + } + } } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -223,6 +247,15 @@ impl OpTrait for DataflowBlock { Direction::Outgoing => self.sum_rows.len(), } } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + inputs: self.inputs.substitute(subst), + other_outputs: self.other_outputs.substitute(subst), + sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), + extension_delta: self.extension_delta.substitute(subst), + } + } } impl OpTrait for ExitBlock { @@ -248,6 +281,12 @@ impl OpTrait for ExitBlock { Direction::Outgoing => 0, } } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + cfg_outputs: self.cfg_outputs.substitute(subst), + } + } } /// Functionality shared by DataflowBlock and Exit CFG block types. @@ -311,6 +350,12 @@ impl OpTrait for Case { fn tag(&self) -> OpTag { ::TAG } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + signature: self.signature.substitute(subst), + } + } } impl Case { @@ -324,3 +369,100 @@ impl Case { &self.signature.output } } + +#[cfg(test)] +mod test { + use crate::{ + extension::{ + prelude::{qb_t, usize_t, PRELUDE_ID}, + ExtensionSet, PRELUDE_REGISTRY, + }, + ops::{Conditional, DataflowOpTrait, DataflowParent}, + types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, + }; + + use super::{DataflowBlock, TailLoop}; + + #[test] + fn test_subst_dataflow_block() { + use crate::ops::OpTrait; + let tv0 = Type::new_var_use(0, TypeBound::Any); + let dfb = DataflowBlock { + inputs: vec![usize_t(), tv0.clone()].into(), + other_outputs: vec![tv0.clone()].into(), + sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()], + extension_delta: ExtensionSet::type_var(1), + }; + let dfb2 = dfb.substitute(&Substitution::new( + &[ + qb_t().into(), + TypeArg::Extensions { + es: PRELUDE_ID.into(), + }, + ], + &PRELUDE_REGISTRY, + )); + let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); + assert_eq!( + dfb2.inner_signature(), + Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) + .with_extension_delta(PRELUDE_ID) + ); + } + + #[test] + fn test_subst_conditional() { + let tv1 = Type::new_var_use(1, TypeBound::Any); + let cond = Conditional { + sum_rows: vec![usize_t().into(), tv1.clone().into()], + other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), + outputs: vec![usize_t(), tv1].into(), + extension_delta: ExtensionSet::new(), + }; + let cond2 = cond.substitute(&Substitution::new( + &[ + TypeArg::Sequence { + elems: vec![usize_t().into(); 3], + }, + qb_t().into(), + ], + &PRELUDE_REGISTRY, + )); + let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants + assert_eq!( + cond2.signature(), + Signature::new( + vec![st, Type::new_tuple(vec![usize_t(); 3])], + vec![usize_t(), qb_t()] + ) + ); + } + + #[test] + fn test_tail_loop() { + let tv0 = Type::new_var_use(0, TypeBound::Copyable); + let tail_loop = TailLoop { + just_inputs: vec![qb_t(), tv0.clone()].into(), + just_outputs: vec![tv0.clone(), qb_t()].into(), + rest: vec![tv0.clone()].into(), + extension_delta: ExtensionSet::type_var(1), + }; + let tail2 = tail_loop.substitute(&Substitution::new( + &[ + usize_t().into(), + TypeArg::Extensions { + es: PRELUDE_ID.into(), + }, + ], + &PRELUDE_REGISTRY, + )); + assert_eq!( + tail2.signature(), + Signature::new( + vec![qb_t(), usize_t(), usize_t()], + vec![usize_t(), qb_t(), usize_t()] + ) + .with_extension_delta(PRELUDE_ID) + ); + } +} diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 91b93e332..a0ad220e7 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -166,6 +166,26 @@ impl DataflowOpTrait for ExtensionOp { fn signature(&self) -> Cow<'_, Signature> { Cow::Borrowed(&self.signature) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + let args = self + .args + .iter() + .map(|ta| ta.substitute(subst)) + .collect::>(); + let signature = self.signature.substitute(subst); + debug_assert_eq!( + self.def + .compute_signature(&args, subst.extension_registry()) + .as_ref(), + Ok(&signature) + ); + Self { + def: self.def.clone(), + args, + signature, + } + } } /// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`]. @@ -259,6 +279,14 @@ impl DataflowOpTrait for OpaqueOp { fn signature(&self) -> Cow<'_, Signature> { Cow::Borrowed(&self.signature) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + args: self.args.iter().map(|ta| ta.substitute(subst)).collect(), + signature: self.signature.substitute(subst), + ..self.clone() + } + } } /// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names) diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 093857df6..9e6208b65 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -6,14 +6,14 @@ use super::{impl_op_name, OpTag, OpTrait}; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::ops::StaticTag; -use crate::types::{EdgeKind, PolyFuncType, Signature, Type, TypeArg, TypeRow}; +use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; use crate::{type_row, IncomingPort}; #[cfg(test)] use proptest_derive::Arbitrary; /// Trait implemented by all dataflow operations. -pub trait DataflowOpTrait { +pub trait DataflowOpTrait: Sized { /// Tag identifying the operation. const TAG: OpTag; @@ -51,6 +51,10 @@ pub trait DataflowOpTrait { fn static_input(&self) -> Option { None } + + /// Apply a type-level substitution to this OpType, i.e. replace + /// [type variables](TypeArg::new_var_use) with new types. + fn substitute(&self, _subst: &Substitution) -> Self; } /// Helpers to construct input and output nodes @@ -111,6 +115,12 @@ impl DataflowOpTrait for Input { // TODO: Store a cached signature Cow::Owned(Signature::new(TypeRow::new(), self.types.clone())) } + + fn substitute(&self, subst: &Substitution) -> Self { + Self { + types: self.types.substitute(subst), + } + } } impl DataflowOpTrait for Output { const TAG: OpTag = OpTag::Output; @@ -129,9 +139,15 @@ impl DataflowOpTrait for Output { fn other_output(&self) -> Option { None } + + fn substitute(&self, subst: &Substitution) -> Self { + Self { + types: self.types.substitute(subst), + } + } } -impl OpTrait for T { +impl OpTrait for T { fn description(&self) -> &str { DataflowOpTrait::description(self) } @@ -155,6 +171,10 @@ impl OpTrait for T { fn static_input(&self) -> Option { DataflowOpTrait::static_input(self) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + DataflowOpTrait::substitute(self, subst) + } } impl StaticTag for T { const TAG: OpTag = T::TAG; @@ -191,6 +211,26 @@ impl DataflowOpTrait for Call { fn static_input(&self) -> Option { Some(EdgeKind::Function(self.called_function_type().clone())) } + + fn substitute(&self, subst: &Substitution) -> Self { + let type_args = self + .type_args + .iter() + .map(|ta| ta.substitute(subst)) + .collect::>(); + let instantiation = self.instantiation.substitute(subst); + debug_assert_eq!( + self.func_sig + .instantiate(&type_args, subst.extension_registry()) + .as_ref(), + Ok(&instantiation) + ); + Self { + type_args, + instantiation, + func_sig: self.func_sig.clone(), + } + } } impl Call { /// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s @@ -284,6 +324,12 @@ impl DataflowOpTrait for CallIndirect { .insert(0, Type::new_function(self.signature.clone())); Cow::Owned(s) } + + fn substitute(&self, subst: &Substitution) -> Self { + Self { + signature: self.signature.substitute(subst), + } + } } /// Load a static constant in to the local dataflow graph. @@ -309,7 +355,13 @@ impl DataflowOpTrait for LoadConstant { fn static_input(&self) -> Option { Some(EdgeKind::Const(self.constant_type().clone())) } + + fn substitute(&self, _subst: &Substitution) -> Self { + // Constants cannot refer to TypeArgs, so neither can loading them + self.clone() + } } + impl LoadConstant { #[inline] /// The type of the constant loaded by this op. @@ -367,6 +419,26 @@ impl DataflowOpTrait for LoadFunction { fn static_input(&self) -> Option { Some(EdgeKind::Function(self.func_sig.clone())) } + + fn substitute(&self, subst: &Substitution) -> Self { + let type_args = self + .type_args + .iter() + .map(|ta| ta.substitute(subst)) + .collect::>(); + let instantiation = self.instantiation.substitute(subst); + debug_assert_eq!( + self.func_sig + .instantiate(&type_args, subst.extension_registry()) + .as_ref(), + Ok(&instantiation) + ); + Self { + func_sig: self.func_sig.clone(), + type_args, + instantiation, + } + } } impl LoadFunction { /// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit @@ -455,4 +527,10 @@ impl DataflowOpTrait for DFG { fn signature(&self) -> Cow<'_, Signature> { self.inner_signature() } + + fn substitute(&self, subst: &Substitution) -> Self { + Self { + signature: self.signature.substitute(subst), + } + } } diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index f321cfa9e..f18d5572d 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -83,6 +83,8 @@ impl OpTrait for FuncDefn { fn static_output(&self) -> Option { Some(EdgeKind::Function(self.signature.clone())) } + + // Cannot refer to TypeArgs of enclosing Hugr (it binds its own), so no substitute() } /// External function declaration, linked at runtime. @@ -113,6 +115,8 @@ impl OpTrait for FuncDecl { fn static_output(&self) -> Option { Some(EdgeKind::Function(self.signature.clone())) } + + // Cannot refer to TypeArgs of enclosing Hugr (the type binds its own), so no substitute() } /// A type alias definition, used only for debug/metadata. @@ -137,6 +141,9 @@ impl OpTrait for AliasDefn { fn tag(&self) -> OpTag { ::TAG } + + // Cannot refer to TypeArgs of enclosing Hugr (? - we planned to make this + // polymorphic so it binds its own, and we never combine binders), so no substitute() } /// A type alias declaration. Resolved at link time. @@ -177,4 +184,7 @@ impl OpTrait for AliasDecl { fn tag(&self) -> OpTag { ::TAG } + + // Cannot refer to TypeArgs of enclosing Hugr (? - we planned to make this + // polymorphic so it binds its own, and we never combine binders), so no substitute() } diff --git a/hugr-core/src/ops/sum.rs b/hugr-core/src/ops/sum.rs index b07179d00..7cc9fb577 100644 --- a/hugr-core/src/ops/sum.rs +++ b/hugr-core/src/ops/sum.rs @@ -55,4 +55,11 @@ impl DataflowOpTrait for Tag { fn other_output(&self) -> Option { Some(EdgeKind::StateOrder) } + + fn substitute(&self, subst: &crate::types::Substitution) -> Self { + Self { + variants: self.variants.iter().map(|r| r.substitute(subst)).collect(), + tag: self.tag, + } + } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 6fcb292b6..23db1f1d3 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -588,9 +588,19 @@ impl From for TypeRV { /// Details a replacement of type variables with a finite list of known values. /// (Variables out of the range of the list will result in a panic) -pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); +pub struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); + +impl<'a> Substitution<'a> { + /// Create a new Substitution given the replacement values (indexed + /// as the variables they replace). `exts` must contain the [TypeDef] + /// for every custom [Type] (to which the Substitution is applied) + /// containing a type-variable. + /// + /// [TypeDef]: crate::extension::TypeDef + pub fn new(items: &'a [TypeArg], exts: &'a ExtensionRegistry) -> Self { + Self(items, exts) + } -impl Substitution<'_> { pub(crate) fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { let arg = self .0 @@ -630,7 +640,7 @@ impl Substitution<'_> { } } - fn extension_registry(&self) -> &ExtensionRegistry { + pub(crate) fn extension_registry(&self) -> &ExtensionRegistry { self.1 } } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 3c40552aa..9aa19963e 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -155,18 +155,26 @@ pub enum TypeArg { n: u64, }, ///Instance of [TypeParam::String]. UTF-8 encoded string argument. + #[display("\"{arg}\"")] String { #[allow(missing_docs)] arg: String, }, /// Instance of [TypeParam::List] or [TypeParam::Tuple], defined by a /// sequence of elements. - #[display("Sequence({})", "elems.iter().map(|t|t.to_string()).join(\", \")")] + #[display("Sequence({})", { + use itertools::Itertools as _; + elems.iter().map(|t|t.to_string()).join(",") + })] Sequence { #[allow(missing_docs)] elems: Vec, }, /// Instance of [TypeParam::Extensions], providing the extension ids. + #[display("Exts({})", { + use itertools::Itertools as _; + es.iter().map(|t|t.to_string()).join(",") + })] Extensions { #[allow(missing_docs)] es: ExtensionSet, @@ -202,6 +210,14 @@ impl From for TypeArg { } } +impl From<&str> for TypeArg { + fn from(arg: &str) -> Self { + TypeArg::String { + arg: arg.to_string(), + } + } +} + impl From> for TypeArg { fn from(elems: Vec) -> Self { Self::Sequence { elems } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index b8cb6d116..c807b872f 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -71,7 +71,7 @@ impl TypeRowBase { /// Applies a substitution to the row. /// For `TypeRowRV`, note this may change the length of the row. /// For `TypeRow`, guaranteed not to change the length of the row. - pub(super) fn substitute(&self, s: &Substitution) -> Self { + pub(crate) fn substitute(&self, s: &Substitution) -> Self { self.iter() .flat_map(|ty| ty.substitute(s)) .collect::>() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 56300e345..86bdc92f9 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -6,6 +6,8 @@ pub mod force_order; mod half_node; pub mod lower; pub mod merge_bbs; +mod monomorphize; +pub use monomorphize::{monomorphize, remove_polyfuncs}; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs new file mode 100644 index 000000000..b85a2b580 --- /dev/null +++ b/hugr-passes/src/monomorphize.rs @@ -0,0 +1,681 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + fmt::Write, + ops::Deref, +}; + +use hugr_core::{ + extension::ExtensionRegistry, + ops::{Call, FuncDefn, LoadFunction, OpTrait}, + types::{Signature, Substitution, TypeArg}, + Node, +}; + +use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType}; +use itertools::Itertools as _; + +/// Replaces calls to polymorphic functions with calls to new monomorphic +/// instantiations of the polymorphic ones. +/// +/// If the Hugr is [Module](OpType::Module)-rooted, +/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them) +/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these +/// * else, the originals are removed (they are invisible from outside the Hugr). +/// +/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic +/// signature then the HUGR will not be modified. +/// +/// Monomorphic copies of polymorphic functions will be added to the HUGR as +/// children of the root node. We make best effort to ensure that names (derived +/// from parent function names and concrete type args) of new functions are unique +/// whenever the names of their parents are unique, but this is not guaranteed. +pub fn monomorphize(mut h: Hugr) -> Hugr { + let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}")); + + // We clone the extension registry because we will need a reference to + // create our mutable substitutions. This is cannot cause a problem because + // we will not be adding any new types or extension ops to the HUGR. + let reg = h.extensions().to_owned(); + + #[cfg(debug_assertions)] + validate(&h); + + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(&mut h, root, None, &mut HashMap::new(), ®); + if !h.get_optype(root).is_module() { + return remove_polyfuncs(h); + } + } + #[cfg(debug_assertions)] + validate(&h); + h +} + +/// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have +/// calls from *monomorphic* code, this will make the Hugr invalid (call [monomorphize] +/// first). +/// +/// TODO replace this with a more general remove-unused-functions pass +/// +pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { + let mut pfs_to_delete = Vec::new(); + let mut to_scan = Vec::from_iter(h.children(h.root())); + while let Some(n) = to_scan.pop() { + if is_polymorphic_funcdefn(h.get_optype(n)) { + pfs_to_delete.push(n) + } else { + to_scan.extend(h.children(n)); + } + } + for n in pfs_to_delete { + h.remove_subtree(n); + } + h +} + +fn is_polymorphic(fd: &FuncDefn) -> bool { + !fd.signature.params().is_empty() +} + +fn is_polymorphic_funcdefn(t: &OpType) -> bool { + t.as_func_defn().is_some_and(is_polymorphic) +} + +struct Instantiating<'a> { + subst: &'a Substitution<'a>, + target_container: Node, + node_map: &'a mut HashMap, +} + +type Instantiations = HashMap, Node>>; + +/// Scans a subtree for polymorphic calls and monomorphizes them by instantiating the +/// called functions (if instantiations not already in `cache`). +/// Optionally copies the subtree into a new location whilst applying a substitution. +/// The subtree should be monomorphic after the substitution (if provided) has been applied. +fn mono_scan( + h: &mut Hugr, + parent: Node, + mut subst_into: Option<&mut Instantiating>, + cache: &mut Instantiations, + reg: &ExtensionRegistry, +) { + for old_ch in h.children(parent).collect_vec() { + let ch_op = h.get_optype(old_ch); + debug_assert!(!ch_op.is_func_defn() || subst_into.is_none()); // If substituting, should have flattened already + if is_polymorphic_funcdefn(ch_op) { + continue; + }; + // Perform substitution, and recurse into containers (mono_scan does nothing if no children) + let ch = if let Some(ref mut inst) = subst_into { + let new_ch = + h.add_node_with_parent(inst.target_container, ch_op.substitute(inst.subst)); + inst.node_map.insert(old_ch, new_ch); + let mut inst = Instantiating { + target_container: new_ch, + node_map: inst.node_map, + ..**inst + }; + mono_scan(h, old_ch, Some(&mut inst), cache, reg); + new_ch + } else { + mono_scan(h, old_ch, None, cache, reg); + old_ch + }; + + // Now instantiate the target of any Call/LoadFunction to a polymorphic function... + let ch_op = h.get_optype(ch); + let (type_args, mono_sig, new_op) = match ch_op { + OpType::Call(c) => { + let mono_sig = c.instantiation.clone(); + ( + &c.type_args, + mono_sig.clone(), + OpType::from(Call::try_new(mono_sig.into(), [], reg).unwrap()), + ) + } + OpType::LoadFunction(lf) => { + let mono_sig = lf.instantiation.clone(); + ( + &lf.type_args, + mono_sig.clone(), + LoadFunction::try_new(mono_sig.into(), [], reg) + .unwrap() + .into(), + ) + } + _ => continue, + }; + if type_args.is_empty() { + continue; + }; + let fn_inp = ch_op.static_input_port().unwrap(); + let tgt = h.static_source(old_ch).unwrap(); // Use old_ch as edges not copied yet + let new_tgt = instantiate(h, tgt, type_args.clone(), mono_sig.clone(), cache, reg); + let fn_out = { + let func = h.get_optype(new_tgt).as_func_defn().unwrap(); + debug_assert_eq!(func.signature, mono_sig.into()); + h.get_optype(new_tgt).static_output_port().unwrap() + }; + h.disconnect(ch, fn_inp); // No-op if copying+substituting + h.connect(new_tgt, fn_out, ch, fn_inp); + + h.replace_op(ch, new_op).unwrap(); + } +} + +fn instantiate( + h: &mut Hugr, + poly_func: Node, + type_args: Vec, + mono_sig: Signature, + cache: &mut Instantiations, + reg: &ExtensionRegistry, +) -> Node { + let for_func = cache.entry(poly_func).or_insert_with(|| { + // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. + let outer_name = h.get_optype(poly_func).as_func_defn().unwrap().name.clone(); + let mut to_scan = Vec::from_iter(h.children(poly_func)); + while let Some(n) = to_scan.pop() { + if let OpType::FuncDefn(fd) = h.get_optype(n) { + let fd = FuncDefn { + name: mangle_inner_func(&outer_name, &fd.name), + signature: fd.signature.clone(), + }; + h.replace_op(n, fd).unwrap(); + h.move_after_sibling(n, poly_func); + } else { + to_scan.extend(h.children(n)) + } + } + HashMap::new() + }); + + let ve = match for_func.entry(type_args.clone()) { + Entry::Occupied(n) => return *n.get(), + Entry::Vacant(ve) => ve, + }; + + let name = mangle_name( + &h.get_optype(poly_func).as_func_defn().unwrap().name, + &type_args, + ); + let mono_tgt = h.add_node_after( + poly_func, + FuncDefn { + name, + signature: mono_sig.into(), + }, + ); + // Insert BEFORE we scan (in case of recursion), hence we cannot use Entry::or_insert + ve.insert(mono_tgt); + // Now make the instantiation + let mut node_map = HashMap::new(); + let mut inst = Instantiating { + subst: &Substitution::new(&type_args, reg), + target_container: mono_tgt, + node_map: &mut node_map, + }; + mono_scan(h, poly_func, Some(&mut inst), cache, reg); + // Copy edges...we have built a node_map for every node in the function. + // Note we could avoid building the "large" map (smaller than the Hugr we've just created) + // by doing this during recursion, but we'd need to be careful with nonlocal edges - + // 'ext' edges by copying every node before recursing on any of them, + // 'dom' edges would *also* require recursing in dominator-tree preorder. + for (&old_ch, &new_ch) in node_map.iter() { + for inport in h.node_inputs(old_ch).collect::>() { + // Edges from monomorphized functions to their calls already added during mono_scan() + // as these depend not just on the original FuncDefn but also the TypeArgs + if h.linked_outputs(new_ch, inport).next().is_some() { + continue; + }; + let srcs = h.linked_outputs(old_ch, inport).collect::>(); + for (src, outport) in srcs { + // Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside + h.connect( + node_map.get(&src).copied().unwrap_or(src), + outport, + new_ch, + inport, + ); + } + } + } + + mono_tgt +} + +struct TypeArgsList<'a>(&'a [TypeArg]); + +impl std::fmt::Display for TypeArgsList<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for arg in self.0 { + f.write_char('$')?; + write_type_arg_str(arg, f)?; + } + Ok(()) + } +} + +fn escape_dollar(str: impl AsRef) -> String { + str.as_ref().replace("$", "\\$") +} + +fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match arg { + TypeArg::Type { ty } => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), + TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), + TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), + TypeArg::Extensions { es } => f.write_fmt(format_args!( + "es({})", + es.iter().map(|x| x.deref()).join(",") + )), + // We are monomorphizing. We will never monomorphize to a signature + // containing a variable. + TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), + _ => panic!("unknown type arg: {arg}"), + } +} + +/// We do our best to generate unique names. Our strategy is to pick out '$' as +/// a special character. +/// +/// We: +/// - construct a new name of the form `{func_name}$$arg0$arg1$arg2` etc +/// - replace any existing `$` in the function name or type args string +/// representation with `r"\$"` +/// - We depend on the `Display` impl of `Type` to generate the string +/// representation of a `TypeArg::Type`. For other constructors we do the +/// simple obvious thing. +/// - For all TypeArg Constructors we choose a short prefix (e.g. `t` for type) +/// and use "t({arg})" as the string representation of that arg. +fn mangle_name(name: &str, type_args: impl AsRef<[TypeArg]>) -> String { + let name = escape_dollar(name); + format!("${name}${}", TypeArgsList(type_args.as_ref())) +} + +fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String { + format!("${outer_name}${inner_name}") +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::iter; + + use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; + use hugr_core::std_extensions::{collections, STD_REG}; + use hugr_core::types::type_param::TypeParam; + use itertools::Itertools; + + use hugr_core::builder::{ + Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + HugrBuilder, ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + usize_t, ConstUsize, UnpackTuple, UnwrapBuilder, PRELUDE_ID, + }; + use hugr_core::extension::ExtensionSet; + use hugr_core::ops::handle::{FuncID, NodeHandle}; + use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; + use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; + use hugr_core::types::{ + PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, + }; + use hugr_core::{Hugr, HugrView, Node}; + use rstest::rstest; + + use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs}; + + fn pair_type(ty: Type) -> Type { + Type::new_tuple(vec![ty.clone(), ty]) + } + + fn triple_type(ty: Type) -> Type { + Type::new_tuple(vec![ty.clone(), ty.clone(), ty]) + } + + fn prelusig(ins: impl Into, outs: impl Into) -> Signature { + Signature::new(ins, outs).with_extension_delta(PRELUDE_ID) + } + + #[test] + fn test_null() { + let dfg_builder = + DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); + let [i1] = dfg_builder.input_wires_arr(); + let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); + let hugr2 = monomorphize(hugr.clone()); + assert_eq!(hugr, hugr2); + } + + #[test] + fn test_recursion_module() -> Result<(), Box> { + let tv0 = || Type::new_var_use(0, TypeBound::Copyable); + let mut mb = ModuleBuilder::new(); + let db = { + let pfty = PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new(tv0(), pair_type(tv0())), + ); + let mut fb = mb.define_function("double", pfty)?; + let [elem] = fb.input_wires_arr(); + // A "genuine" impl might: + // let tag = Tag::new(0, vec![vec![elem_ty; 2].into()]); + // let tag = fb.add_dataflow_op(tag, [elem, elem]).unwrap(); + // ...but since this will never execute, we can test recursion here + let tag = fb.call( + &FuncID::::from(fb.container_node()), + &[tv0().into()], + [elem], + )?; + fb.finish_with_outputs(tag.outputs())? + }; + + let tr = { + let sig = prelusig(tv0(), Type::new_tuple(vec![tv0(); 3])); + let mut fb = mb.define_function( + "triple", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + )?; + let [elem] = fb.input_wires_arr(); + let pair = fb.call(db.handle(), &[tv0().into()], [elem])?; + + let [elem1, elem2] = fb + .add_dataflow_op(UnpackTuple::new(vec![tv0(); 2].into()), pair.outputs())? + .outputs_arr(); + let tag = Tag::new(0, vec![vec![tv0(); 3].into()]); + let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?; + fb.finish_with_outputs(trip.outputs())? + }; + { + let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; + let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; + let [elem] = fb.input_wires_arr(); + let [res1] = fb + .call(tr.handle(), &[usize_t().into()], [elem])? + .outputs_arr(); + let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?; + let pty = pair_type(usize_t()).into(); + let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); + fb.finish_with_outputs([res1, res2])?; + } + let hugr = mb.finish_hugr()?; + assert_eq!( + hugr.nodes() + .filter(|n| hugr.get_optype(*n).is_func_defn()) + .count(), + 3 + ); + let mono = monomorphize(hugr); + mono.validate()?; + + let mut funcs = list_funcs(&mono); + let expected_mangled_names = [ + mangle_name("double", &[usize_t().into()]), + mangle_name("triple", &[usize_t().into()]), + mangle_name("double", &[pair_type(usize_t()).into()]), + mangle_name("triple", &[pair_type(usize_t()).into()]), + ]; + + for n in expected_mangled_names.iter() { + assert!(!is_polymorphic(funcs.remove(n).unwrap().1)); + } + + assert_eq!( + funcs.into_keys().sorted().collect_vec(), + ["double", "main", "triple"] + ); + + assert_eq!(monomorphize(mono.clone()), mono); // Idempotent + + let nopoly = remove_polyfuncs(mono); + let mut funcs = list_funcs(&nopoly); + + assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); + for n in expected_mangled_names { + assert!(funcs.remove(&n).is_some()); + } + assert_eq!(funcs.keys().collect_vec(), vec![&"main"]); + Ok(()) + } + + #[test] + fn test_flattening_multiargs_nats() { + //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func + + let tv = |i| Type::new_var_use(i, TypeBound::Copyable); + let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat()); + let sa = |n| TypeArg::BoundedNat { n }; + + let n: u64 = 5; + let mut outer = FunctionBuilder::new( + "mainish", + prelusig( + array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) + .unwrap(), + vec![usize_t(); 2], + ) + .with_extension_delta(collections::array::EXTENSION_ID), + ) + .unwrap(); + + let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); + let pf1t = PolyFuncType::new( + [TypeParam::max_nat()], + prelusig(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()) + .with_extension_delta(collections::array::EXTENSION_ID), + ); + let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); + + let pf2t = PolyFuncType::new( + [TypeParam::max_nat(), TypeBound::Copyable.into()], + prelusig(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)) + .with_extension_delta(collections::array::EXTENSION_ID), + ); + let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); + + let mono_func = { + let mut fb = pf2 + .define_function( + "get_usz", + prelusig(vec![], usize_t()) + .with_extension_delta(collections::array::EXTENSION_ID), + ) + .unwrap(); + let cst0 = fb.add_load_value(ConstUsize::new(1)); + fb.finish_with_outputs([cst0]).unwrap() + }; + let pf2 = { + let [inw] = pf2.input_wires_arr(); + let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); + let op_def = collections::array::EXTENSION.get_op("get").unwrap(); + let op = hugr_core::ops::ExtensionOp::new( + op_def.clone(), + vec![sv(0), tv(1).into()], + &STD_REG, + ) + .unwrap(); + let [get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); + let [got] = pf2 + .build_unwrap_sum(&STD_REG, 1, SumType::new([vec![], vec![tv(1)]]), get) + .unwrap(); + pf2.finish_with_outputs([got]).unwrap() + }; + // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not + let inner = pf1 + .call(pf2.handle(), &[sv(0), arr2u().into()], pf1.input_wires()) + .unwrap(); + let elem = pf1 + .call( + pf2.handle(), + &[TypeArg::BoundedNat { n: 2 }, usize_t().into()], + inner.outputs(), + ) + .unwrap(); + let pf1 = pf1.finish_with_outputs(elem.outputs()).unwrap(); + // Outer: two calls to pf1 with different TypeArgs + let [e1] = outer + .call(pf1.handle(), &[sa(n)], outer.input_wires()) + .unwrap() + .outputs_arr(); + let popleft = ArrayOpDef::pop_left.to_concrete(arr2u(), n); + let ar2 = outer + .add_dataflow_op(popleft.clone(), outer.input_wires()) + .unwrap(); + let sig = popleft.to_extension_op().unwrap().signature().into_owned(); + let TypeEnum::Sum(st) = sig.output().get(0).unwrap().as_type_enum() else { + panic!() + }; + let [_, ar2_unwrapped] = outer + .build_unwrap_sum(&STD_REG, 1, st.clone(), ar2.out_wire(0)) + .unwrap(); + let [e2] = outer + .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) + .unwrap() + .outputs_arr(); + let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); + + let mono_hugr = monomorphize(hugr); + mono_hugr.validate().unwrap(); + let funcs = list_funcs(&mono_hugr); + let pf2_name = mangle_inner_func("pf1", "pf2"); + assert_eq!( + funcs.keys().copied().sorted().collect_vec(), + vec![ + &mangle_name("pf1", &[TypeArg::BoundedNat { n: 5 }]), + &mangle_name("pf1", &[TypeArg::BoundedNat { n: 4 }]), + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> + &mangle_name(&pf2_name, &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> + &mangle_inner_func(&pf2_name, "get_usz"), + "mainish" + ] + .into_iter() + .sorted() + .collect_vec() + ); + for (n, fd) in funcs.into_values() { + assert!(!is_polymorphic(fd)); + assert!(mono_hugr.get_parent(n) == (fd.name != "mainish").then_some(mono_hugr.root())); + } + } + + fn list_funcs(h: &Hugr) -> HashMap<&String, (Node, &FuncDefn)> { + h.nodes() + .filter_map(|n| h.get_optype(n).as_func_defn().map(|fd| (&fd.name, (n, fd)))) + .collect::>() + } + + #[test] + fn test_no_flatten_out_of_mono_func() -> Result<(), Box> { + let ity = || INT_TYPES[4].clone(); + let sig = Signature::new_endo(vec![usize_t(), ity()]); + let mut dfg = DFGBuilder::new(sig.clone()).unwrap(); + let mut mono = dfg.define_function("id2", sig).unwrap(); + let pf = mono + .define_function( + "id", + PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), + ), + ) + .unwrap(); + let outs = pf.input_wires(); + let pf = pf.finish_with_outputs(outs).unwrap(); + let [a, b] = mono.input_wires_arr(); + let [a] = mono + .call(pf.handle(), &[usize_t().into()], [a]) + .unwrap() + .outputs_arr(); + let [b] = mono + .call(pf.handle(), &[ity().into()], [b]) + .unwrap() + .outputs_arr(); + let mono = mono.finish_with_outputs([a, b]).unwrap(); + let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); + let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); + let mono_hugr = monomorphize(hugr); + + let mut funcs = list_funcs(&mono_hugr); + assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); + #[allow(clippy::unnecessary_to_owned)] // It is necessary + let (m, _) = funcs.remove(&"id2".to_string()).unwrap(); + assert_eq!(m, mono.handle().node()); + assert_eq!(mono_hugr.get_parent(m), Some(mono_hugr.root())); + for t in [usize_t(), ity()] { + let (n, _) = funcs.remove(&mangle_name("id", &[t.into()])).unwrap(); + assert_eq!(mono_hugr.get_parent(n), Some(m)); // Not lifted to top + } + Ok(()) + } + + #[test] + fn load_function() { + let hugr = { + let mut module_builder = ModuleBuilder::new(); + let foo = { + let builder = module_builder + .define_function( + "foo", + PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), + ), + ) + .unwrap(); + let inputs = builder.input_wires(); + builder.finish_with_outputs(inputs).unwrap() + }; + + let _main = { + let mut builder = module_builder + .define_function("main", Signature::new_endo(Type::UNIT)) + .unwrap(); + let func_ptr = builder + .load_func(foo.handle(), &[Type::UNIT.into()]) + .unwrap(); + let [r] = { + let signature = Signature::new_endo(Type::UNIT); + builder + .add_dataflow_op( + CallIndirect { signature }, + iter::once(func_ptr).chain(builder.input_wires()), + ) + .unwrap() + .outputs_arr() + }; + + builder.finish_with_outputs([r]).unwrap() + }; + module_builder.finish_hugr().unwrap() + }; + + let mono_hugr = remove_polyfuncs(monomorphize(hugr)); + + let funcs = list_funcs(&mono_hugr); + assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); + } + + #[rstest] + #[case::bounded_nat(vec![0.into()], "$foo$$n(0)")] + #[case::type_unit(vec![Type::UNIT.into()], "$foo$$t(Unit)")] + #[case::type_int(vec![INT_TYPES[2].to_owned().into()], "$foo$$t(int([BoundedNat { n: 2 }]))")] + #[case::string(vec!["arg".into()], "$foo$$s(arg)")] + #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] + #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] + #[case::extensionset(vec![ExtensionSet::from_iter([PRELUDE_ID,int_types::EXTENSION_ID]).into()], + "$foo$$es(arithmetic.int.types,prelude)")] // alphabetic ordering of extension names + #[should_panic] + #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], + "$foo$$v(1)")] + #[case::multiple(vec![0.into(), "arg".into()], "$foo$$n(0)$s(arg)")] + fn test_mangle_name(#[case] args: Vec, #[case] expected: String) { + assert_eq!(mangle_name("foo", &args), expected); + } +}