Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Allow CustomConsts to (optionally) be hashable #1397

Merged
merged 13 commits into from
Sep 11, 2024
8 changes: 4 additions & 4 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ pub const STRING_CUSTOM_TYPE: CustomType =
/// String type.
pub const STRING_TYPE: Type = Type::new_extension(STRING_CUSTOM_TYPE);

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant string values.
pub struct ConstString(String);

Expand Down Expand Up @@ -276,7 +276,7 @@ pub fn sum_with_error(ty: Type) -> SumType {
SumType::new([ty, ERROR_TYPE])
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);

Expand Down Expand Up @@ -311,7 +311,7 @@ impl CustomConst for ConstUsize {
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstError {
/// Integer tag/signal for the error.
Expand Down Expand Up @@ -348,7 +348,7 @@ impl CustomConst for ConstError {
}
}

#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
/// A structure for holding references to external symbols.
pub struct ConstExternalSymbol {
/// The symbol name that this value refers to. Must be nonempty.
Expand Down
36 changes: 34 additions & 2 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

mod custom;

use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76.
use std::hash::{Hash, Hasher};

use super::{NamedOp, OpName, OpTrait, StaticTag};
use super::{OpTag, OpType};
use crate::extension::ExtensionSet;
Expand All @@ -16,7 +19,7 @@ use thiserror::Error;

pub use custom::{
downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
CustomSerialized,
CustomSerialized, MaybeHash,
};

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -134,6 +137,24 @@ impl Sum {
// For valid instances, the type row will not have any row variables.
self.sum_type.as_tuple().map(|_| self.values.as_ref())
}

fn maybe_hash<H: Hasher>(&self, st: &mut H) -> bool {
maybe_hash_values(&self.values, st) && {
st.write_usize(self.tag);
self.sum_type.hash(st);
true
}
}
}

pub(crate) fn maybe_hash_values<H: Hasher>(vals: &[Value], st: &mut H) -> bool {
// We can't mutate the Hasher with the first element
// if any element, even the last, fails.
let mut hasher = DefaultHasher::new();
vals.iter().all(|e| e.maybe_hash(&mut hasher)) && {
st.write_u64(hasher.finish());
true
}
}

impl TryFrom<SerialSum> for Sum {
Expand Down Expand Up @@ -508,6 +529,17 @@ impl Value {
None
}
}

/// Hashes this value, if possible. [Value::Extension]s are hashable according
/// to their implementation of [MaybeHash]; [Value::Function]s never are;
/// [Value::Sum]s are if their contents are.
pub fn maybe_hash<H: Hasher>(&self, st: &mut H) -> bool {
match self {
Value::Extension { e } => e.value().maybe_hash(&mut Box::new(st)),
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
Value::Function { .. } => false,
Value::Sum(s) => s.maybe_hash(st),
}
}
}

