From 2bbd8c69181ff0b63953c5597abe10a0d25074c7 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 6 Sep 2023 15:45:52 +0100 Subject: [PATCH] feat: symbolic parameters + more T2ops (#90) temporary - when we have a better idea of angle/rotation types can use a HUGR input for symbols instead Closes #82 --- src/json.rs | 15 +++++++-- src/json/decoder.rs | 8 +++-- src/json/encoder.rs | 28 +++++++++++------ src/json/op.rs | 20 ++++++------ src/json/tests.rs | 2 +- src/lib.rs | 4 +-- src/ops.rs | 77 +++++++++++++++++++++++++++++++++++++++++---- 7 files changed, 121 insertions(+), 33 deletions(-) diff --git a/src/json.rs b/src/json.rs index d9659666..8470f35a 100644 --- a/src/json.rs +++ b/src/json.rs @@ -140,16 +140,25 @@ impl From for TK1LoadError { } } +#[inline] +fn parse_val(n: &str) -> Option { + if n == "pi" { + // angles are in multiples of pi + Some(1.0) + } else { + n.parse::().ok() + } +} /// Try to interpret a TKET1 parameter as a constant value. #[inline] fn try_param_to_constant(param: &str) -> Option { - if let Ok(f) = param.parse::() { + if let Some(f) = parse_val(param) { Some(ConstF64::new(f).into()) } else if param.split('/').count() == 2 { // TODO: Use the rational types from `Hugr::extensions::rotation` let (n, d) = param.split_once('/').unwrap(); - let n = n.parse::().unwrap(); - let d = d.parse::().unwrap(); + let n = parse_val(n)?; + let d = parse_val(d)?; Some(ConstF64::new(n / d).into()) } else { None diff --git a/src/json/decoder.rs b/src/json/decoder.rs index 28ec5e5a..1d5432c3 100644 --- a/src/json/decoder.rs +++ b/src/json/decoder.rs @@ -21,6 +21,7 @@ use tket_json_rs::circuit_json::SerialCircuit; use super::op::JsonOp; use super::{try_param_to_constant, METADATA_IMPLICIT_PERM, METADATA_PHASE}; use crate::extension::{LINEAR_BIT, REGISTRY}; +use crate::symbolic_constant_op; /// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`]. /// @@ -145,9 +146,10 @@ impl JsonDecoder { self.hugr.add_load_const(const_op).unwrap() } None => { - // TODO: If the parameter is just a variable, - // return the corresponding wire from the input. - todo!("Variable parameters not yet supported") + // store string in custom op. + let symb_op = symbolic_constant_op(param); + let o = self.hugr.add_dataflow_op(symb_op, []).unwrap(); + o.out_wire(0) } } } diff --git a/src/json/encoder.rs b/src/json/encoder.rs index cb9f20eb..3711e401 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -13,6 +13,7 @@ use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; use crate::circuit::command::{CircuitUnit, Command}; use crate::circuit::Circuit; use crate::extension::LINEAR_BIT; +use crate::ops::match_symb_const_op; use super::op::JsonOp; use super::{OpConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE}; @@ -89,11 +90,12 @@ impl JsonEncoder { /// Add a circuit command to the serialization. pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. - self.record_parameters(&command, optype); - - if let OpType::Const(_) | OpType::LoadConstant(_) = optype { + if self.record_parameters(&command, optype) { + // for now all ops that record parameters should be ignored (are + // just constants) return Ok(()); } + let (args, params): (Vec, Vec) = command .inputs() @@ -141,9 +143,9 @@ impl JsonEncoder { } /// Record any output of the command that can be used as a TKET1 parameter. - /// + /// Returns whether parameters were recorded. /// Associates the output wires with the parameter expression. - fn record_parameters(&mut self, command: &Command, optype: &OpType) { + fn record_parameters(&mut self, command: &Command, optype: &OpType) -> bool { // Only consider commands where all inputs are parameters. let inputs = command .inputs() @@ -154,7 +156,7 @@ impl JsonEncoder { }) .collect_vec(); if inputs.len() != command.inputs().len() { - return; + return false; } let param = match optype { @@ -165,21 +167,28 @@ impl JsonEncoder { if let Some(f) = v.downcast_ref::() { f.to_string() } else { - return; + return false; } } - _ => return, + _ => return false, } } OpType::LoadConstant(_op_type) => { // Re-use the parameter from the input. inputs[0].clone() } + OpType::LeafOp(_) => { + if let Some(s) = match_symb_const_op(optype) { + s.to_string() + } else { + return false; + } + } _ => { // In the future we may want to support arithmetic operations. // (Just concatenating the inputs and the operation symbol, no // need for evaluation). - return; + return false; } }; @@ -188,6 +197,7 @@ impl JsonEncoder { self.parameters.insert(*wire, param.clone()); } } + true } fn unit_to_register(&self, unit: CircuitUnit) -> Option { diff --git a/src/json/op.rs b/src/json/op.rs index a4b340db..6eea35e4 100644 --- a/src/json/op.rs +++ b/src/json/op.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use tket_json_rs::circuit_json; use tket_json_rs::optype::OpType as JsonOpType; -use super::{try_param_to_constant, OpConvertError}; +use super::OpConvertError; use crate::extension::{try_unwrap_json_op, LINEAR_BIT}; use crate::T2Op; @@ -164,14 +164,8 @@ impl JsonOp { return; }; - let mut p_input_indices = 0..; - let param_inputs = params - .iter() - .map(|param| try_param_to_constant(param).map(|_| p_input_indices.next().unwrap())) - .collect(); - - self.num_params = p_input_indices.next().unwrap(); - self.param_inputs = param_inputs; + self.num_params = params.len(); + self.param_inputs = (0..params.len()).map(Some).collect(); } } @@ -189,7 +183,11 @@ impl From<&JsonOp> for OpType { JsonOpType::Tdg => T2Op::Tdg.into(), JsonOpType::X => T2Op::X.into(), JsonOpType::Rz => T2Op::RzF64.into(), + JsonOpType::Rx => T2Op::RxF64.into(), JsonOpType::TK1 => T2Op::TK1.into(), + JsonOpType::PhasedX => T2Op::PhasedX.into(), + JsonOpType::ZZMax => T2Op::ZZMax.into(), + JsonOpType::ZZPhase => T2Op::ZZPhase.into(), JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(), _ => LeafOp::CustomOp(Box::new(json_op.as_opaque_op())).into(), } @@ -216,7 +214,11 @@ impl TryFrom<&OpType> for JsonOp { T2Op::H => JsonOpType::H, T2Op::Measure => JsonOpType::Measure, T2Op::RzF64 => JsonOpType::Rz, + T2Op::RxF64 => JsonOpType::Rx, T2Op::TK1 => JsonOpType::TK1, + T2Op::PhasedX => JsonOpType::PhasedX, + T2Op::ZZMax => JsonOpType::ZZMax, + T2Op::ZZPhase => JsonOpType::ZZPhase, _ => return Err(err()), } } else if let LeafOp::CustomOp(b) = leaf { diff --git a/src/json/tests.rs b/src/json/tests.rs index 1ccaf8cc..88489a49 100644 --- a/src/json/tests.rs +++ b/src/json/tests.rs @@ -43,7 +43,7 @@ const PARAMETRIZED: &str = r#"{ {"args":[["q",[0]]],"op":{"type":"H"}}, {"args":[["q",[1]],["q",[0]]],"op":{"type":"CX"}}, {"args":[["q",[0]]],"op":{"params":["0.1"],"type":"Rz"}}, - {"args": [["q", [0]]], "op": {"params": ["0.1", "0.2", "0.3"], "type": "TK1"}} + {"args": [["q", [0]]], "op": {"params": ["0.1", "alpha", "0.3"], "type": "TK1"}} ], "created_qubits": [], "discarded_qubits": [], diff --git a/src/lib.rs b/src/lib.rs index 2f0ada9c..bce5b59b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,9 +10,9 @@ pub mod circuit; pub mod extension; pub mod json; -mod ops; +pub(crate) mod ops; pub mod passes; -pub use ops::{Pauli, T2Op}; +pub use ops::{symbolic_constant_op, Pauli, T2Op}; #[cfg(feature = "portmatching")] pub mod portmatching; diff --git a/src/ops.rs b/src/ops.rs index 736637ac..47c5ee2e 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -8,11 +8,16 @@ use hugr::{ ops::{custom::ExternalOp, LeafOp, OpType}, std_extensions::arithmetic::float_types::FLOAT64_TYPE, type_row, - types::FunctionType, + types::{ + type_param::{CustomTypeArg, TypeArg, TypeParam}, + CustomType, FunctionType, TypeBound, + }, Extension, }; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; + use std::str::FromStr; use strum::IntoEnumIterator; use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr}; @@ -51,6 +56,9 @@ pub enum T2Op { ZZMax, Measure, RzF64, + RxF64, + PhasedX, + ZZPhase, TK1, } #[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)] @@ -111,8 +119,10 @@ impl SimpleOpEnum for T2Op { match self { H | T | S | X | Y | Z | Tdg | Sdg => FunctionType::new(one_qb_row.clone(), one_qb_row), CX | ZZMax => FunctionType::new(two_qb_row.clone(), two_qb_row), + ZZPhase => FunctionType::new(type_row![QB_T, QB_T, FLOAT64_TYPE], two_qb_row), Measure => FunctionType::new(one_qb_row, type_row![QB_T, BOOL_T]), - RzF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row), + RzF64 | RxF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row), + PhasedX => FunctionType::new(type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE], one_qb_row), TK1 => FunctionType::new( type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE, FLOAT64_TYPE], one_qb_row, @@ -152,20 +162,74 @@ impl T2Op { use T2Op::*; match self { - X => vec![(0, Pauli::X)], - T | Z | S | Tdg | Sdg | RzF64 => vec![(0, Pauli::Z)], + X | RxF64 => vec![(0, Pauli::X)], + T | Z | S | Tdg | Sdg | RzF64 | Measure => vec![(0, Pauli::Z)], CX => vec![(0, Pauli::Z), (1, Pauli::X)], - ZZMax => vec![(0, Pauli::Z), (1, Pauli::Z)], + ZZMax | ZZPhase => vec![(0, Pauli::Z), (1, Pauli::Z)], // by default, no commutation _ => vec![], } } } +/// The type of the symbolic expression opaque type arg. +pub const SYM_EXPR_T: CustomType = + CustomType::new_simple(SmolStr::new_inline("SymExpr"), EXTENSION_ID, TypeBound::Eq); + +const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_float"); + +/// Initialize a new custom symbolic expression constant op from a string. +pub fn symbolic_constant_op(s: &str) -> OpType { + let value: serde_yaml::Value = s.into(); + let l: LeafOp = EXTENSION + .instantiate_extension_op( + &SYM_OP_ID, + vec![TypeArg::Opaque( + CustomTypeArg::new(SYM_EXPR_T, value).unwrap(), + )], + ) + .unwrap() + .into(); + l.into() +} + +/// match against a symbolic constant +pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> { + if let OpType::LeafOp(LeafOp::CustomOp(e)) = op { + match e.as_ref() { + ExternalOp::Extension(e) + if e.def().name() == &SYM_OP_ID && e.def().extension() == &EXTENSION_ID => + { + // TODO also check extension name + + let Some(TypeArg::Opaque(s)) = e.args().get(0) else { + panic!("should be an opaque type arg.") + }; + + let serde_yaml::Value::String(s) = &s.value else { + panic!("unexpected yaml value.") + }; + + Some(s) + } + ExternalOp::Opaque(_) => todo!(), + _ => None, + } + } else { + None + } +} + fn extension() -> Extension { let mut e = Extension::new(EXTENSION_ID); load_all_ops::(&mut e).expect("add fail"); - + e.add_op_custom_sig_simple( + SYM_OP_ID, + "Store a sympy expression that can be evaluated to a float.".to_string(), + vec![TypeParam::Opaque(SYM_EXPR_T)], + |_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])), + ) + .unwrap(); e } @@ -213,6 +277,7 @@ impl TryFrom for T2Op { } } +/// load all variants of a `SimpleOpEnum` in to an extension as op defs. fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> { for op in T::all_variants() { op.add_to_extension(extension)?;