Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: used_extensions calls for both ops and signatures #1739

Merged
merged 24 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
//! TODO: YAML declaration and parsing. This should be similar to a plugin
//! system (outside the `types` module), which also parses nested [`OpDef`]s.

use itertools::Itertools;
pub use semver::Version;
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::fmt::Debug;
use std::mem;
use std::sync::{Arc, Weak};

use derive_more::Display;
use thiserror::Error;

use crate::hugr::IdentList;
Expand Down Expand Up @@ -40,10 +43,22 @@ pub use type_def::{TypeDef, TypeDefBound};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, Default, PartialEq)]
#[derive(Clone, Debug, Display, Default, PartialEq)]
#[display("ExtensionRegistry[{}]", _0.keys().join(", "))]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);

impl ExtensionRegistry {
/// Create a new empty extension registry.
///
/// For a version that checks the validity of the extensions, see [`ExtensionRegistry::try_new`].
pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
let mut res = Self::default();
for ext in extensions.into_iter() {
res.register_updated(ext);
}
res
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.0.get(name)
Expand All @@ -55,14 +70,12 @@ impl ExtensionRegistry {
}

/// Makes a new [ExtensionRegistry], validating all the extensions in it.
///
/// For an unvalidated version, see [`ExtensionRegistry::new`].
pub fn try_new(
value: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, ExtensionRegistryError> {
let mut res = ExtensionRegistry(BTreeMap::new());

for ext in value.into_iter() {
res.register(ext)?;
}
let res = ExtensionRegistry::new(value);

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
Expand Down Expand Up @@ -205,6 +218,28 @@ impl Extend<Arc<Extension>> for ExtensionRegistry {
}
}

// Encode/decode ExtensionRegistry as a list of extensions.
// We can get the map key from the extension itself.
impl<'de> Deserialize<'de> for ExtensionRegistry {
fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
where
D: Deserializer<'de>,
{
let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
Ok(ExtensionRegistry::new(extensions))
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl Serialize for ExtensionRegistry {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions: Vec<Arc<Extension>> = self.0.values().cloned().collect();
extensions.serialize(serializer)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());

Expand All @@ -226,7 +261,7 @@ pub enum SignatureError {
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
/// The Extension Registry did not contain an Extension referenced by the Signature
#[error("Extension '{missing}' not found. Available extensions: {}",
#[error("Extension '{missing}' is not part of the declared HUGR extensions [{}]",
available.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", ")
)]
ExtensionNotFound {
Expand Down Expand Up @@ -614,7 +649,10 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(
Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
)]
#[display("[{}]", _0.iter().join(", "))]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

/// A special ExtensionId which indicates that the delta of a non-Function
Expand Down Expand Up @@ -738,12 +776,6 @@ fn as_typevar(e: &ExtensionId) -> Option<usize> {
}
}

impl Display for ExtensionSet {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(BTreeSet::from_iter(iter))
Expand Down
77 changes: 71 additions & 6 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Utilities for resolving operations and types present in a HUGR, and updating
//! the list of used extensions. See [`crate::Hugr::resolve_extension_defs`].
//! the list of used extensions. The functionalities of this module can be
//! called from the type methods [`crate::Hugr::resolve_extension_defs`],
//! [`crate::ops::OpType::used_extensions`], and
//! [`crate::types::Signature::used_extensions`].
//!
//! When listing "used extensions" we only care about _definitional_ extension
//! requirements, i.e., the operations and types that are required to define the
Expand All @@ -13,21 +16,23 @@
//! Note: These procedures are only temporary until `hugr-model` is stabilized.
//! Once that happens, hugrs will no longer be directly deserialized using serde
//! but instead will be created by the methods in `crate::import`. As these
//! (will) automatically resolve extensions as the operations are created,
//! we will no longer require this post-facto resolution step.
//! (will) automatically resolve extensions as the operations are created, we
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
//! will no longer require this post-facto resolution step.

mod ops;
mod types;
mod types_mut;

pub(crate) use ops::update_op_extensions;
pub(crate) use types::update_op_types_extensions;
pub(crate) use ops::{collect_op_extensions, update_op_extensions};
pub(crate) use types::{collect_op_types_extensions, collect_signature_exts};
pub(crate) use types_mut::update_op_types_extensions;

use derive_more::{Display, Error, From};

use super::{Extension, ExtensionId, ExtensionRegistry};
use crate::ops::custom::OpaqueOpError;
use crate::ops::{NamedOp, OpName, OpType};
use crate::types::TypeName;
use crate::types::{FuncTypeBase, MaybeRV, TypeName};
use crate::Node;

/// Errors that can occur during extension resolution.
Expand Down Expand Up @@ -101,3 +106,63 @@ impl ExtensionResolutionError {
}
}
}

