Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use IncomingPort and OutgoingPort instead of Port where possible. #296

Merged
merged 5 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ use hugr::hugr::NodeType;
use hugr::ops::dataflow::IOTrait;
use hugr::ops::{Input, Output, DFG};
use hugr::types::FunctionType;
use hugr::HugrView;
use hugr::PortIndex;
use hugr::{HugrView, OutgoingPort};
use itertools::Itertools;
use portgraph::Direction;
use thiserror::Error;

pub use hugr::ops::OpType;
pub use hugr::types::{EdgeKind, Type, TypeRow};
pub use hugr::{Node, Port, Wire};

use self::units::{filter, FilteredUnits, Units};
use self::units::{filter, LinearUnit, Units};

/// An object behaving like a quantum circuit.
//
Expand Down Expand Up @@ -85,7 +84,7 @@ pub trait Circuit: HugrView {

/// Get the input units of the circuit and their types.
#[inline]
fn units(&self) -> Units
fn units(&self) -> Units<OutgoingPort>
where
Self: Sized,
{
Expand All @@ -94,29 +93,29 @@ pub trait Circuit: HugrView {

/// Get the linear input units of the circuit and their types.
#[inline]
fn linear_units(&self) -> FilteredUnits<filter::Linear>
fn linear_units(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::Linear>()
self.units().filter_map(filter::filter_linear)
}

/// Get the non-linear input units of the circuit and their types.
#[inline]
fn nonlinear_units(&self) -> FilteredUnits<filter::NonLinear>
fn nonlinear_units(&self) -> impl Iterator<Item = (Wire, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::NonLinear>()
self.units().filter_map(filter::filter_non_linear)
}

/// Returns the units corresponding to qubits inputs to the circuit.
#[inline]
fn qubits(&self) -> FilteredUnits<filter::Qubits>
fn qubits(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_
where
Self: Sized,
{
self.units().filter_units::<filter::Qubits>()
self.units().filter_map(filter::filter_qubit)
}

/// Returns all the commands in the circuit, in some topological order.
Expand Down Expand Up @@ -175,9 +174,9 @@ pub(crate) fn remove_empty_wire(
if input_port >= circ.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = Port::new(Direction::Outgoing, input_port);
let input_port = OutgoingPort::from(input_port);
let link = circ
.linked_ports(inp, input_port)
.linked_inputs(inp, input_port)
.at_most_one()
.map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?;
if link.is_some() && link.unwrap().0 != out {
Expand Down Expand Up @@ -223,9 +222,10 @@ pub enum CircuitMutError {
fn shift_ports<C: HugrMut + ?Sized>(
circ: &mut C,
node: Node,
mut free_port: Port,
free_port: impl Into<Port>,
max_ind: usize,
) -> Result<Port, hugr::hugr::HugrError> {
let mut free_port = free_port.into();
let dir = free_port.direction();
let port_range = (free_port.index() + 1..max_ind).map(|p| Port::new(dir, p));
for port in port_range {
Expand Down
142 changes: 72 additions & 70 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::iter::FusedIterator;

use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use itertools::Either::{Left, Right};
use hugr::{IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
use petgraph::visit as pv;

use super::units::filter::FilteredUnits;
use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units};
use super::Circuit;

Expand Down Expand Up @@ -53,78 +53,84 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> {

/// Returns the units of this command in a given direction.
#[inline]
pub fn units(&self, direction: Direction) -> Units<&'_ Self> {
Units::new(self.circ, self.node, direction, self)
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn linear_units(&self, direction: Direction) -> FilteredUnits<filter::Linear, &Self> {
Units::new(self.circ, self.node, direction, self).filter_units::<filter::Linear>()
pub fn units(
&self,
direction: Direction,
) -> impl Iterator<Item = (CircuitUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => Either::Left(self.inputs().map(|(u, p, t)| (u, p.into(), t))),
Direction::Outgoing => Either::Right(self.outputs().map(|(u, p, t)| (u, p.into(), t))),
}
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn qubits(&self, direction: Direction) -> FilteredUnits<filter::Qubits, &Self> {
Units::new(self.circ, self.node, direction, self).filter_units::<filter::Qubits>()
pub fn linear_units(
&self,
direction: Direction,
) -> impl Iterator<Item = (LinearUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => {
Either::Left(self.linear_inputs().map(|(u, p, t)| (u, p.into(), t)))
}
Direction::Outgoing => {
Either::Right(self.linear_outputs().map(|(u, p, t)| (u, p.into(), t)))
}
}
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn input_qubits(&self) -> FilteredUnits<filter::Qubits, &Self> {
self.qubits(Direction::Incoming)
pub fn input_qubits(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_qubit)
}

/// Returns the linear units of this command in a given direction.
#[inline]
pub fn output_qubits(&self) -> FilteredUnits<filter::Qubits, &Self> {
self.qubits(Direction::Outgoing)
}

/// Returns the units and wires of this command in a given direction.
#[inline]
pub fn unit_wires(
&self,
direction: Direction,
) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.units(direction)
.filter_map(move |(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?)))
pub fn output_qubits(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_qubit)
}

/// Returns the output units of this command. See [`Command::units`].
#[inline]
pub fn outputs(&self) -> Units<&'_ Self> {
self.units(Direction::Outgoing)
pub fn outputs(&self) -> Units<OutgoingPort, &'_ Self> {
Units::new(self.circ, self.node, self)
}

/// Returns the linear output units of this command. See [`Command::linear_units`].
#[inline]
pub fn linear_outputs(&self) -> FilteredUnits<filter::Linear, &Self> {
self.linear_units(Direction::Outgoing)
pub fn linear_outputs(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_linear)
}

/// Returns the output units and wires of this command. See [`Command::unit_wires`].
/// Returns the output units and wires of this command.
#[inline]
pub fn output_wires(&self) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.unit_wires(Direction::Outgoing)
pub fn output_wires(&self) -> impl Iterator<Item = (CircuitUnit, Wire)> + '_ {
self.outputs().filter_map(move |(unit, port, _typ)| {
let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}

/// Returns the output units of this command.
#[inline]
pub fn inputs(&self) -> Units<&'_ Self> {
self.units(Direction::Incoming)
pub fn inputs(&self) -> Units<IncomingPort, &'_ Self> {
Units::new_reversed(self.circ, self.node, self)
}

/// Returns the linear input units of this command. See [`Command::linear_units`].
#[inline]
pub fn linear_inputs(&self) -> FilteredUnits<filter::Linear, &Self> {
self.linear_units(Direction::Incoming)
pub fn linear_inputs(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_linear)
}

/// Returns the input units and wires of this command. See [`Command::unit_wires`].
/// Returns the input units and wires of this command.
#[inline]
pub fn input_wires(&self) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.unit_wires(Direction::Incoming)
self.outputs().filter_map(move |(unit, port, _typ)| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be self.inputs(), surely?

let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}

/// Returns the number of inputs of this command.
Expand Down Expand Up @@ -274,10 +280,7 @@ where
// TODO: `with_wires` combinator for `Units`?
let wire_unit = circ
.linear_units()
.map(|(linear_unit, port, _)| {
let port = port.as_outgoing().unwrap();
(Wire::new(circ.input(), port), linear_unit.index())
})
.map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index()))
.collect();

let nodes = pv::Topo::new(&circ.as_petgraph());
Expand Down Expand Up @@ -374,32 +377,30 @@ where
// required to construct a `Command`.
//
// Updates the map tracking the last wire of linear units.
let linear_units: Vec<_> =
Units::new(self.circ, node, Direction::Outgoing, DefaultUnitLabeller)
.filter_units::<filter::Linear>()
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let port = port.as_outgoing().unwrap();
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();
let linear_units: Vec<_> = Units::new(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear)
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();

Some(linear_units)
}
Expand All @@ -410,7 +411,8 @@ where
/// In the future we may want to have a more general mechanism to handle this.
//
// Note that `Command::linear_units` assumes this behaviour.
fn follow_linear_port(&self, node: Node, port: Port) -> Option<Port> {
fn follow_linear_port(&self, node: Node, port: impl Into<Port>) -> Option<Port> {
let port = port.into();
let optype = self.circ.get_optype(node);
if !optype.port_kind(port)?.is_linear() {
return None;
Expand Down
Loading
Loading