Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/reorder-infe…
Browse files Browse the repository at this point in the history
…rence
  • Loading branch information
croyzor committed May 8, 2024
2 parents 5f0039c + 4409d1d commit 035a44d
Showing 1 changed file with 109 additions and 22 deletions.
131 changes: 109 additions & 22 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ use cool_asserts::assert_matches;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
HugrBuilder, ModuleBuilder,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY};
use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T};
use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::HugrMut;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, Noop, Value};
use crate::ops::handle::NodeHandle;
use crate::ops::leaf::MakeTuple;
use crate::ops::{self, Noop, OpType, Value};
use crate::std_extensions::logic::test::{and_op, or_op};
use crate::std_extensions::logic::{self, NotOp};
use crate::types::type_param::{TypeArg, TypeArgError};
use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow};
use crate::{type_row, IncomingPort};
use crate::{const_extension_ids, type_row, Direction, IncomingPort, Node};

const NAT: Type = crate::extension::prelude::USIZE_T;

Expand Down Expand Up @@ -336,10 +338,12 @@ fn unregistered_extension() {
h.update_validate(&PRELUDE_REGISTRY).unwrap();
}

const_extension_ids! {
const EXT_ID: ExtensionId = "MyExt";
}
#[test]
fn invalid_types() {
let name: ExtensionId = "MyExt".try_into().unwrap();
let mut e = Extension::new(name.clone());
let mut e = Extension::new(EXT_ID);
e.add_type(
"MyContainer".into(),
vec![TypeBound::Copyable.into()],
Expand All @@ -360,7 +364,7 @@ fn invalid_types() {
let valid = Type::new_extension(CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }],
name.clone(),
EXT_ID,
TypeBound::Any,
));
assert_eq!(
Expand All @@ -374,7 +378,7 @@ fn invalid_types() {
let element_outside_bound = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: valid.clone() }],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand All @@ -388,7 +392,7 @@ fn invalid_types() {
let bad_bound = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }],
name.clone(),
EXT_ID,
TypeBound::Copyable,
);
assert_eq!(
Expand All @@ -405,7 +409,7 @@ fn invalid_types() {
vec![TypeArg::Type {
ty: Type::new_extension(bad_bound),
}],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand All @@ -419,7 +423,7 @@ fn invalid_types() {
let too_many_type_args = CustomType::new(
"MyContainer",
vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }],
name.clone(),
EXT_ID,
TypeBound::Any,
);
assert_eq!(
Expand Down Expand Up @@ -544,18 +548,101 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
let mut m = ModuleBuilder::new();
let id = m.declare(
"id",
let mut e = Extension::new(EXT_ID);

let params: Vec<TypeParam> = vec![
TypeBound::Any.into(),
TypeParam::Extensions,
TypeBound::Any.into(),
];
let evaled_fn = Type::new_function(
FunctionType::new(
Type::new_var_use(0, TypeBound::Any),
Type::new_var_use(2, TypeBound::Any),
)
.with_extension_delta(ExtensionSet::type_var(1)),
);
// The higher-order "eval" operation - takes a function and its argument.
// Note the extension-delta of the eval node includes that of the input function.
e.add_op(
"eval".into(),
"".into(),
PolyFuncType::new(
vec![TypeBound::Any.into()],
FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
params.clone(),
FunctionType::new(
vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)],
Type::new_var_use(2, TypeBound::Any),
)
.with_extension_delta(ExtensionSet::type_var(1)),
),
)?;
let mut f = m.define_function("main", FunctionType::new_endo(vec![USIZE_T]).into())?;
let c = f.call(&id, &[USIZE_T.into()], f.input_wires(), &PRELUDE_REGISTRY)?;
f.finish_with_outputs(c.outputs())?;
let _ = m.finish_prelude_hugr()?;

fn utou(e: impl Into<ExtensionSet>) -> Type {
Type::new_function(FunctionType::new_endo(USIZE_T).with_extension_delta(e.into()))
}

let int_pair = Type::new_tuple(type_row![USIZE_T; 2]);
// Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints
let mut d = DFGBuilder::new(
FunctionType::new(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
)
.with_extension_delta(PRELUDE_ID),
)?;
// ....by calling a function parametrized<extensions E> (int--e-->int, int_pair) -> int_pair
let f = {
let es = ExtensionSet::type_var(0);
let mut f = d.define_function(
"two_ints",
PolyFuncType::new(
vec![TypeParam::Extensions],
FunctionType::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone())
.with_extension_delta(es.clone()),
),
)?;
let [func, tup] = f.input_wires_arr();
let mut c = f.conditional_builder(
(vec![type_row![USIZE_T; 2]], tup),
vec![],
type_row![USIZE_T;2],
es.clone(),
)?;
let mut cc = c.case_builder(0)?;
let [i1, i2] = cc.input_wires_arr();
let op = e.instantiate_extension_op(
"eval",
vec![USIZE_T.into(), TypeArg::Extensions { es }, USIZE_T.into()],
&PRELUDE_REGISTRY,
)?;
let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr();
let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr();
cc.finish_with_outputs([f1, f2])?;
let res = c.finish_sub_container()?.outputs();
let tup = f.add_dataflow_op(
MakeTuple {
tys: type_row![USIZE_T; 2],
},
res,
)?;
f.finish_with_outputs(tup.outputs())?
};

let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?;
let [func, tup] = d.input_wires_arr();
let call = d.call(
f.handle(),
&[TypeArg::Extensions {
es: ExtensionSet::singleton(&PRELUDE_ID),
}],
[func, tup],
&reg,
)?;
let h = d.finish_hugr_with_outputs(call.outputs(), &reg)?;
let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap();
let exp_fun_ty = FunctionType::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair)
.with_extension_delta(PRELUDE_ID);
assert_eq!(call_ty, exp_fun_ty);
Ok(())
}

Expand Down

0 comments on commit 035a44d

Please sign in to comment.