Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: used_extensions calls for both ops and signatures #1739

Merged
merged 24 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 133 additions & 47 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
//! TODO: YAML declaration and parsing. This should be similar to a plugin
//! system (outside the `types` module), which also parses nested [`OpDef`]s.

use itertools::Itertools;
pub use semver::Version;
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::fmt::Debug;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};

use derive_more::Display;
use thiserror::Error;

use crate::hugr::IdentList;
Expand Down Expand Up @@ -40,41 +44,73 @@ pub use type_def::{TypeDef, TypeDefBound};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);
#[derive(Debug, Display, Default)]
#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
pub struct ExtensionRegistry {
/// The extensions in the registry.
exts: BTreeMap<ExtensionId, Arc<Extension>>,
/// A flag indicating whether the current set of extensions has been
/// validated.
///
/// This is used to avoid re-validating the extensions every time the
/// registry is validated, and is set to `false` whenever a new extension is
/// added.
valid: AtomicBool,
}

impl PartialEq for ExtensionRegistry {
fn eq(&self, other: &Self) -> bool {
self.exts == other.exts
}
}

impl Clone for ExtensionRegistry {
fn clone(&self) -> Self {
Self {
exts: self.exts.clone(),
valid: self.valid.load(Ordering::Relaxed).into(),
}
}
}

impl ExtensionRegistry {
/// Create a new empty extension registry.
pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
let mut res = Self::default();
for ext in extensions.into_iter() {
res.register_updated(ext);
}
res
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.0.get(name)
self.exts.get(name)
}

/// Returns `true` if the registry contains an extension with the given name.
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
self.exts.contains_key(name)
}

/// Makes a new [ExtensionRegistry], validating all the extensions in it.
pub fn try_new(
value: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, ExtensionRegistryError> {
let mut res = ExtensionRegistry(BTreeMap::new());

for ext in value.into_iter() {
res.register(ext)?;
/// Validate the set of extensions, ensuring that each extension requirements are also in the registry.
///
/// Note this potentially asks extensions to validate themselves against other extensions that
/// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
/// or at least to validate the types first - which we don't do at all yet:
//
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
if self.valid.load(Ordering::Relaxed) {
return Ok(());
}

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
for ext in res.0.values() {
ext.validate(&res)
for ext in self.exts.values() {
ext.validate(self)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}

Ok(res)
self.valid.store(true, Ordering::Relaxed);
Ok(())
}

/// Registers a new extension to the registry.
Expand All @@ -85,14 +121,17 @@ impl ExtensionRegistry {
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);

Ok(())
}
}
Expand All @@ -109,7 +148,7 @@ impl ExtensionRegistry {
/// see [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
Expand All @@ -119,6 +158,8 @@ impl ExtensionRegistry {
ve.insert(extension);
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Registers a new extension to the registry, keeping the one most up to
Expand All @@ -131,7 +172,7 @@ impl ExtensionRegistry {
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
Expand All @@ -141,31 +182,36 @@ impl ExtensionRegistry {
ve.insert(extension.clone());
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
self.exts.len()
}

/// Returns `true` if the registry contains no extensions.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
self.exts.is_empty()
}

/// Returns an iterator over the extensions in the registry.
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.0.values()
self.exts.values()
}

/// Returns an iterator over the extensions ids in the registry.
pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.keys()
self.exts.keys()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
self.0.remove(name)
// Clear the valid flag so that the registry is re-validated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most nitty super-nit ever: these two lines after the removal, for consistency with the other operations :-) :-D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove returns the Option<Arc<Ext>>, so moving it requires declaring local vars..

aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
self.valid.store(false, Ordering::Relaxed);

self.exts.remove(name)
}
}

Expand All @@ -175,7 +221,7 @@ impl IntoIterator for ExtensionRegistry {
type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_values()
self.exts.into_values()
}
}

Expand All @@ -185,7 +231,7 @@ impl<'a> IntoIterator for &'a ExtensionRegistry {
type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.values()
self.exts.values()
}
}

Expand All @@ -205,8 +251,33 @@ impl Extend<Arc<Extension>> for ExtensionRegistry {
}
}

