diff --git a/Cargo.lock b/Cargo.lock index e0589a7d..29b3b12b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,6 +138,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "block2" version = "0.5.1" @@ -379,6 +388,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -455,6 +473,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.3.0" @@ -552,6 +580,16 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -733,6 +771,16 @@ dependencies = [ "byteorder", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "glob" version = "0.3.1" @@ -1226,6 +1274,51 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdbef9d1d47087a895abd220ed25eb4ad973a5e26f6a4367b038c25e28dfc2d9" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d3a6e3394ec80feb3b6393c725571754c6188490265c61aaf260810d6b95aa0" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94429506bde1ca69d1b5601962c73f4172ab4726571a59ea95931218cb0e930e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.71", +] + +[[package]] +name = "pest_meta" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac8a071862e93690b6e34e9a5fb8e33ff3734473ac0245b27232222c4906a33f" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "petgraph" version = "0.6.5" @@ -1659,6 +1752,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1912,6 +2016,8 @@ dependencies = [ "lazy_static", "num-complex", "num-rational", + "pest", + "pest_derive", "petgraph", "portgraph 0.12.2", "portmatching", @@ -2049,6 +2155,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "059d83cc991e7a42fc37bd50941885db0888e34209f8cfd9aab07ddec03bc9cf" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "typetag" version = "0.2.18" @@ -2073,6 +2185,12 @@ dependencies = [ "syn 2.0.71", ] +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -2150,6 +2268,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "walkdir" version = "2.5.0" diff --git a/Cargo.toml b/Cargo.toml index 71bfd17c..807f3c49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ num-complex = "0.4" num-rational = "0.4" num_cpus = "1.16.0" peak_alloc = "0.2.0" +pest = "2.7.13" +pest_derive = "2.7.13" petgraph = { version = "0.6.3", default-features = false } priority-queue = "2.1.1" rayon = "1.5" diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 36bda100..4b1364d4 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -68,6 +68,8 @@ chrono = { workspace = true } bytemuck = { workspace = true } crossbeam-channel = { workspace = true } tracing = { workspace = true } +pest = { workspace = true } +pest_derive = { workspace = true } zstd = { workspace = true, optional = true } # Required to acces the `Package` type. # Remove once https://github.com/CQCL/hugr/issues/1530 is fixed. diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index d52f6698..cf428b7a 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -717,8 +717,6 @@ mod tests { assert_eq!(circ.circuit_signature().input_count(), qubits + bits); assert_eq!(circ.circuit_signature().output_count(), qubits + bits); assert_eq!(circ.qubit_count(), qubits); - assert_eq!(circ.num_operations(), 3); - assert_eq!(circ.operations().count(), 3); assert_eq!(circ.units().count(), qubits + bits); assert_eq!(circ.nonlinear_units().count(), bits); diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs index 1ffffbc4..44e7e685 100644 --- a/tket2/src/serialize/pytket.rs +++ b/tket2/src/serialize/pytket.rs @@ -3,8 +3,8 @@ mod decoder; mod encoder; mod op; +mod param; -use hugr::std_extensions::arithmetic::float_types::ConstF64; use hugr::types::Type; use hugr::Node; @@ -20,14 +20,13 @@ use std::hash::{Hash, Hasher}; use std::path::Path; use std::{fs, io}; -use hugr::ops::{NamedOp, OpType, Value}; +use hugr::ops::{NamedOp, OpType}; use derive_more::{Display, Error, From}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype::OpType as SerialOpType; use crate::circuit::Circuit; -use crate::extension::rotation::ConstRotation; use self::decoder::Tk1Decoder; use self::encoder::Tk1Encoder; @@ -282,48 +281,6 @@ pub enum TK1ConvertError { FileLoadError(io::Error), } -/// Try to interpret a TKET1 parameter as a constant value. -/// -/// Angle parameters in TKET1 are encoded as a number of half-turns, -/// whereas HUGR uses radians. -#[inline] -fn try_param_to_constant(param: &str) -> Option { - fn parse_val(n: &str) -> Option { - n.parse::().ok() - } - - let half_turns = if let Some(f) = parse_val(param) { - f - } else if param.split('/').count() == 2 { - // TODO: Use the rational types from `Hugr::extensions::rotation` - let (n, d) = param.split_once('/').unwrap(); - let n = parse_val(n)?; - let d = parse_val(d)?; - n / d - } else { - return None; - }; - - ConstRotation::new(half_turns).ok().map(Into::into) -} - -/// Convert a HUGR rotation or float constant to a TKET1 parameter. -/// -/// Angle parameters in TKET1 are encoded as a number of half-turns, -/// whereas HUGR uses radians. -#[inline] -fn try_constant_to_param(val: &Value) -> Option { - if let Some(const_angle) = val.get_custom_value::() { - let half_turns = const_angle.half_turns(); - Some(half_turns.to_string()) - } else if let Some(const_float) = val.get_custom_value::() { - let float = const_float.value(); - Some(float.to_string()) - } else { - None - } -} - /// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, /// avoiding string and vector clones on lookup. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs index aff1cfd0..2763ad47 100644 --- a/tket2/src/serialize/pytket/decoder.rs +++ b/tket2/src/serialize/pytket/decoder.rs @@ -6,21 +6,25 @@ use hugr::builder::{Container, Dataflow, DataflowHugr, FunctionBuilder}; use hugr::extension::prelude::{BOOL_T, QB_T}; use hugr::ops::handle::NodeHandle; -use hugr::ops::OpType; +use hugr::ops::{OpType, Value}; +use hugr::std_extensions::arithmetic::float_types::ConstF64; use hugr::types::Signature; use hugr::{Hugr, Wire}; +use derive_more::Display; use itertools::{EitherOrBoth, Itertools}; use serde_json::json; use tket_json_rs::circuit_json; use tket_json_rs::circuit_json::SerialCircuit; use super::op::Tk1Op; +use super::param::decode::{parse_pytket_param, PytketParam}; use super::{ - try_param_to_constant, OpConvertError, RegisterHash, TK1ConvertError, - METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, - METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, + OpConvertError, RegisterHash, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, + METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, + METADATA_Q_REGISTERS, }; +use crate::extension::rotation::RotationOp; use crate::extension::{REGISTRY, TKET1_EXTENSION_ID}; use crate::symbolic_constant_op; @@ -246,7 +250,7 @@ impl Tk1Decoder { tk1op .param_ports() .zip(params) - .map(|(_port, param)| self.create_param_wire(param)), + .map(|(_port, param)| self.load_parameter(param)), ); Ok((inputs, outputs)) @@ -256,17 +260,60 @@ impl Tk1Decoder { /// /// If the parameter is a constant, a constant definition is added to the Hugr. /// - /// TODO: If the parameter is a variable, returns the corresponding wire from the input. - fn create_param_wire(&mut self, param: String) -> Wire { - match try_param_to_constant(¶m) { - Some(const_op) => self.hugr.add_load_const(const_op), - None => { - // store string in custom op. - let symb_op = symbolic_constant_op(param); - let o = self.hugr.add_dataflow_op(symb_op, []).unwrap(); - o.out_wire(0) + /// The returned wires always have float type. + fn load_parameter(&mut self, param: String) -> Wire { + fn process( + hugr: &mut FunctionBuilder, + parsed: PytketParam, + param: &str, + ) -> LoadedParameter { + match parsed { + PytketParam::Constant(half_turns) => { + let value: Value = ConstF64::new(half_turns).into(); + let wire = hugr.add_load_const(value); + LoadedParameter::float(wire) + } + PytketParam::Sympy(expr) => { + // store string in custom op. + let symb_op = symbolic_constant_op(expr.to_string()); + let wire = hugr.add_dataflow_op(symb_op, []).unwrap().out_wire(0); + LoadedParameter::rotation(wire) + } + PytketParam::InputVariable { name } => { + // Special case for the name "pi", inserts a `ConstRotation::PI` instead. + if name == "pi" { + let value: Value = ConstF64::new(std::f64::consts::PI).into(); + let wire = hugr.add_load_const(value); + return LoadedParameter::float(wire); + } + + // TODO: We need to add a `FunctionBuilder::add_input` function on the hugr side. + // Here we just add an opaque sympy box instead. + // https://github.com/CQCL/hugr/issues/1562 + // https://github.com/CQCL/tket2/issues/628 + process(hugr, PytketParam::Sympy(name), param) + } + PytketParam::Operation { op, args } => { + // We assume all operations take float inputs. + let input_wires = args + .into_iter() + .map(|arg| process(hugr, arg, param).as_float(hugr).wire) + .collect_vec(); + let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| { + panic!("Error while decoding pytket operation parameter \"{param}\". {e}",) + }); + // TODO: Replace with `res.num_value_outputs` + // https://github.com/CQCL/hugr/pull/1560 + assert_eq!(res.outputs().count(), 1, "An operation decoded from the pytket op parameter \"{param}\" had {} outputs", res.outputs().count()); + LoadedParameter::float(res.out_wire(0)) + } } } + + let parsed = parse_pytket_param(¶m); + process(&mut self.hugr, parsed, ¶m) + .as_rotation(&mut self.hugr) + .wire } /// Return the [`Wire`] associated with a register. @@ -295,3 +342,84 @@ fn check_register(register: &circuit_json::Register) -> Result<(), TK1ConvertErr Ok(()) } } + +/// The type of a loaded parameter in the Hugr. +#[derive(Debug, Display, Clone, Copy, Hash, PartialEq, Eq)] +enum LoadedParameterType { + Float, + Rotation, +} + +/// A loaded parameter in the Hugr. +/// +/// Tracking the type of the wire lets us delay conversion between the types +/// until they are actually needed. +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +struct LoadedParameter { + /// The type of the parameter. + pub typ: LoadedParameterType, + /// The wire where the parameter is loaded. + pub wire: Wire, +} + +impl LoadedParameter { + /// Returns a `LoadedParameter` for a float param. + pub fn float(wire: Wire) -> LoadedParameter { + LoadedParameter { + typ: LoadedParameterType::Float, + wire, + } + } + + /// Returns a `LoadedParameter` for a rotation param. + pub fn rotation(wire: Wire) -> LoadedParameter { + LoadedParameter { + typ: LoadedParameterType::Rotation, + wire, + } + } + + /// Convert the parameter into a given type, if necessary. + /// + /// Adds the necessary operations to the Hugr and returns a new wire. + pub fn as_type( + &self, + typ: LoadedParameterType, + hugr: &mut FunctionBuilder, + ) -> LoadedParameter { + match (self.typ, typ) { + (LoadedParameterType::Float, LoadedParameterType::Rotation) => { + let wire = hugr + .add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire]) + .unwrap() + .out_wire(0); + LoadedParameter::rotation(wire) + } + (LoadedParameterType::Rotation, LoadedParameterType::Float) => { + let wire = hugr + .add_dataflow_op(RotationOp::to_halfturns, [self.wire]) + .unwrap() + .out_wire(0); + LoadedParameter::float(wire) + } + _ => { + debug_assert_eq!(self.typ, typ, "cannot convert {} to {}", self.typ, typ); + *self + } + } + } + + /// Convert the parameter into a float, if necessary. + /// + /// Adds the necessary operations to the Hugr and returns a new wire. + pub fn as_float(&self, hugr: &mut FunctionBuilder) -> LoadedParameter { + self.as_type(LoadedParameterType::Float, hugr) + } + + /// Convert the parameter into a rotation, if necessary. + /// + /// Adds the necessary operations to the Hugr and returns a new wire. + pub fn as_rotation(&self, hugr: &mut FunctionBuilder) -> LoadedParameter { + self.as_type(LoadedParameterType::Rotation, hugr) + } +} diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs index 8f3f0430..e50c3315 100644 --- a/tket2/src/serialize/pytket/encoder.rs +++ b/tket2/src/serialize/pytket/encoder.rs @@ -5,7 +5,6 @@ use std::collections::{HashMap, HashSet, VecDeque}; use hugr::extension::prelude::{BOOL_T, QB_T}; use hugr::ops::{OpTrait, OpType}; -use hugr::std_extensions::arithmetic::float_ops::FloatOps; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::Type; use hugr::{HugrView, Wire}; @@ -15,17 +14,15 @@ use tket_json_rs::circuit_json::{self, SerialCircuit}; use crate::circuit::command::{CircuitUnit, Command}; use crate::circuit::Circuit; -use crate::extension::rotation::{RotationOp, ROTATION_TYPE}; -use crate::extension::sympy::SympyOp; -use crate::ops::match_symb_const_op; +use crate::extension::rotation::ROTATION_TYPE; use crate::serialize::pytket::RegisterHash; use crate::Tk2Op; use super::op::Tk1Op; +use super::param::encode::fold_param_op; use super::{ - try_constant_to_param, OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, - METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, - METADATA_Q_REGISTERS, + OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, + METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, }; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. @@ -53,7 +50,7 @@ impl Tk1Encoder { // Check for unsupported input types. for (_, _, typ) in circ.units() { - if ![ROTATION_TYPE, QB_T, BOOL_T].contains(&typ) { + if ![ROTATION_TYPE, FLOAT64_TYPE, QB_T, BOOL_T].contains(&typ) { return Err(TK1ConvertError::NonSerializableInputs { typ }); } } @@ -129,7 +126,7 @@ impl Tk1Encoder { ) }); bit_args.push(reg); - } else if ty == ROTATION_TYPE { + } else if [ROTATION_TYPE, FLOAT64_TYPE].contains(&ty) { let CircuitUnit::Wire(param_wire) = unit else { unreachable!("Angle types are not linear.") }; @@ -557,7 +554,7 @@ impl ParameterTracker { let mut tracker = ParameterTracker::default(); let angle_input_wires = circ.units().filter_map(|u| match u { - (CircuitUnit::Wire(w), _, ty) if ty == ROTATION_TYPE => Some(w), + (CircuitUnit::Wire(w), _, ty) if [ROTATION_TYPE, FLOAT64_TYPE].contains(&ty) => Some(w), _ => None, }); @@ -615,50 +612,11 @@ impl ParameterTracker { node: command.node(), }); }; - inputs.push(param); + inputs.push(param.as_str()); } - let param = match optype { - OpType::Const(const_op) => { - // New constant, register it if it can be interpreted as a parameter. - match try_constant_to_param(const_op.value()) { - Some(param) => param, - None => return Ok(false), - } - } - OpType::LoadConstant(_op_type) => { - // Re-use the parameter from the input. - inputs[0].clone() - } - // Encode some angle and float operations directly as strings using - // the already encoded inputs. Fail if the operation is not - // supported, and let the operation encoding process it instead. - OpType::ExtensionOp(_) => { - if let Some(s) = optype - .cast::() - .and_then(|op| self.encode_rotation_op(&op, inputs.as_slice())) - { - s - } else if let Some(s) = optype - .cast::() - .and_then(|op| self.encode_float_op(&op, inputs.as_slice())) - { - s - } else if let Some(s) = optype - .cast::() - .and_then(|op| self.encode_sympy_op(&op, inputs.as_slice())) - { - s - } else { - return Ok(false); - } - } - _ => { - let Some(s) = match_symb_const_op(optype) else { - return Ok(false); - }; - s.to_string() - } + let Some(param) = fold_param_op(optype, &inputs) else { + return Ok(false); }; for (unit, _, _) in command.outputs() { @@ -678,54 +636,6 @@ impl ParameterTracker { fn get(&self, wire: &Wire) -> Option<&String> { self.parameters.get(wire) } - - /// Encode an [`RotationOp`]s as a string, given its encoded inputs. - /// - /// `inputs` contains the expressions to compute each input. - fn encode_rotation_op(&self, op: &RotationOp, inputs: &[&String]) -> Option { - let s = match op { - RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]), - // Encode/decode the rotation as pytket parameters, expressed as half-turns. - // Note that the tracked parameter strings are always written in half-turns, - // so the conversion here is a no-op. - RotationOp::to_halfturns => inputs[0].clone(), - RotationOp::from_halfturns_unchecked => inputs[0].clone(), - // The checked conversion returns an option, which we do not support. - RotationOp::from_halfturns => return None, - }; - Some(s) - } - - /// Encode an [`FloatOps`] as a string, given its encoded inputs. - fn encode_float_op(&self, op: &FloatOps, inputs: &[&String]) -> Option { - let s = match op { - FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]), - FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]), - FloatOps::fneg => format!("(-{})", inputs[0]), - FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]), - FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]), - FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]), - FloatOps::ffloor => format!("floor({})", inputs[0]), - FloatOps::fceil => format!("ceil({})", inputs[0]), - FloatOps::fround => format!("round({})", inputs[0]), - FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]), - FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]), - FloatOps::fabs => format!("abs({})", inputs[0]), - _ => return None, - }; - Some(s) - } - - /// Encode a [`SympyOp`]s as a string. - /// - /// Note that the sympy operation does not have any inputs. - fn encode_sympy_op(&self, op: &SympyOp, inputs: &[&String]) -> Option { - if !inputs.is_empty() { - return None; - } - - Some(op.expr.clone()) - } } /// A utility class for finding new unused qubit/bit names. diff --git a/tket2/src/serialize/pytket/op/native.rs b/tket2/src/serialize/pytket/op/native.rs index cb8c8fee..9a18a2bc 100644 --- a/tket2/src/serialize/pytket/op/native.rs +++ b/tket2/src/serialize/pytket/op/native.rs @@ -3,6 +3,7 @@ use hugr::extension::prelude::{Noop, BOOL_T, QB_T}; use hugr::ops::{OpTrait, OpType}; +use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::Signature; use hugr::IncomingPort; @@ -159,7 +160,7 @@ impl NativeOp { let types = sig.input_types().to_owned(); sig.input_ports() .zip(types) - .filter(|(_, ty)| ty == &ROTATION_TYPE) + .filter(|(_, ty)| [ROTATION_TYPE, FLOAT64_TYPE].contains(ty)) .map(|(port, _)| port) }) } @@ -179,7 +180,7 @@ impl NativeOp { self.input_qubits += 1; } else if ty == &BOOL_T { self.input_bits += 1; - } else if ty == &ROTATION_TYPE { + } else if [ROTATION_TYPE, FLOAT64_TYPE].contains(ty) { self.num_params += 1; } } diff --git a/tket2/src/serialize/pytket/param.rs b/tket2/src/serialize/pytket/param.rs new file mode 100644 index 00000000..951e7585 --- /dev/null +++ b/tket2/src/serialize/pytket/param.rs @@ -0,0 +1,4 @@ +//! Encoding and decoding pytket operation parameters as hugr ops. + +pub mod decode; +pub mod encode; diff --git a/tket2/src/serialize/pytket/param/decode.rs b/tket2/src/serialize/pytket/param/decode.rs new file mode 100644 index 00000000..48cbb7ae --- /dev/null +++ b/tket2/src/serialize/pytket/param/decode.rs @@ -0,0 +1,250 @@ +//! Definitions for decoding parameter expressions from pytket operations. +//! +//! This is based on the `pest` grammar defined in `param.pest`. + +use derive_more::Display; +use hugr::ops::{NamedOp, OpType}; +use hugr::std_extensions::arithmetic::float_ops::FloatOps; +use itertools::Itertools; +use pest::iterators::{Pair, Pairs}; +use pest::pratt_parser::PrattParser; +use pest::Parser; +use pest_derive::Parser; + +/// The parsed AST for a pytket operation parameter. +/// +/// The leafs of the AST are either a constant value, a variable name, or an +/// unrecognized sympy expression. +/// +/// Return type of [`parse_pytket_param`]. +#[derive(Debug, Display, Clone, PartialEq)] +pub enum PytketParam<'a> { + /// A constant value that can be loaded directly. + #[display("{_0}")] + Constant(f64), + /// A variable that should be routed as an input. + #[display("\"{name}\"")] + InputVariable { + /// The variable name. + name: &'a str, + }, + /// Unrecognized sympy expression. + /// Will be emitted as a [`SympyOp`]. + #[display("Sympy(\"{_0}\")")] + Sympy(&'a str), + /// An operation on some nested expressions. + #[display("{}({})", op.name(), args.iter().map(|a| a.to_string()).join(", "))] + Operation { + op: OpType, + args: Vec>, + }, +} + +/// Parse a TKET1 operation parameter, and return an AST representing the expression. +#[inline] +pub fn parse_pytket_param(param: &str) -> PytketParam<'_> { + let Ok(mut parsed) = ParamParser::parse(Rule::parameter, param) else { + // The parameter could not be parsed, so we just return it as an opaque sympy expression. + return PytketParam::Sympy(param); + }; + let parsed = parsed + .next() + .expect("The `parameter` rule can only be matched once."); + + parse_infix_ops(parsed.into_inner()) +} + +#[derive(Parser)] +#[grammar = "serialize/pytket/param/param.pest"] +struct ParamParser; + +lazy_static::lazy_static! { + /// Precedence parser used to define the order of infix operations. + /// + /// Based on the calculator example from `pest`. + /// https://pest.rs/book/examples/calculator.html + static ref PRATT_PARSER: PrattParser = { + use pest::pratt_parser::{Assoc::*, Op}; + use Rule::*; + + // Precedence is defined lowest to highest + PrattParser::new() + // Addition and subtract have equal precedence + .op(Op::infix(add, Left) | Op::infix(subtract, Left)) + .op(Op::infix(multiply, Left) | Op::infix(divide, Left)) + .op(Op::infix(power, Left)) + }; +} + +/// Parse a match of the [`Rule::expr`] rule. +/// +/// This takes a sequence of rule matches alternating [`Rule::term`]s and infix operations. +fn parse_infix_ops(pairs: Pairs<'_, Rule>) -> PytketParam<'_> { + PRATT_PARSER + .map_primary(|primary| parse_term(primary)) + .map_infix(|lhs, op, rhs| { + let op = match op.as_rule() { + Rule::add => FloatOps::fadd, + Rule::subtract => FloatOps::fsub, + Rule::multiply => FloatOps::fmul, + Rule::divide => FloatOps::fdiv, + Rule::power => FloatOps::fpow, + rule => unreachable!("Expr::parse expected infix operation, found {:?}", rule), + } + .into(); + PytketParam::Operation { + op, + args: vec![lhs, rhs], + } + }) + .parse(pairs) +} + +/// Parse a match of the silent [`Rule::term`] rule. +fn parse_term(pair: Pair<'_, Rule>) -> PytketParam<'_> { + match pair.as_rule() { + Rule::expr => parse_infix_ops(pair.into_inner()), + Rule::implicit_multiply => { + let mut pairs = pair.into_inner(); + let lhs = parse_term(pairs.next().unwrap()); + let rhs = parse_term(pairs.next().unwrap()); + PytketParam::Operation { + op: FloatOps::fmul.into(), + args: vec![lhs, rhs], + } + } + Rule::num => parse_number(pair), + Rule::unary_minus => PytketParam::Operation { + op: FloatOps::fneg.into(), + args: vec![parse_term(pair.into_inner().next().unwrap())], + }, + Rule::function_call => parse_function_call(pair), + Rule::ident => PytketParam::InputVariable { + name: pair.as_str(), + }, + rule => unreachable!("Term::parse expected a term, found {:?}", rule), + } +} + +/// Parse a match of the [`Rule::num`] rule. +fn parse_number(pair: Pair<'_, Rule>) -> PytketParam<'_> { + let num = pair.as_str(); + let half_turns = num + .parse::() + .unwrap_or_else(|_| panic!("`num` rule matched invalid number \"{num}\"")); + PytketParam::Constant(half_turns) +} + +/// Parse a match of the [`Rule::function_call`] rule. +fn parse_function_call(pair: Pair<'_, Rule>) -> PytketParam<'_> { + let pair_str = pair.as_str(); + let mut args = pair.into_inner(); + let name = args + .next() + .expect("Function call must have a name") + .as_str(); + let op = match name { + "max" => FloatOps::fmax.into(), + "min" => FloatOps::fmin.into(), + "abs" => FloatOps::fabs.into(), + "floor" => FloatOps::ffloor.into(), + "ceil" => FloatOps::fceil.into(), + "round" => FloatOps::fround.into(), + // Unrecognized function name. + // Treat it as an opaque sympy expression. + _ => return PytketParam::Sympy(pair_str), + }; + + let args = args.map(|arg| parse_term(arg)).collect::>(); + PytketParam::Operation { op, args } +} + +#[cfg(test)] +mod test { + use super::*; + use rstest::rstest; + + #[rstest] + #[case::int("42", PytketParam::Constant(42.0))] + #[case::float("42.37", PytketParam::Constant(42.37))] + #[case::float_pointless("37.", PytketParam::Constant(37.))] + #[case::exp("42e4", PytketParam::Constant(42e4))] + #[case::neg("-42.55", PytketParam::Constant(-42.55))] + #[case::parens("(42)", PytketParam::Constant(42.))] + #[case::var("f64", PytketParam::InputVariable{name: "f64"})] + #[case::add("42 + f64", PytketParam::Operation { + op: FloatOps::fadd.into(), + args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] + })] + #[case::sub("42 - 2", PytketParam::Operation { + op: FloatOps::fsub.into(), + args: vec![PytketParam::Constant(42.), PytketParam::Constant(2.)] + })] + #[case::product_implicit("42 f64", PytketParam::Operation { + op: FloatOps::fmul.into(), + args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] + })] + #[case::product_implicit2("42f64", PytketParam::Operation { + op: FloatOps::fmul.into(), + args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] + })] + #[case::product_implicit3("42 e4", PytketParam::Operation { + op: FloatOps::fmul.into(), + args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "e4"}] + })] + #[case::max("max(42, f64)", PytketParam::Operation { + op: FloatOps::fmax.into(), + args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] + })] + #[case::minus("-f64", PytketParam::Operation { + op: FloatOps::fneg.into(), + args: vec![PytketParam::InputVariable{name: "f64"}] + })] + #[case::unknown("unknown_op(42, f64)", PytketParam::Sympy("unknown_op(42, f64)"))] + #[case::unknown_no_params("unknown_op()", PytketParam::Sympy("unknown_op()"))] + #[case::nested("max(42, unknown_op(37))", PytketParam::Operation { + op: FloatOps::fmax.into(), + args: vec![PytketParam::Constant(42.), PytketParam::Sympy("unknown_op(37)")] + })] + #[case::precedence("5-2/3x+4**6", PytketParam::Operation { + op: FloatOps::fadd.into(), + args: vec![ + PytketParam::Operation { + op: FloatOps::fsub.into(), + args: vec![ + PytketParam::Constant(5.), + PytketParam::Operation { op: FloatOps::fdiv.into(), args: vec![ + PytketParam::Constant(2.), + PytketParam::Operation { op: FloatOps::fmul.into(), args: vec![ + PytketParam::Constant(3.), + PytketParam::InputVariable{name: "x"}, + ]} + ]} + ] + }, + PytketParam::Operation { op: FloatOps::fpow.into(), args: vec![ + PytketParam::Constant(4.), + PytketParam::Constant(6.), + ]} + ] + })] + #[case::associativity("1-2-3+4", PytketParam::Operation { + op: FloatOps::fadd.into(), + args: vec![ + PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![ + PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![ + PytketParam::Constant(1.), + PytketParam::Constant(2.), + ]}, + PytketParam::Constant(3.), + ]}, + PytketParam::Constant(4.), + ] + })] + fn parse_param(#[case] param: &str, #[case] expected: PytketParam) { + let parsed = parse_pytket_param(param); + if parsed != expected { + panic!("Incorrect parameter parsing\n\texpression: \"{param}\"\n\tparsed: {parsed}\n\texpected: {expected}"); + } + } +} diff --git a/tket2/src/serialize/pytket/param/encode.rs b/tket2/src/serialize/pytket/param/encode.rs new file mode 100644 index 00000000..962dcfbc --- /dev/null +++ b/tket2/src/serialize/pytket/param/encode.rs @@ -0,0 +1,125 @@ +//! Definitions for encoding hugr graphs into pytket op parameters. + +use hugr::ops::{OpType, Value}; +use hugr::std_extensions::arithmetic::float_ops::FloatOps; +use hugr::std_extensions::arithmetic::float_types::ConstF64; + +use crate::extension::rotation::{ConstRotation, RotationOp}; +use crate::extension::sympy::SympyOp; +use crate::ops::match_symb_const_op; + +/// Fold a rotation or float operation into a string, given the string +/// representations of its inputs. +/// +/// The folded op must have a single string output. +/// +/// Returns `None` if the operation cannot be folded. +pub fn fold_param_op(optype: &OpType, inputs: &[&str]) -> Option { + let param = match optype { + OpType::Const(const_op) => { + // New constant, register it if it can be interpreted as a parameter. + match try_constant_to_param(const_op.value()) { + Some(param) => param, + None => return None, + } + } + OpType::LoadConstant(_op_type) => { + // Re-use the parameter from the input. + inputs[0].to_string() + } + // Encode some angle and float operations directly as strings using + // the already encoded inputs. Fail if the operation is not + // supported, and let the operation encoding process it instead. + OpType::ExtensionOp(_) => { + if let Some(s) = optype + .cast::() + .and_then(|op| encode_rotation_op(&op, inputs)) + { + s + } else if let Some(s) = optype + .cast::() + .and_then(|op| encode_float_op(&op, inputs)) + { + s + } else if let Some(s) = optype + .cast::() + .and_then(|op| encode_sympy_op(&op, inputs)) + { + s + } else { + return None; + } + } + _ => match_symb_const_op(optype)?.to_string(), + }; + Some(param) +} + +/// Convert a HUGR rotation or float constant to a TKET1 parameter. +/// +/// Angle parameters in TKET1 are encoded as a number of half-turns, +/// whereas HUGR uses radians. +#[inline] +fn try_constant_to_param(val: &Value) -> Option { + if let Some(const_angle) = val.get_custom_value::() { + let half_turns = const_angle.half_turns(); + Some(half_turns.to_string()) + } else if let Some(const_float) = val.get_custom_value::() { + let float = const_float.value(); + + // Special case for pi rotations + if float == std::f64::consts::PI { + Some("pi".to_string()) + } else { + Some(float.to_string()) + } + } else { + None + } +} + +/// Encode an [`RotationOp`]s as a string, given its encoded inputs. +/// +/// `inputs` contains the expressions to compute each input. +fn encode_rotation_op(op: &RotationOp, inputs: &[&str]) -> Option { + let s = match op { + RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]), + // Encode/decode the rotation as pytket parameters, expressed as half-turns. + // Note that the tracked parameter strings are always written in half-turns, + // so the conversion here is a no-op. + RotationOp::to_halfturns => inputs[0].to_string(), + RotationOp::from_halfturns_unchecked => inputs[0].to_string(), + // The checked conversion returns an option, which we do not support. + RotationOp::from_halfturns => return None, + }; + Some(s) +} + +/// Encode an [`FloatOps`] as a string, given its encoded inputs. +fn encode_float_op(op: &FloatOps, inputs: &[&str]) -> Option { + let s = match op { + FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]), + FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]), + FloatOps::fneg => format!("(-{})", inputs[0]), + FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]), + FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]), + FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]), + FloatOps::ffloor => format!("floor({})", inputs[0]), + FloatOps::fceil => format!("ceil({})", inputs[0]), + FloatOps::fround => format!("round({})", inputs[0]), + FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]), + FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]), + FloatOps::fabs => format!("abs({})", inputs[0]), + _ => return None, + }; + Some(s) +} + +/// Encode a [`SympyOp`]s as a string. +fn encode_sympy_op(op: &SympyOp, inputs: &[&str]) -> Option { + if !inputs.is_empty() { + return None; + } + + Some(op.expr.clone()) +} diff --git a/tket2/src/serialize/pytket/param/param.pest b/tket2/src/serialize/pytket/param/param.pest new file mode 100644 index 00000000..d672d139 --- /dev/null +++ b/tket2/src/serialize/pytket/param/param.pest @@ -0,0 +1,55 @@ +//! A parser grammar for pytket operation parameters +//! +//! The grammar is a subset of sympy expressions, +//! unrecognised expressions will be boxed opaquely. + +/// Operation and variable identifiers +/// +/// The atomic marker `@` ensures no whitespace is present inside. +ident = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" )* } +/// Alias for `ident`. +variable = _{ ident } + +/// Explicit numeric constants. +/// +/// The atomic marker `@` ensures no whitespace is present inside. +num = @{ + "-"? + ~ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) + ~ ("." ~ ASCII_DIGIT*)? + ~ (^"e" ~ ("+" | "-")? ~ ASCII_DIGIT+)? +} + +/// Infix operations (e.g. `2 + 2`), +/// +/// The rust parser defines the operator precedence between these. +infix_operator = _{ add | subtract | power | multiply | divide } + add = { "+" } + subtract = { "-" } + power = { "**" } + multiply = { "*" } + divide = { "/" } + +/// Function calls like (`max(2,4)`) +function_call = { ident ~ ( "()" | "(" ~ expr ~ ("," ~ expr)* ~ ","? ~ ")") } +/// Unary negation, with a special identifier +unary_minus = { "-" ~ term } + +/// Implicit multiplication lets us write expressions like `2x`. +/// This has higher precedence that `infix_operator`s. +/// +/// The second match is a subset of `term`, picked to avoid ambiguity. +/// If we used `term` directly, the expression `5-2` could be interpreted as `5 * (-2)`. +implicit_multiply = { num ~ ("(" ~ expr ~ ")" | function_call | variable) } + +/// Expressions are sequences of terms and infix operators +expr = { term ~ (infix_operator ~ term)* } +/// A standalone term, with no infix operators at the root. +term = _{ "(" ~ expr ~ ")" | implicit_multiply | num | unary_minus | function_call | variable } + +/// Main entry point of the parser. Ensures we always parse the whole input. +parameter = _{ SOI ~ expr ~ EOI } + +/// Allowed whitespace characters. +/// Each `~` separator in the preceding rules can contain any number of these (including zero). +WHITESPACE = _{ " " | "\t" } diff --git a/tket2/src/serialize/pytket/tests.rs b/tket2/src/serialize/pytket/tests.rs index 076efd93..65a35460 100644 --- a/tket2/src/serialize/pytket/tests.rs +++ b/tket2/src/serialize/pytket/tests.rs @@ -63,8 +63,8 @@ const PARAMETERIZED: &str = r#"{ "commands": [ {"args":[["q",[0]]],"op":{"type":"H"}}, {"args":[["q",[1]],["q",[0]]],"op":{"type":"CX"}}, - {"args":[["q",[0]]],"op":{"params":["1.5707963267948966/pi"],"type":"Rz"}}, - {"args": [["q", [0]]], "op": {"params": ["3.141596/pi", "alpha", "0.7853981633974483/pi"], "type": "TK1"}} + {"args":[["q",[0]]],"op":{"params":["(1.5707963267948966 / pi)"],"type":"Rz"}}, + {"args": [["q", [0]]], "op": {"params": ["(3.141596 / pi)", "alpha", "(0.7853981633974483 / pi)"], "type": "TK1"}} ], "created_qubits": [], "discarded_qubits": [], diff --git a/tket2/tests/badger_termination.rs b/tket2/tests/badger_termination.rs index c2efc923..4d15857a 100644 --- a/tket2/tests/badger_termination.rs +++ b/tket2/tests/badger_termination.rs @@ -61,5 +61,5 @@ fn badger_termination(simple_circ: Circuit, nam_4_2: DefaultBadgerOptimiser) { ..Default::default() }, ); - assert_eq!(opt_circ.commands().count(), 11); + assert_eq!(opt_circ.commands().count(), 14); }