Skip to content

Commit

Permalink
feat: symbolic parameters + more T2ops (#90)
Browse files Browse the repository at this point in the history
temporary - when we have a better idea of angle/rotation types can use a
HUGR input for symbols instead


Closes #82
  • Loading branch information
ss2165 authored Sep 6, 2023
1 parent 20d63a2 commit 2bbd8c6
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 33 deletions.
15 changes: 12 additions & 3 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,25 @@ impl From<OpConvertError> for TK1LoadError {
}
}

#[inline]
fn parse_val(n: &str) -> Option<f64> {
if n == "pi" {
// angles are in multiples of pi
Some(1.0)
} else {
n.parse::<f64>().ok()
}
}
/// Try to interpret a TKET1 parameter as a constant value.
#[inline]
fn try_param_to_constant(param: &str) -> Option<Value> {
if let Ok(f) = param.parse::<f64>() {
if let Some(f) = parse_val(param) {
Some(ConstF64::new(f).into())
} else if param.split('/').count() == 2 {
// TODO: Use the rational types from `Hugr::extensions::rotation`
let (n, d) = param.split_once('/').unwrap();
let n = n.parse::<f64>().unwrap();
let d = d.parse::<f64>().unwrap();
let n = parse_val(n)?;
let d = parse_val(d)?;
Some(ConstF64::new(n / d).into())
} else {
None
Expand Down
8 changes: 5 additions & 3 deletions src/json/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use tket_json_rs::circuit_json::SerialCircuit;
use super::op::JsonOp;
use super::{try_param_to_constant, METADATA_IMPLICIT_PERM, METADATA_PHASE};
use crate::extension::{LINEAR_BIT, REGISTRY};
use crate::symbolic_constant_op;

/// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`].
///
Expand Down Expand Up @@ -145,9 +146,10 @@ impl JsonDecoder {
self.hugr.add_load_const(const_op).unwrap()
}
None => {
// TODO: If the parameter is just a variable,
// return the corresponding wire from the input.
todo!("Variable parameters not yet supported")
// store string in custom op.
let symb_op = symbolic_constant_op(param);
let o = self.hugr.add_dataflow_op(symb_op, []).unwrap();
o.out_wire(0)
}
}
}
Expand Down
28 changes: 19 additions & 9 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit};
use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
use crate::extension::LINEAR_BIT;
use crate::ops::match_symb_const_op;

use super::op::JsonOp;
use super::{OpConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE};
Expand Down Expand Up @@ -89,11 +90,12 @@ impl JsonEncoder {
/// Add a circuit command to the serialization.
pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> {
// Register any output of the command that can be used as a TKET1 parameter.
self.record_parameters(&command, optype);

if let OpType::Const(_) | OpType::LoadConstant(_) = optype {
if self.record_parameters(&command, optype) {
// for now all ops that record parameters should be ignored (are
// just constants)
return Ok(());
}

let (args, params): (Vec<Register>, Vec<Wire>) =
command
.inputs()
Expand Down Expand Up @@ -141,9 +143,9 @@ impl JsonEncoder {
}

/// Record any output of the command that can be used as a TKET1 parameter.
///
/// Returns whether parameters were recorded.
/// Associates the output wires with the parameter expression.
fn record_parameters(&mut self, command: &Command, optype: &OpType) {
fn record_parameters(&mut self, command: &Command, optype: &OpType) -> bool {
// Only consider commands where all inputs are parameters.
let inputs = command
.inputs()
Expand All @@ -154,7 +156,7 @@ impl JsonEncoder {
})
.collect_vec();
if inputs.len() != command.inputs().len() {
return;
return false;
}

let param = match optype {
Expand All @@ -165,21 +167,28 @@ impl JsonEncoder {
if let Some(f) = v.downcast_ref::<ConstF64>() {
f.to_string()
} else {
return;
return false;
}
}
_ => return,
_ => return false,
}
}
OpType::LoadConstant(_op_type) => {
// Re-use the parameter from the input.
inputs[0].clone()
}
OpType::LeafOp(_) => {
if let Some(s) = match_symb_const_op(optype) {
s.to_string()
} else {
return false;
}
}
_ => {
// In the future we may want to support arithmetic operations.
// (Just concatenating the inputs and the operation symbol, no
// need for evaluation).
return;
return false;
}
};

Expand All @@ -188,6 +197,7 @@ impl JsonEncoder {
self.parameters.insert(*wire, param.clone());
}
}
true
}