// Encode/decode ExtensionRegistry as a list of extensions.
// We can get the map key from the extension itself.
impl<'de> Deserialize<'de> for ExtensionRegistry {
fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
where
D: Deserializer<'de>,
{
let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
Ok(ExtensionRegistry::new(extensions))
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl Serialize for ExtensionRegistry {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
extensions.serialize(serializer)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());
pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random driveby, feel free to ignore - could this could be a const fn, to which you could add Extensions yourself? (and then static EMPTY_REG: ExtensionRegistry = ExtensionRegistry::new_empty(); or similar)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtensionRegistry now has internal mutability, so it should not be used as a const.
(it does compile, but using it just raises linting errors everywhere, with good reason)

exts: BTreeMap::new(),
valid: AtomicBool::new(true),
};

/// An error that can occur in computing the signature of a node.
/// TODO: decide on failure modes
Expand All @@ -226,7 +297,7 @@ pub enum SignatureError {
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
/// The Extension Registry did not contain an Extension referenced by the Signature
#[error("Extension '{missing}' not found. Available extensions: {}",
#[error("Extension '{missing}' is not part of the declared HUGR extensions [{}]",
available.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", ")
)]
ExtensionNotFound {
Expand Down Expand Up @@ -614,7 +685,10 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(
Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
)]
#[display("[{}]", _0.iter().join(", "))]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

/// A special ExtensionId which indicates that the delta of a non-Function
Expand All @@ -632,7 +706,7 @@ impl ExtensionSet {
}

/// Adds a extension to the set.
pub fn insert(&mut self, extension: &ExtensionId) {
pub fn insert(&mut self, extension: ExtensionId) {
self.0.insert(extension.clone());
}

Expand Down Expand Up @@ -660,7 +734,7 @@ impl ExtensionSet {
}

/// Create a extension set with a single element.
pub fn singleton(extension: &ExtensionId) -> Self {
pub fn singleton(extension: ExtensionId) -> Self {
let mut set = Self::new();
set.insert(extension);
set
Expand Down Expand Up @@ -724,7 +798,25 @@ impl ExtensionSet {

impl From<ExtensionId> for ExtensionSet {
fn from(id: ExtensionId) -> Self {
Self::singleton(&id)
Self::singleton(id)
}
}

impl IntoIterator for ExtensionSet {
type Item = ExtensionId;
type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

impl<'a> IntoIterator for &'a ExtensionSet {
type Item = &'a ExtensionId;
type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;

fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}

Expand All @@ -738,12 +830,6 @@ fn as_typevar(e: &ExtensionId) -> Option<usize> {
}
}

impl Display for ExtensionSet {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(BTreeSet::from_iter(iter))
Expand Down Expand Up @@ -783,8 +869,8 @@ pub mod test {
fn test_register_update() {
// Two registers that should remain the same.
// We use them to test both `register_updated` and `register_updated_ref`.
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let mut reg_ref = ExtensionRegistry::try_new([]).unwrap();
let mut reg = ExtensionRegistry::default();
let mut reg_ref = ExtensionRegistry::default();

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl ExtensionSetDeclaration {
registry.register(PRELUDE.clone())?;
}
if !scope.contains(&PRELUDE_ID) {
scope.insert(&PRELUDE_ID);
scope.insert(PRELUDE_ID);
}

// Registers extensions sequentially, adding them to the current scope.
Expand All @@ -137,7 +137,7 @@ impl ExtensionSetDeclaration {
registry,
};
let ext = decl.make_extension(&self.imports, ctx)?;
scope.insert(ext.name());
scope.insert(ext.name().clone());
registry.register(ext)?;
}

Expand Down
7 changes: 4 additions & 3 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ impl SignatureFunc {
SignatureFunc::MissingValidateFunc(ts) => (ts, args),
};
let mut res = pf.instantiate(args, exts)?;
res.extension_reqs.insert(&def.extension);
res.extension_reqs.insert(def.extension.clone());

// If there are any row variables left, this will fail with an error:
res.try_into()
Expand Down Expand Up @@ -658,7 +658,8 @@ pub(super) mod test {
Ok(())
})?;

let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap();
let reg = ExtensionRegistry::new([PRELUDE.clone(), EXTENSION.clone(), ext]);
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Expand Down Expand Up @@ -822,7 +823,7 @@ pub(super) mod test {
)?;

// Concrete extension set
let es = ExtensionSet::singleton(&EXT_ID);
let es = ExtensionSet::singleton(EXT_ID);
let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone());
let args = [TypeArg::Extensions { es }];

Expand Down
Loading
Loading