Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Aug 28, 2023
1 parent 7f5902e commit 5b921f3
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//! while the former provide views for subgraphs within a single level of the
//! hierarchy.
use std::collections::HashSet;

use itertools::Itertools;
use portgraph::{algorithms::ConvexChecker, view::Subgraph, Direction, PortView};
use thiserror::Error;
Expand All @@ -18,7 +20,7 @@ use crate::{
handle::{ContainerHandle, DataflowOpID},
OpTag, OpTrait,
},
types::FunctionType,
types::{FunctionType, Type},
Hugr, Node, Port, SimpleReplacement,
};

Expand Down Expand Up @@ -53,14 +55,21 @@ pub struct SiblingSubgraph<'g, Base> {
nodes: Vec<Node>,
/// The input ports of the subgraph.
///
/// Grouped by input parameter. Each port must be unique.
inputs: IncomingPorts,
/// Grouped by input parameter. Each port must be unique and belong to a
/// node in `nodes`.
inputs: Vec<Vec<(Node, Port)>>,
/// The output ports of the subgraph.
///
/// Repeated ports are allowed and correspond to copying the output.
outputs: OutgoingPorts,
/// Repeated ports are allowed and correspond to copying the output. Every
/// port must belong to a node in `nodes`.
outputs: Vec<(Node, Port)>,
}

/// The type of the incoming boundary of [`SiblingSubgraph`].
pub type IncomingPorts = Vec<Vec<(Node, Port)>>;
/// The type of the outgoing boundary of [`SiblingSubgraph`].
pub type OutgoingPorts = Vec<(Node, Port)>;

impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> {
/// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR.
///
Expand Down Expand Up @@ -364,6 +373,18 @@ impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> {
}
}

/// The type of all ports in the iterator.
///
/// If the array is empty or a port does not exist, returns `None`.
fn get_edge_type<H: HugrView>(hugr: &H, ports: &[(Node, Port)]) -> Option<Type> {
let &(n, p) = ports.first()?;
let edge_t = hugr.get_optype(n).signature().get(p)?.clone();
ports
.iter()
.all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t))
.then_some(edge_t)
}

/// Whether a subgraph is valid.
///
/// Does NOT check for convexity.
Expand All @@ -382,6 +403,16 @@ fn validate_subgraph<H: HugrView>(
return Err(InvalidSubgraph::NoSharedParent);
}

// Check there are no linked "other" ports
if inputs
.iter()
.flatten()
.chain(outputs)
.any(|&(n, p)| is_order_edge(hugr, n, p))
{
unimplemented!("Linked other ports not supported at boundary")
}

// Check inputs are incoming ports and outputs are outgoing ports
if inputs
.iter()
Expand All @@ -402,6 +433,7 @@ fn validate_subgraph<H: HugrView>(
.clone()
.flat_map(|(n, p)| hugr.linked_ports(n, p));
// Check incoming & outgoing ports have target resp. source inside
let nodes = nodes.iter().copied().collect::<HashSet<_>>();
if ports_inside.any(|(n, _)| !nodes.contains(&n)) {
return Err(InvalidSubgraph::InvalidBoundary);
}
Expand All @@ -420,53 +452,34 @@ fn validate_subgraph<H: HugrView>(
return Err(InvalidSubgraph::InvalidBoundary);
}

// Check edge types are equal within partition
// Check edge types are equal within partition and copyable if partition size > 1
if !inputs.iter().all(|ports| {
ports
.iter()
.filter_map(|&(n, p)| hugr.get_optype(n).signature().get(p).cloned())
.all_equal()
let Some(edge_t) = get_edge_type(hugr, ports) else {
return false;
};
let require_copy = ports.len() > 1;
!require_copy || edge_t.copyable()
}) {
return Err(InvalidSubgraph::InvalidBoundary);
}
// Check there are no state order edges
if inputs
.iter()
.flatten()
.chain(outputs)
.any(|&(n, p)| is_order_edge(hugr, n, p))
{
unimplemented!("State order edges not supported at boundary")
}

Ok(())
}

type IncomingPorts = Vec<Vec<(Node, Port)>>;
type OutgoingPorts = Vec<(Node, Port)>;

fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPorts) {
let (inp, out) = hugr
.children(hugr.root())
.take(2)
.collect_tuple()
.expect("invalid DFG");
let inp_sig = hugr.get_optype(inp).signature();
let out_sig = hugr.get_optype(out).signature();
let (dfg_inputs, order_inputs): (Vec<_>, Vec<_>) = hugr
.node_outputs(inp)
.partition(|&p| inp_sig.get(p).is_some());
let (dfg_outputs, order_outputs): (Vec<_>, Vec<_>) = hugr
.node_inputs(out)
.partition(|&p| out_sig.get(p).is_some());
if order_inputs
.into_iter()
.any(|p| is_order_edge(hugr, inp, p))
|| order_outputs
.into_iter()
.any(|p| is_order_edge(hugr, out, p))
{
unimplemented!("State order edges not supported at boundary")
if has_other_edge(hugr, inp, Direction::Outgoing) {
unimplemented!("Non-dataflow output not supported at input node")
}
let dfg_inputs = hugr.get_optype(inp).signature().output_ports();
if has_other_edge(hugr, out, Direction::Incoming) {
unimplemented!("Non-dataflow input not supported at output node")
}
let dfg_outputs = hugr.get_optype(out).signature().input_ports();
let inputs = dfg_inputs
.into_iter()
.map(|p| hugr.linked_ports(inp, p).collect())
Expand All @@ -485,8 +498,14 @@ fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPort

/// Whether a port is linked to a state order edge.
fn is_order_edge<H: HugrView>(hugr: &H, node: Node, port: Port) -> bool {
hugr.get_optype(node).signature().get(port).is_none()
&& hugr.linked_ports(node, port).count() > 0
let op = hugr.get_optype(node);
op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port)
}

/// Whether node has a non-df linked port in the given direction.
fn has_other_edge<H: HugrView>(hugr: &H, node: Node, dir: Direction) -> bool {
let op = hugr.get_optype(node);
op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap())
}

/// Errors that can occur while constructing a [`SimpleReplacement`].
Expand Down

0 comments on commit 5b921f3

Please sign in to comment.