impl<T> From<T> for Value
Expand Down Expand Up @@ -547,7 +579,7 @@ mod test {

use super::*;

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
/// A custom constant value used in testing
pub(crate) struct CustomTestValue(pub CustomType);

Expand Down
45 changes: 39 additions & 6 deletions hugr-core/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
//! [`Const`]: crate::ops::Const

use std::any::Any;
use std::hash::{Hash, Hasher};

use downcast_rs::{impl_downcast, Downcast};
use thiserror::Error;

use crate::extension::ExtensionSet;
use crate::macros::impl_box_clone;

use crate::types::{CustomCheckFailure, Type};
use crate::IncomingPort;

use super::Value;

use super::ValueName;
use super::{Value, ValueName};

/// Extensible constant values.
///
Expand All @@ -37,7 +35,7 @@ use super::ValueName;
/// extension::ExtensionSet, std_extensions::arithmetic::int_types};
/// use serde_json::json;
///
/// #[derive(std::fmt::Debug, Clone, Serialize,Deserialize)]
/// #[derive(std::fmt::Debug, Clone, Hash, Serialize,Deserialize)]
/// struct CC(i64);
///
/// #[typetag::serde]
Expand All @@ -55,7 +53,7 @@ use super::ValueName;
/// ```
#[typetag::serde(tag = "c", content = "v")]
pub trait CustomConst:
Send + Sync + std::fmt::Debug + CustomConstBoxClone + Any + Downcast
Send + Sync + std::fmt::Debug + MaybeHash + CustomConstBoxClone + Any + Downcast
{
/// An identifier for the constant.
fn name(&self) -> ValueName;
Expand Down Expand Up @@ -90,6 +88,33 @@ pub trait CustomConst:
fn get_type(&self) -> Type;
}

/// Prerequisite for `CustomConst`. Allows to declare a custom hash function, but the easiest
/// options are either to `impl MaybeHash for ... {}` to declare "not hashable", or else
/// to implement (or derive) [Hash].
pub trait MaybeHash {
/// Hashes the value, if possible; else return `false` without mutating the `Hasher`.
/// This relates with [CustomConst::equal_consts] just like [Hash] with [Eq]:
/// * if `x.equal_consts(y)` ==> `x.maybe_hash(s)` behaves equivalently to `y.maybe_hash(s)`
/// * if `x.hash(s)` behaves differently from `y.hash(s)` ==> `x.equal_consts(y) == false`
///
/// As with [Hash], these requirements can trivially be satisfied by either
/// * `equal_consts` always returning `false`, or
/// * `maybe_hash` always behaving the same (e.g. returning `false`, as it does by default)
///
/// Note: this uses `dyn` rather than being parametrized by `<H: Hasher>` so that we can
/// still use `dyn CustomConst`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. wee need the trait to be object safe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment updated.

fn maybe_hash(&self, _state: &mut dyn Hasher) -> bool {
false
}
}

impl<T: Hash> MaybeHash for T {
fn maybe_hash(&self, st: &mut dyn Hasher) -> bool {
Hash::hash(self, &mut Box::new(st));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn maybe_hash(&self, st: &mut dyn Hasher) -> bool {
Hash::hash(self, &mut Box::new(st));
fn maybe_hash(&self, mut st: &mut dyn Hasher) -> bool {
Hash::hash(self, &mut st);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That gives me:

error[E0277]: the size for values of type `dyn Hasher` cannot be known at compilation time
   --> hugr-core/src/ops/constant/custom.rs:113:26
    |
113 |         Hash::hash(self, &mut *st);
    |         ----------       ^^^^^^^^ doesn't have a size known at compile-time
    |         |
    |         required by a bound introduced by this call
    |
    = help: the trait `Sized` is not implemented for `dyn Hasher`

which I admit is odd - I'm not trying to pass a Hasher, only a &mut to it. Am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I'd put in an extra *

Copy link
Collaborator

@aborgna-q aborgna-q Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*st is trying to dereference an unsized value into the stack.
The suggestion avoids that by using T : Hasher $\implies$ &mut T : Hasher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does look to me as if I'm passing a mutable reference to a mutable reference!

true
}
}

impl PartialEq for dyn CustomConst {
fn eq(&self, other: &Self) -> bool {
(*self).equal_consts(other)
Expand Down Expand Up @@ -253,6 +278,14 @@ impl CustomSerialized {
}
}

impl MaybeHash for CustomSerialized {
fn maybe_hash(&self, state: &mut dyn Hasher) -> bool {
// Consistent with equality, same serialization <=> same hash.
self.value.to_string().hash(&mut Box::new(state));
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
true
}
}

#[typetag::serde]
impl CustomConst for CustomSerialized {
fn name(&self) -> ValueName {
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Basic floating-point types
use crate::ops::constant::ValueName;
use crate::ops::constant::{MaybeHash, ValueName};
use crate::types::TypeName;
use crate::{
extension::{ExtensionId, ExtensionSet},
Expand Down Expand Up @@ -56,6 +56,8 @@ impl ConstF64 {
}
}

impl MaybeHash for ConstF64 {}

#[typetag::serde]
impl CustomConst for ConstF64 {
fn name(&self) -> ValueName {
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const fn type_arg(log_width: u8) -> TypeArg {
}

/// An integer (either signed or unsigned)
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ConstInt {
log_width: u8,
// We always use a u64 for the value. The interpretation is:
Expand Down
14 changes: 13 additions & 1 deletion hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! List type and operations.
use std::hash::{Hash, Hasher};

use itertools::Itertools;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};

use crate::ops::constant::ValueName;
use crate::ops::constant::{maybe_hash_values, MaybeHash, ValueName};
use crate::ops::{OpName, Value};
use crate::types::TypeName;
use crate::{
Expand Down Expand Up @@ -56,6 +58,16 @@ impl ListValue {
}
}

impl MaybeHash for ListValue {
fn maybe_hash(&self, st: &mut dyn Hasher) -> bool {
let mut b = Box::new(st);
maybe_hash_values(&self.0, &mut b) && {
self.1.hash(&mut b);
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
true
}
}
}

#[typetag::serde]
impl CustomConst for ListValue {
fn name(&self) -> ValueName {
Expand Down
Loading