fn unit_to_register(&self, unit: CircuitUnit) -> Option<circuit_json::Register> {
Expand Down
20 changes: 11 additions & 9 deletions src/json/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use itertools::Itertools;
use tket_json_rs::circuit_json;
use tket_json_rs::optype::OpType as JsonOpType;

use super::{try_param_to_constant, OpConvertError};
use super::OpConvertError;
use crate::extension::{try_unwrap_json_op, LINEAR_BIT};
use crate::T2Op;

Expand Down Expand Up @@ -164,14 +164,8 @@ impl JsonOp {
return;
};

let mut p_input_indices = 0..;
let param_inputs = params
.iter()
.map(|param| try_param_to_constant(param).map(|_| p_input_indices.next().unwrap()))
.collect();

self.num_params = p_input_indices.next().unwrap();
self.param_inputs = param_inputs;
self.num_params = params.len();
self.param_inputs = (0..params.len()).map(Some).collect();
}
}

Expand All @@ -189,7 +183,11 @@ impl From<&JsonOp> for OpType {
JsonOpType::Tdg => T2Op::Tdg.into(),
JsonOpType::X => T2Op::X.into(),
JsonOpType::Rz => T2Op::RzF64.into(),
JsonOpType::Rx => T2Op::RxF64.into(),
JsonOpType::TK1 => T2Op::TK1.into(),
JsonOpType::PhasedX => T2Op::PhasedX.into(),
JsonOpType::ZZMax => T2Op::ZZMax.into(),
JsonOpType::ZZPhase => T2Op::ZZPhase.into(),
JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(),
_ => LeafOp::CustomOp(Box::new(json_op.as_opaque_op())).into(),
}
Expand All @@ -216,7 +214,11 @@ impl TryFrom<&OpType> for JsonOp {
T2Op::H => JsonOpType::H,
T2Op::Measure => JsonOpType::Measure,
T2Op::RzF64 => JsonOpType::Rz,
T2Op::RxF64 => JsonOpType::Rx,
T2Op::TK1 => JsonOpType::TK1,
T2Op::PhasedX => JsonOpType::PhasedX,
T2Op::ZZMax => JsonOpType::ZZMax,
T2Op::ZZPhase => JsonOpType::ZZPhase,
_ => return Err(err()),
}
} else if let LeafOp::CustomOp(b) = leaf {
Expand Down
2 changes: 1 addition & 1 deletion src/json/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const PARAMETRIZED: &str = r#"{
{"args":[["q",[0]]],"op":{"type":"H"}},
{"args":[["q",[1]],["q",[0]]],"op":{"type":"CX"}},
{"args":[["q",[0]]],"op":{"params":["0.1"],"type":"Rz"}},
{"args": [["q", [0]]], "op": {"params": ["0.1", "0.2", "0.3"], "type": "TK1"}}
{"args": [["q", [0]]], "op": {"params": ["0.1", "alpha", "0.3"], "type": "TK1"}}
],
"created_qubits": [],
"discarded_qubits": [],
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
pub mod circuit;
pub mod extension;
pub mod json;
mod ops;
pub(crate) mod ops;
pub mod passes;
pub use ops::{Pauli, T2Op};
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;
Expand Down
77 changes: 71 additions & 6 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ use hugr::{
ops::{custom::ExternalOp, LeafOp, OpType},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::FunctionType,
types::{
type_param::{CustomTypeArg, TypeArg, TypeParam},
CustomType, FunctionType, TypeBound,
},
Extension,
};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;

