Skip to content

Commit

Permalink
feat: Use IncomingPort and OutgoingPort instead of Port where p…
Browse files Browse the repository at this point in the history
…ossible. (#296)

Mostly a refactor. The `Units` struct required some API changes so the
iterator can return the kind of port matching its direction.

Closes #220
  • Loading branch information
aborgna-q authored Mar 1, 2024
1 parent 8bd2441 commit 93f83ac
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 186 deletions.
26 changes: 13 additions & 13 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ use hugr::hugr::NodeType;
use hugr::ops::dataflow::IOTrait;
use hugr::ops::{Input, Output, DFG};
use hugr::types::FunctionType;
use hugr::HugrView;
use hugr::PortIndex;
use hugr::{HugrView, OutgoingPort};
use itertools::Itertools;
use portgraph::Direction;
use thiserror::Error;

pub use hugr::ops::OpType;
pub use hugr::types::{EdgeKind, Type, TypeRow};
pub use hugr::{Node, Port, Wire};

use self::units::{filter, FilteredUnits, Units};
use self::units::{filter, LinearUnit, Units};

/// An object behaving like a quantum circuit.
//
Expand Down Expand Up @@ -85,7 +84,7 @@ pub trait Circuit: HugrView {

/// Get the input units of the circuit and their types.
#[inline]
fn units(&self) -> Units
fn units(&self) -> Units<OutgoingPort>
where
Self: Sized,
{
Expand All @@ -94,29 +93,29 @@ pub trait Circuit: HugrView {

/// Get the linear input units of the circuit and their types.
#[inline]
fn linear_units(&self) -> FilteredUnits<filter::Linear>
fn linear_units(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::Linear>()
self.units().filter_map(filter::filter_linear)
}

/// Get the non-linear input units of the circuit and their types.
#[inline]
fn nonlinear_units(&self) -> FilteredUnits<filter::NonLinear>
fn nonlinear_units(&self) -> impl Iterator<Item = (Wire, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::NonLinear>()
self.units().filter_map(filter::filter_non_linear)
}

/// Returns the units corresponding to qubits inputs to the circuit.
#[inline]
fn qubits(&self) -> FilteredUnits<filter::Qubits>
fn qubits(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::Qubits>()
self.units().filter_map(filter::filter_qubit)
}

/// Returns all the commands in the circuit, in some topological order.
Expand Down Expand Up @@ -175,9 +174,9 @@ pub(crate) fn remove_empty_wire(
if input_port >= circ.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = Port::new(Direction::Outgoing, input_port);
let input_port = OutgoingPort::from(input_port);
let link = circ
.linked_ports(inp, input_port)
.linked_inputs(inp, input_port)
.at_most_one()
.map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?;
if link.is_some() && link.unwrap().0 != out {
Expand Down Expand Up @@ -223,9 +222,10 @@ pub enum CircuitMutError {
fn shift_ports<C: HugrMut + ?Sized>(
circ: &mut C,
node: Node,
mut free_port: Port,
free_port: impl Into<Port>,
max_ind: usize,
) -> Result<Port, hugr::hugr::HugrError> {
let mut free_port = free_port.into();
let dir = free_port.direction();
let port_range = (free_port.index() + 1..max_ind).map(|p| Port::new(dir, p));
for port in port_range {
Expand Down
142 changes: 72 additions & 70 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::iter::FusedIterator;

use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use itertools::Either::{Left, Right};
use hugr::{IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
use petgraph::visit as pv;

use super::units::filter::FilteredUnits;
use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units};
use super::Circuit;

Expand Down Expand Up @@ -53,78 +53,84 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> {

/// Returns the units of this command in a given direction.
#[inline]
pub fn units(&self, direction: Direction) -> Units<&'_ Self> {
Units::new(self.circ, self.node, direction, self)
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn linear_units(&self, direction: Direction) -> FilteredUnits<filter::Linear, &Self> {
Units::new(self.circ, self.node, direction, self).filter_units::<filter::Linear>()
pub fn units(
&self,
direction: Direction,
) -> impl Iterator<Item = (CircuitUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => Either::Left(self.inputs().map(|(u, p, t)| (u, p.into(), t))),
Direction::Outgoing => Either::Right(self.outputs().map(|(u, p, t)| (u, p.into(), t))),
}
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn qubits(&self, direction: Direction) -> FilteredUnits<filter::Qubits, &Self> {
Units::new(self.circ, self.node, direction, self).filter_units::<filter::Qubits>()
pub fn linear_units(
&self,
direction: Direction,
) -> impl Iterator<Item = (LinearUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => {
Either::Left(self.linear_inputs().map(|(u, p, t)| (u, p.into(), t)))
}
Direction::Outgoing => {
Either::Right(self.linear_outputs().map(|(u, p, t)| (u, p.into(), t)))
}
}
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn input_qubits(&self) -> FilteredUnits<filter::Qubits, &Self> {
self.qubits(Direction::Incoming)
pub fn input_qubits(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_qubit)
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn output_qubits(&self) -> FilteredUnits<filter::Qubits, &Self> {
self.qubits(Direction::Outgoing)
}

/// Returns the units and wires of this command in a given direction.
#[inline]
pub fn unit_wires(
&self,
direction: Direction,
) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.units(direction)
.filter_map(move |(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?)))
pub fn output_qubits(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_qubit)
}

/// Returns the output units of this command. See [`Command::units`].
#[inline]
pub fn outputs(&self) -> Units<&'_ Self> {
self.units(Direction::Outgoing)
pub fn outputs(&self) -> Units<OutgoingPort, &'_ Self> {
Units::new_outgoing(self.circ, self.node, self)
}

/// Returns the linear output units of this command. See [`Command::linear_units`].
#[inline]
pub fn linear_outputs(&self) -> FilteredUnits<filter::Linear, &Self> {
self.linear_units(Direction::Outgoing)
pub fn linear_outputs(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_linear)
}

/// Returns the output units and wires of this command. See [`Command::unit_wires`].
/// Returns the output units and wires of this command.
#[inline]
pub fn output_wires(&self) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.unit_wires(Direction::Outgoing)
pub fn output_wires(&self) -> impl Iterator<Item = (CircuitUnit, Wire)> + '_ {
self.outputs().filter_map(move |(unit, port, _typ)| {
let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}

/// Returns the output units of this command.
#[inline]
pub fn inputs(&self) -> Units<&'_ Self> {
self.units(Direction::Incoming)
pub fn inputs(&self) -> Units<IncomingPort, &'_ Self> {
Units::new_incoming(self.circ, self.node, self)
}

/// Returns the linear input units of this command. See [`Command::linear_units`].
#[inline]
pub fn linear_inputs(&self) -> FilteredUnits<filter::Linear, &Self> {
self.linear_units(Direction::Incoming)
pub fn linear_inputs(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_linear)
}

/// Returns the input units and wires of this command. See [`Command::unit_wires`].
/// Returns the input units and wires of this command.
#[inline]
pub fn input_wires(&self) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.unit_wires(Direction::Incoming)
self.inputs().filter_map(move |(unit, port, _typ)| {
let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}

/// Returns the number of inputs of this command.
Expand Down Expand Up @@ -274,10 +280,7 @@ where
// TODO: `with_wires` combinator for `Units`?
let wire_unit = circ
.linear_units()
.map(|(linear_unit, port, _)| {
let port = port.as_outgoing().unwrap();
(Wire::new(circ.input(), port), linear_unit.index())
})
.map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index()))
.collect();

let nodes = pv::Topo::new(&circ.as_petgraph());
Expand Down Expand Up @@ -374,32 +377,30 @@ where
// required to construct a `Command`.
//
// Updates the map tracking the last wire of linear units.
let linear_units: Vec<_> =
Units::new(self.circ, node, Direction::Outgoing, DefaultUnitLabeller)
.filter_units::<filter::Linear>()
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let port = port.as_outgoing().unwrap();
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();
let linear_units: Vec<_> = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear)
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();

Some(linear_units)
}
Expand All @@ -410,7 +411,8 @@ where
/// In the future we may want to have a more general mechanism to handle this.
//
// Note that `Command::linear_units` assumes this behaviour.
fn follow_linear_port(&self, node: Node, port: Port) -> Option<Port> {
fn follow_linear_port(&self, node: Node, port: impl Into<Port>) -> Option<Port> {
let port = port.into();
let optype = self.circ.get_optype(node);
if !optype.port_kind(port)?.is_linear() {
return None;
Expand Down
Loading

0 comments on commit 93f83ac

Please sign in to comment.