/// Errors that can occur when collecting extension requirements.
#[derive(Debug, Display, Clone, Error, From, PartialEq)]
#[non_exhaustive]
pub enum ExtensionCollectionError {
/// An operation requires an extension that is not in the given registry.
#[display(
"{op}{} contains custom types for which have lost the reference to their defining extensions. Dropped extensions: {}",
if let Some(node) = node { format!(" ({})", node) } else { "".to_string() },
missing_extensions.join(", ")
)]
DroppedOpExtensions {
/// The node that is missing extensions.
node: Option<Node>,
/// The operation that is missing extensions.
op: OpName,
/// The missing extensions.
missing_extensions: Vec<ExtensionId>,
},
/// A signature requires an extension that is not in the given registry.
#[display(
"Signature {signature} contains custom types for which have lost the reference to their defining extensions. Dropped extensions: {}",
missing_extensions.join(", ")
)]
DroppedSignatureExtensions {
/// The signature that is missing extensions.
signature: String,
/// The missing extensions.
missing_extensions: Vec<ExtensionId>,
},
}

impl ExtensionCollectionError {
/// Create a new error when operation extensions have been dropped.
pub fn dropped_op_extension(
node: Option<Node>,
op: &OpType,
missing_extension: impl IntoIterator<Item = ExtensionId>,
) -> Self {
Self::DroppedOpExtensions {
node,
op: NamedOp::name(op),
missing_extensions: missing_extension.into_iter().collect(),
}
}

/// Create a new error when signature extensions have been dropped.
pub fn dropped_signature<RV: MaybeRV>(
signature: &FuncTypeBase<RV>,
missing_extension: impl IntoIterator<Item = ExtensionId>,
) -> Self {
Self::DroppedSignatureExtensions {
signature: format!("{signature}"),
missing_extensions: missing_extension.into_iter().collect(),
}
}
}

#[cfg(test)]
mod test;
42 changes: 40 additions & 2 deletions hugr-core/src/extension/resolution/ops.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,50 @@
//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required extension.
//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required
//! extension.
//!
//! Contains both mutable ([`update_op_extensions`]) and immutable
//! ([`collect_operation_extension`]) methods to resolve operations and collect
//! the required extensions respectively.

use std::sync::Arc;

use super::{Extension, ExtensionRegistry, ExtensionResolutionError};
use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError};
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType};
use crate::Node;

/// Returns the extension in the registry required by the operation.
///
/// If the operation does not require an extension, returns `None`.
///
/// [`ExtensionOp`]s store a [`Weak`] reference to their extension, which can be
/// invalidated if the original `Arc<Extension>` is dropped. On such cases, we
/// return an error with the missing extension names.
///
/// # Attributes
///
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
/// - `node`: The node where the operation is located, if available. This is
/// used to provide context in the error message.
/// - `op`: The operation to collect the extensions from.
pub(crate) fn collect_op_extensions(
node: Option<Node>,
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
op: &OpType,
) -> Result<Option<Arc<Extension>>, ExtensionCollectionError> {
let OpType::ExtensionOp(ext_op) = op else {
// TODO: Extract the extension when the operation is a `Const`.
// https://github.com/CQCL/hugr/issues/1742
return Ok(None);
};
let ext = ext_op.def().extension();
match ext.upgrade() {
Some(e) => Ok(Some(e)),
None => Err(ExtensionCollectionError::dropped_op_extension(
node,
op,
[ext_op.def().extension_id().clone()],
)),
}
}

/// Compute the required extension for an operation.
///
/// If the op is a [`OpType::OpaqueOp`], replace it with a resolved
Expand Down
Loading
Loading