use std::str::FromStr;
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr};
Expand Down Expand Up @@ -51,6 +56,9 @@ pub enum T2Op {
ZZMax,
Measure,
RzF64,
RxF64,
PhasedX,
ZZPhase,
TK1,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
Expand Down Expand Up @@ -111,8 +119,10 @@ impl SimpleOpEnum for T2Op {
match self {
H | T | S | X | Y | Z | Tdg | Sdg => FunctionType::new(one_qb_row.clone(), one_qb_row),
CX | ZZMax => FunctionType::new(two_qb_row.clone(), two_qb_row),
ZZPhase => FunctionType::new(type_row![QB_T, QB_T, FLOAT64_TYPE], two_qb_row),
Measure => FunctionType::new(one_qb_row, type_row![QB_T, BOOL_T]),
RzF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row),
RzF64 | RxF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row),
PhasedX => FunctionType::new(type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE], one_qb_row),
TK1 => FunctionType::new(
type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE, FLOAT64_TYPE],
one_qb_row,
Expand Down Expand Up @@ -152,20 +162,74 @@ impl T2Op {
use T2Op::*;

match self {
X => vec![(0, Pauli::X)],
T | Z | S | Tdg | Sdg | RzF64 => vec![(0, Pauli::Z)],
X | RxF64 => vec![(0, Pauli::X)],
T | Z | S | Tdg | Sdg | RzF64 | Measure => vec![(0, Pauli::Z)],
CX => vec![(0, Pauli::Z), (1, Pauli::X)],
ZZMax => vec![(0, Pauli::Z), (1, Pauli::Z)],
ZZMax | ZZPhase => vec![(0, Pauli::Z), (1, Pauli::Z)],
// by default, no commutation
_ => vec![],
}
}
}

/// The type of the symbolic expression opaque type arg.
pub const SYM_EXPR_T: CustomType =
CustomType::new_simple(SmolStr::new_inline("SymExpr"), EXTENSION_ID, TypeBound::Eq);

const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_float");

/// Initialize a new custom symbolic expression constant op from a string.
pub fn symbolic_constant_op(s: &str) -> OpType {
let value: serde_yaml::Value = s.into();
let l: LeafOp = EXTENSION
.instantiate_extension_op(
&SYM_OP_ID,
vec![TypeArg::Opaque(
CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
)],
)
.unwrap()
.into();
l.into()
}

/// match against a symbolic constant
pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> {
if let OpType::LeafOp(LeafOp::CustomOp(e)) = op {
match e.as_ref() {
ExternalOp::Extension(e)
if e.def().name() == &SYM_OP_ID && e.def().extension() == &EXTENSION_ID =>
{
// TODO also check extension name

let Some(TypeArg::Opaque(s)) = e.args().get(0) else {
panic!("should be an opaque type arg.")
};

let serde_yaml::Value::String(s) = &s.value else {
panic!("unexpected yaml value.")
};

Some(s)
}
ExternalOp::Opaque(_) => todo!(),
_ => None,
}
} else {
None
}
}

fn extension() -> Extension {
let mut e = Extension::new(EXTENSION_ID);
load_all_ops::<T2Op>(&mut e).expect("add fail");

e.add_op_custom_sig_simple(
SYM_OP_ID,
"Store a sympy expression that can be evaluated to a float.".to_string(),
vec![TypeParam::Opaque(SYM_EXPR_T)],
|_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])),
)
.unwrap();
e
}

Expand Down Expand Up @@ -213,6 +277,7 @@ impl TryFrom<LeafOp> for T2Op {
}
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in T::all_variants() {
op.add_to_extension(extension)?;
Expand Down

0 comments on commit 2bbd8c6

Please sign in to comment.