Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: symbolic parameters + more T2ops #90

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,25 @@ impl From<OpConvertError> for TK1LoadError {
}
}

#[inline]
fn parse_val(n: &str) -> Option<f64> {
if n == "pi" {
// angles are in multiples of pi
Some(1.0)
} else {
n.parse::<f64>().ok()
}
}
/// Try to interpret a TKET1 parameter as a constant value.
#[inline]
fn try_param_to_constant(param: &str) -> Option<Value> {
if let Ok(f) = param.parse::<f64>() {
if let Some(f) = parse_val(param) {
Some(ConstF64::new(f).into())
} else if param.split('/').count() == 2 {
// TODO: Use the rational types from `Hugr::extensions::rotation`
let (n, d) = param.split_once('/').unwrap();
let n = n.parse::<f64>().unwrap();
let d = d.parse::<f64>().unwrap();
let n = parse_val(n)?;
let d = parse_val(d)?;
Some(ConstF64::new(n / d).into())
} else {
None
Expand Down
8 changes: 5 additions & 3 deletions src/json/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use tket_json_rs::circuit_json::SerialCircuit;
use super::op::JsonOp;
use super::{try_param_to_constant, METADATA_IMPLICIT_PERM, METADATA_PHASE};
use crate::extension::{LINEAR_BIT, REGISTRY};
use crate::symbolic_constant_op;

/// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`].
///
Expand Down Expand Up @@ -145,9 +146,10 @@ impl JsonDecoder {
self.hugr.add_load_const(const_op).unwrap()
}
None => {
// TODO: If the parameter is just a variable,
// return the corresponding wire from the input.
todo!("Variable parameters not yet supported")
// store string in custom op.
let symb_op = symbolic_constant_op(param);
let o = self.hugr.add_dataflow_op(symb_op, []).unwrap();
o.out_wire(0)
}
}
}
Expand Down
28 changes: 19 additions & 9 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit};
use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
use crate::extension::LINEAR_BIT;
use crate::ops::match_symb_const_op;

use super::op::JsonOp;
use super::{OpConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE};
Expand Down Expand Up @@ -89,11 +90,12 @@ impl JsonEncoder {
/// Add a circuit command to the serialization.
pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> {
// Register any output of the command that can be used as a TKET1 parameter.
self.record_parameters(&command, optype);

if let OpType::Const(_) | OpType::LoadConstant(_) = optype {
if self.record_parameters(&command, optype) {
// for now all ops that record parameters should be ignored (are
// just constants)
return Ok(());
}

let (args, params): (Vec<Register>, Vec<Wire>) =
command
.inputs()
Expand Down Expand Up @@ -141,9 +143,9 @@ impl JsonEncoder {
}

/// Record any output of the command that can be used as a TKET1 parameter.
///
/// Returns whether parameters were recorded.
/// Associates the output wires with the parameter expression.
fn record_parameters(&mut self, command: &Command, optype: &OpType) {
fn record_parameters(&mut self, command: &Command, optype: &OpType) -> bool {
// Only consider commands where all inputs are parameters.
let inputs = command
.inputs()
Expand All @@ -154,7 +156,7 @@ impl JsonEncoder {
})
.collect_vec();
if inputs.len() != command.inputs().len() {
return;
return false;
}

let param = match optype {
Expand All @@ -165,21 +167,28 @@ impl JsonEncoder {
if let Some(f) = v.downcast_ref::<ConstF64>() {
f.to_string()
} else {
return;
return false;
}
}
_ => return,
_ => return false,
}
}
OpType::LoadConstant(_op_type) => {
// Re-use the parameter from the input.
inputs[0].clone()
}
OpType::LeafOp(_) => {
if let Some(s) = match_symb_const_op(optype) {
s.to_string()
} else {
return false;
}
}
_ => {
// In the future we may want to support arithmetic operations.
// (Just concatenating the inputs and the operation symbol, no
// need for evaluation).
return;
return false;
}
};

Expand All @@ -188,6 +197,7 @@ impl JsonEncoder {
self.parameters.insert(*wire, param.clone());
}
}
true
}

fn unit_to_register(&self, unit: CircuitUnit) -> Option<circuit_json::Register> {
Expand Down
20 changes: 11 additions & 9 deletions src/json/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use itertools::Itertools;
use tket_json_rs::circuit_json;
use tket_json_rs::optype::OpType as JsonOpType;

use super::{try_param_to_constant, OpConvertError};
use super::OpConvertError;
use crate::extension::{try_unwrap_json_op, LINEAR_BIT};
use crate::T2Op;

Expand Down Expand Up @@ -164,14 +164,8 @@ impl JsonOp {
return;
};

let mut p_input_indices = 0..;
let param_inputs = params
.iter()
.map(|param| try_param_to_constant(param).map(|_| p_input_indices.next().unwrap()))
.collect();

self.num_params = p_input_indices.next().unwrap();
self.param_inputs = param_inputs;
self.num_params = params.len();
self.param_inputs = (0..params.len()).map(Some).collect();
}
}

