From 7cf7bb699abab537506eafa8aa4f39401079922a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 16 Dec 2024 14:16:01 +0000 Subject: [PATCH] feat!: Remove ExtensionRegistry args in UnwrapBuilder and ListOp (#1785) Last two remaining usages of `ExtensionRegistry` as redundant parameters. These should have been included in #1784. BREAKING CHANGE: `UnwrapBuilder` no longer requires an `ExtensionRegistry` as argument. BREAKING CHANGE: `ListOpInst` no longer requires an `ExtensionRegistry` to produce an `ExtensionOp`. --- .../src/extension/prelude/unwrap_builder.rs | 21 ++++-------- hugr-core/src/hugr/rewrite/replace.rs | 4 +-- .../src/std_extensions/collections/list.rs | 33 +++++-------------- hugr-llvm/src/extension/collections/array.rs | 24 +++++++------- hugr-llvm/src/utils/array_op_builder.rs | 29 ++++------------ hugr-passes/src/const_fold/test.rs | 11 ++----- hugr-passes/src/monomorphize.rs | 6 ++-- 7 files changed, 39 insertions(+), 89 deletions(-) diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index b3f649a66..2495a464d 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -2,22 +2,20 @@ use std::iter; use crate::{ builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer}, - extension::{ - prelude::{ConstError, PANIC_OP_ID, PRELUDE_ID}, - ExtensionRegistry, - }, + extension::prelude::{ConstError, PANIC_OP_ID}, ops::handle::DataflowOpID, types::{SumType, Type, TypeArg, TypeRow}, Wire, }; use itertools::{zip_eq, Itertools as _}; +use super::PRELUDE; + /// Extend dataflow builders with methods for building unwrap operations. pub trait UnwrapBuilder: Dataflow { /// Add a panic operation to the dataflow with the given error. fn add_panic( &mut self, - reg: &ExtensionRegistry, err: ConstError, output_row: impl IntoIterator, inputs: impl IntoIterator, @@ -33,8 +31,7 @@ pub trait UnwrapBuilder: Dataflow { .map(>::from) .collect_vec() .into(); - let prelude = reg.get(&PRELUDE_ID).unwrap(); - let op = prelude.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?; + let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?; let err = self.add_load_value(err); self.add_dataflow_op(op, iter::once(err).chain(input_wires)) } @@ -43,7 +40,6 @@ pub trait UnwrapBuilder: Dataflow { /// or panic if the tag is not the expected value. fn build_unwrap_sum( &mut self, - reg: &ExtensionRegistry, tag: usize, sum_type: SumType, input: Wire, @@ -70,7 +66,7 @@ pub trait UnwrapBuilder: Dataflow { let inputs = zip_eq(case.input_wires(), variant.iter().cloned()); let err = ConstError::new(1, format!("Expected variant {} but got variant {}", tag, i)); - let outputs = case.add_panic(reg, err, output_row, inputs)?.outputs(); + let outputs = case.add_panic(err, output_row, inputs)?.outputs(); case.finish_with_outputs(outputs)?; } } @@ -85,10 +81,7 @@ mod tests { use super::*; use crate::{ builder::{DFGBuilder, DataflowHugr}, - extension::{ - prelude::{bool_t, option_type}, - PRELUDE_REGISTRY, - }, + extension::prelude::{bool_t, option_type}, types::Signature, }; @@ -102,7 +95,7 @@ mod tests { let [opt] = builder.input_wires_arr(); let [res] = builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(bool_t()), opt) + .build_unwrap_sum(1, option_type(bool_t()), opt) .unwrap(); builder.finish_hugr_with_outputs([res]).unwrap(); } diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 2e4300b2a..80a847eea 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -471,11 +471,11 @@ mod test { let listy = list::list_type(usize_t()); let pop: ExtensionOp = list::ListOp::pop .with_type(usize_t()) - .to_extension_op(®) + .to_extension_op() .unwrap(); let push: ExtensionOp = list::ListOp::push .with_type(usize_t()) - .to_extension_op(®) + .to_extension_op() .unwrap(); let just_list = TypeRow::from(vec![listy.clone()]); let intermed = TypeRow::from(vec![listy.clone(), usize_t()]); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 0a37238c4..79330a68b 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -25,7 +25,7 @@ use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound, + ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, ops::{custom::ExtensionOp, NamedOp}, @@ -373,20 +373,8 @@ impl MakeExtensionOp for ListOpInst { impl ListOpInst { /// Convert this list operation to an [`ExtensionOp`] by providing a /// registry to validate the element type against. - pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option { - let registry = ExtensionRegistry::new( - elem_type_registry - .clone() - .into_iter() - // ignore self if already in registry - .filter(|ext| ext.name() != EXTENSION.name()) - .chain(std::iter::once(EXTENSION.to_owned())), - ); - ExtensionOp::new( - registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), - self.type_args(), - ) - .ok() + pub fn to_extension_op(self) -> Option { + ExtensionOp::new(EXTENSION.get_op(&self.name())?.clone(), self.type_args()).ok() } } @@ -398,14 +386,10 @@ mod test { const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, }; use crate::ops::OpTrait; - use crate::std_extensions::STD_REG; use crate::PortIndex; use crate::{ - extension::{ - prelude::{qb_t, usize_t, ConstUsize}, - PRELUDE, - }, - std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}, + extension::prelude::{qb_t, usize_t, ConstUsize}, + std_extensions::arithmetic::float_types::{float64_type, ConstF64}, types::TypeRow, }; @@ -443,9 +427,8 @@ mod test { #[test] fn test_list_ops() { - let reg = ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]); let pop_op = ListOp::pop.with_type(qb_t()); - let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); + let pop_ext = pop_op.clone().to_extension_op().unwrap(); assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); let pop_sig = pop_ext.dataflow_signature().unwrap(); @@ -457,7 +440,7 @@ mod test { assert_eq!(pop_sig.output(), &both_row); let push_op = ListOp::push.with_type(float64_type()); - let push_ext = push_op.clone().to_extension_op(®).unwrap(); + let push_ext = push_op.clone().to_extension_op().unwrap(); assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); let push_sig = push_ext.dataflow_signature().unwrap(); @@ -531,7 +514,7 @@ mod test { let res = op .with_type(usize_t()) - .to_extension_op(&STD_REG) + .to_extension_op() .unwrap() .constant_fold(&consts) .unwrap(); diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index e4b9deb75..39d721a29 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -831,7 +831,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(usize_t()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, reg| { + .finish(|mut builder| { let us0 = builder.add_load_value(ConstUsize::new(0)); let us1 = builder.add_load_value(ConstUsize::new(1)); let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap()); @@ -872,13 +872,13 @@ mod test { let [arr_0] = { let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); builder - .build_unwrap_sum(reg, 1, option_type(int_ty.clone()), r) + .build_unwrap_sum(1, option_type(int_ty.clone()), r) .unwrap() }; let [arr_1] = { let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); builder - .build_unwrap_sum(reg, 1, option_type(int_ty.clone()), r) + .build_unwrap_sum(1, option_type(int_ty.clone()), r) .unwrap() }; let elem_eq = builder.add_ieq(3, elem, expected_elem).unwrap(); @@ -943,7 +943,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(usize_t()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, reg| { + .finish(|mut builder| { let us0 = builder.add_load_value(ConstUsize::new(0)); let us1 = builder.add_load_value(ConstUsize::new(1)); let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap()); @@ -983,13 +983,13 @@ mod test { let elem_0 = { let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); builder - .build_unwrap_sum::<1>(reg, 1, option_type(int_ty.clone()), r) + .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r) .unwrap()[0] }; let elem_1 = { let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); builder - .build_unwrap_sum::<1>(reg, 1, option_type(int_ty), r) + .build_unwrap_sum::<1>(1, option_type(int_ty), r) .unwrap()[0] }; let expected_elem_0 = @@ -1052,7 +1052,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, reg| { + .finish(|mut builder| { let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); let new_array_args = array_contents .iter() @@ -1074,7 +1074,6 @@ mod test { }; let [elem, new_arr] = builder .build_unwrap_sum( - reg, 1, option_type(vec![ int_ty.clone(), @@ -1117,7 +1116,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, reg| { + .finish(|mut builder| { let mut func = builder .define_function( "foo", @@ -1138,7 +1137,7 @@ mod test { .add_array_get(int_ty.clone(), size, arr, idx_v) .unwrap(); let [elem] = builder - .build_unwrap_sum(reg, 1, option_type(vec![int_ty.clone()]), get_res) + .build_unwrap_sum(1, option_type(vec![int_ty.clone()]), get_res) .unwrap(); builder.finish_with_outputs([elem]).unwrap() }); @@ -1164,7 +1163,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, reg| { + .finish(|mut builder| { let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); let new_array_args = (0..size) .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap())) @@ -1204,7 +1203,6 @@ mod test { .unwrap(); let [elem, new_arr] = builder .build_unwrap_sum( - reg, 1, option_type(vec![ int_ty.clone(), @@ -1240,7 +1238,7 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(int_ty.clone()) .with_extensions(exec_registry()) - .finish_with_exts(|mut builder, _reg| { + .finish(|mut builder| { let new_array_args = (0..size) .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap())) .collect_vec(); diff --git a/hugr-llvm/src/utils/array_op_builder.rs b/hugr-llvm/src/utils/array_op_builder.rs index d660f15b8..dfe2faba4 100644 --- a/hugr-llvm/src/utils/array_op_builder.rs +++ b/hugr-llvm/src/utils/array_op_builder.rs @@ -118,10 +118,7 @@ pub mod test { use hugr_core::std_extensions::collections::array::{self, array_type}; use hugr_core::{ builder::{DFGBuilder, HugrBuilder}, - extension::{ - prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, - PRELUDE_REGISTRY, - }, + extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, types::Signature, Hugr, }; @@ -149,15 +146,13 @@ pub mod test { let array_type = array_type(2, usize_t()); either_type(array_type.clone(), array_type) }; - builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, res_sum_ty, r) - .unwrap() + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() }; let [elem_0] = { let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(usize_t()), r) + .build_unwrap_sum(1, option_type(usize_t()), r) .unwrap() }; @@ -169,31 +164,19 @@ pub mod test { let row = vec![usize_t(), array_type(2, usize_t())]; either_type(row.clone(), row) }; - builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, res_sum_ty, r) - .unwrap() + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() }; let [_elem_left, arr] = { let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); builder - .build_unwrap_sum( - &PRELUDE_REGISTRY, - 1, - option_type(vec![usize_t(), array_type(1, usize_t())]), - r, - ) + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) .unwrap() }; let [_elem_right, arr] = { let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); builder - .build_unwrap_sum( - &PRELUDE_REGISTRY, - 1, - option_type(vec![usize_t(), array_type(0, usize_t())]), - r, - ) + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) .unwrap() }; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 9a2ca013b..bc1a5e2d1 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,7 +1,6 @@ use std::collections::hash_map::RandomState; use std::collections::HashSet; -use hugr_core::std_extensions::STD_REG; use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; @@ -184,10 +183,7 @@ fn test_list_ops() -> Result<(), Box> { let [list, maybe_elem] = build .add_dataflow_op( - ListOp::pop - .with_type(bool_t()) - .to_extension_op(&STD_REG) - .unwrap(), + ListOp::pop.with_type(bool_t()).to_extension_op().unwrap(), [list], )? .outputs_arr(); @@ -197,10 +193,7 @@ fn test_list_ops() -> Result<(), Box> { let [list] = build .add_dataflow_op( - ListOp::push - .with_type(bool_t()) - .to_extension_op(&STD_REG) - .unwrap(), + ListOp::push.with_type(bool_t()).to_extension_op().unwrap(), [list, elem], )? .outputs_arr(); diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 81cbed5ab..f4144ef72 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -300,8 +300,8 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; - use hugr_core::std_extensions::{collections, STD_REG}; use hugr_core::types::type_param::TypeParam; use itertools::Itertools; @@ -491,7 +491,7 @@ mod test { .unwrap(); let [get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); let [got] = pf2 - .build_unwrap_sum(&STD_REG, 1, SumType::new([vec![], vec![tv(1)]]), get) + .build_unwrap_sum(1, SumType::new([vec![], vec![tv(1)]]), get) .unwrap(); pf2.finish_with_outputs([got]).unwrap() }; @@ -521,7 +521,7 @@ mod test { panic!() }; let [_, ar2_unwrapped] = outer - .build_unwrap_sum(&STD_REG, 1, st.clone(), ar2.out_wire(0)) + .build_unwrap_sum(1, st.clone(), ar2.out_wire(0)) .unwrap(); let [e2] = outer .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped])