From 6c5c5f3e1a444190e7967288f5fe635817d2e346 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 16:06:41 +0000 Subject: [PATCH] feat: shorthand for retrieving custom constants from `Const`, `Value` Closes #654 --- src/ops/constant.rs | 19 +++++++++++++++++-- src/values.rs | 12 ++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 308378e3a3..5cdb2a3bd3 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -76,6 +76,11 @@ impl Const { .unzip(); Self::new(Value::tuple(values), Type::new_tuple(types)).unwrap() } + + /// For a Const holding a CustomConst, extract the CustomConst by downcasting. + pub fn get_custom_value(&self) -> Option<&T> { + self.value().get_custom_value() + } } impl OpName for Const { @@ -123,7 +128,7 @@ mod test { prelude::{ConstUsize, USIZE_T}, ExtensionId, ExtensionSet, }, - std_extensions::arithmetic::float_types::FLOAT64_TYPE, + std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}, type_row, types::test::test_registry, types::type_param::TypeArg, @@ -197,11 +202,21 @@ mod test { tuple_ty.check_type(&tuple_val2), Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2 ); - let tuple_val3 = Value::tuple([int_value, serialized_float(3.3), serialized_float(2.0)]); + let tuple_val3 = Value::tuple([ + int_value.clone(), + serialized_float(3.3), + serialized_float(2.0), + ]); assert_eq!( tuple_ty.check_type(&tuple_val3), Err(ConstTypeError::TupleWrongLength) ); + + let op = Const::new(int_value, USIZE_T).unwrap(); + + assert_eq!(op.get_custom_value(), Some(&ConstUsize::new(257))); + let try_float: Option<&ConstF64> = op.get_custom_value(); + assert!(try_float.is_none()); } #[test] diff --git a/src/values.rs b/src/values.rs index 987e8a5bee..808cf883ef 100644 --- a/src/values.rs +++ b/src/values.rs @@ -127,6 +127,18 @@ impl Value { val: PrimValue::Extension { c: (Box::new(c),) }, } } + + /// For a Const holding a CustomConst, extract the CustomConst by downcasting. + pub fn get_custom_value(&self) -> Option<&T> { + if let Value::Prim { + val: PrimValue::Extension { c: (custom,) }, + } = self + { + custom.downcast_ref() + } else { + None + } + } } impl From for Value {