Expand All @@ -189,7 +183,11 @@ impl From<&JsonOp> for OpType {
JsonOpType::Tdg => T2Op::Tdg.into(),
JsonOpType::X => T2Op::X.into(),
JsonOpType::Rz => T2Op::RzF64.into(),
JsonOpType::Rx => T2Op::RxF64.into(),
JsonOpType::TK1 => T2Op::TK1.into(),
JsonOpType::PhasedX => T2Op::PhasedX.into(),
JsonOpType::ZZMax => T2Op::ZZMax.into(),
JsonOpType::ZZPhase => T2Op::ZZPhase.into(),
JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(),
_ => LeafOp::CustomOp(Box::new(json_op.as_opaque_op())).into(),
}
Expand All @@ -216,7 +214,11 @@ impl TryFrom<&OpType> for JsonOp {
T2Op::H => JsonOpType::H,
T2Op::Measure => JsonOpType::Measure,
T2Op::RzF64 => JsonOpType::Rz,
T2Op::RxF64 => JsonOpType::Rx,
T2Op::TK1 => JsonOpType::TK1,
T2Op::PhasedX => JsonOpType::PhasedX,
T2Op::ZZMax => JsonOpType::ZZMax,
T2Op::ZZPhase => JsonOpType::ZZPhase,
_ => return Err(err()),
}
} else if let LeafOp::CustomOp(b) = leaf {
Expand Down
2 changes: 1 addition & 1 deletion src/json/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const PARAMETRIZED: &str = r#"{
{"args":[["q",[0]]],"op":{"type":"H"}},
{"args":[["q",[1]],["q",[0]]],"op":{"type":"CX"}},
{"args":[["q",[0]]],"op":{"params":["0.1"],"type":"Rz"}},
{"args": [["q", [0]]], "op": {"params": ["0.1", "0.2", "0.3"], "type": "TK1"}}
{"args": [["q", [0]]], "op": {"params": ["0.1", "alpha", "0.3"], "type": "TK1"}}
],
"created_qubits": [],
"discarded_qubits": [],
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
pub mod circuit;
pub mod extension;
pub mod json;
mod ops;
pub(crate) mod ops;
pub mod passes;
pub use ops::{Pauli, T2Op};
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;
Expand Down
77 changes: 71 additions & 6 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ use hugr::{
ops::{custom::ExternalOp, LeafOp, OpType},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::FunctionType,
types::{
type_param::{CustomTypeArg, TypeArg, TypeParam},
CustomType, FunctionType, TypeBound,
},
Extension,
};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;

use std::str::FromStr;
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter, EnumString, IntoStaticStr};
Expand Down Expand Up @@ -51,6 +56,9 @@ pub enum T2Op {
ZZMax,
Measure,
RzF64,
RxF64,
PhasedX,
ZZPhase,
TK1,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
Expand Down Expand Up @@ -111,8 +119,10 @@ impl SimpleOpEnum for T2Op {
match self {
H | T | S | X | Y | Z | Tdg | Sdg => FunctionType::new(one_qb_row.clone(), one_qb_row),
CX | ZZMax => FunctionType::new(two_qb_row.clone(), two_qb_row),
ZZPhase => FunctionType::new(type_row![QB_T, QB_T, FLOAT64_TYPE], two_qb_row),
Measure => FunctionType::new(one_qb_row, type_row![QB_T, BOOL_T]),
RzF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row),
RzF64 | RxF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row),
PhasedX => FunctionType::new(type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE], one_qb_row),
TK1 => FunctionType::new(
type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE, FLOAT64_TYPE],
one_qb_row,
Expand Down Expand Up @@ -152,20 +162,74 @@ impl T2Op {
use T2Op::*;

match self {
X => vec![(0, Pauli::X)],
T | Z | S | Tdg | Sdg | RzF64 => vec![(0, Pauli::Z)],
X | RxF64 => vec![(0, Pauli::X)],
T | Z | S | Tdg | Sdg | RzF64 | Measure => vec![(0, Pauli::Z)],
CX => vec![(0, Pauli::Z), (1, Pauli::X)],
ZZMax => vec![(0, Pauli::Z), (1, Pauli::Z)],
ZZMax | ZZPhase => vec![(0, Pauli::Z), (1, Pauli::Z)],
// by default, no commutation
_ => vec![],
}
}
}

/// The type of the symbolic expression opaque type arg.
pub const SYM_EXPR_T: CustomType =
CustomType::new_simple(SmolStr::new_inline("SymExpr"), EXTENSION_ID, TypeBound::Eq);

const SYM_OP_ID: SmolStr = SmolStr::new_inline("symbolic_float");

/// Initialize a new custom symbolic expression constant op from a string.
pub fn symbolic_constant_op(s: &str) -> OpType {
let value: serde_yaml::Value = s.into();
let l: LeafOp = EXTENSION
.instantiate_extension_op(
&SYM_OP_ID,
vec![TypeArg::Opaque(
CustomTypeArg::new(SYM_EXPR_T, value).unwrap(),
)],
)
.unwrap()
.into();
l.into()
}

/// match against a symbolic constant
pub(crate) fn match_symb_const_op(op: &OpType) -> Option<&str> {
if let OpType::LeafOp(LeafOp::CustomOp(e)) = op {
match e.as_ref() {
ExternalOp::Extension(e)
if e.def().name() == &SYM_OP_ID && e.def().extension() == &EXTENSION_ID =>
{
// TODO also check extension name

let Some(TypeArg::Opaque(s)) = e.args().get(0) else {
panic!("should be an opaque type arg.")
};

let serde_yaml::Value::String(s) = &s.value else {
panic!("unexpected yaml value.")
};

Some(s)
}
ExternalOp::Opaque(_) => todo!(),
_ => None,
}
} else {
None
}
}

fn extension() -> Extension {
let mut e = Extension::new(EXTENSION_ID);
load_all_ops::<T2Op>(&mut e).expect("add fail");

e.add_op_custom_sig_simple(
SYM_OP_ID,
"Store a sympy expression that can be evaluated to a float.".to_string(),
vec![TypeParam::Opaque(SYM_EXPR_T)],
|_: &[TypeArg]| Ok(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])),
)
.unwrap();
e
}

Expand Down Expand Up @@ -213,6 +277,7 @@ impl TryFrom<LeafOp> for T2Op {
}
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in T::all_variants() {
op.add_to_extension(extension)?;
Expand Down