diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 295297cc738..6e2756f0301 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -553,7 +553,7 @@ impl<'context> Elaborator<'context> { fn elaborate_cast(&mut self, cast: CastExpression, span: Span) -> (HirExpression, Type) { let (lhs, lhs_type) = self.elaborate_expression(cast.lhs); let r#type = self.resolve_type(cast.r#type); - let result = self.check_cast(lhs_type, &r#type, span); + let result = self.check_cast(&lhs_type, &r#type, span); let expr = HirExpression::Cast(HirCastExpression { lhs, r#type }); (expr, result) } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index c134820811e..ec144e8e134 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -779,7 +779,7 @@ impl<'context> Elaborator<'context> { } } - pub(super) fn check_cast(&mut self, from: Type, to: &Type, span: Span) -> Type { + pub(super) fn check_cast(&mut self, from: &Type, to: &Type, span: Span) -> Type { match from.follow_bindings() { Type::Integer(..) | Type::FieldElement @@ -787,9 +787,15 @@ impl<'context> Elaborator<'context> { | Type::TypeVariable(_, TypeVariableKind::Integer) | Type::Bool => (), - Type::TypeVariable(_, _) => { - self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); - return Type::Error; + Type::TypeVariable(binding, TypeVariableKind::Normal) => { + if let TypeBinding::Bound(typ) = &*binding.borrow() { + return self.check_cast(from, typ, span); + } + let expected = &&Type::polymorphic_integer_or_field_or_bool(self.interner); + self.unify(from, expected, || TypeCheckError::InvalidCast { + from: from.clone(), + span, + }); } Type::Error => return Type::Error, from => { diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 0ec975a04db..5d8e8d1b11e 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -439,6 +439,11 @@ pub enum TypeVariableKind { /// Can bind to any type Normal, + /// A generic integer or field or bool type. + /// This type is used in the `x as T` cast expression, where `x` is a type variable. + /// In this case, `x` must be an integer, field or bool for the cast to succeed. + IntegerOrFieldOrBool, + /// A generic integer or field type. This is a more specific kind of TypeVariable /// that can only be bound to Type::Field, Type::Integer, or other polymorphic integers. /// This is the type of undecorated integer literals like `46`. Typing them in this way @@ -456,6 +461,13 @@ pub enum TypeVariableKind { Constant(u32), } +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum PolymorphicKind { + IntegerOrFieldOrBool, + IntegerOrField, + Integer, +} + /// A TypeVariable is a mutable reference that is either /// bound to some type, or unbound with a given TypeVariableId. #[derive(PartialEq, Eq, Clone, Hash)] @@ -577,6 +589,15 @@ impl std::fmt::Display for Type { write!(f, "{}", binding.borrow()) } } + Type::TypeVariable(binding, TypeVariableKind::IntegerOrFieldOrBool) => { + if let TypeBinding::Unbound(_) = &*binding.borrow() { + // Show a pseudo-syntax for a union type, which might help users understand that the + // type must be one of these. + write!(f, "Integer | Field | bool") + } else { + write!(f, "{}", binding.borrow()) + } + } Type::TypeVariable(binding, TypeVariableKind::Constant(n)) => { if let TypeBinding::Unbound(_) = &*binding.borrow() { // TypeVariableKind::Constant(n) binds to Type::Constant(n) by default, so just show that. @@ -723,6 +744,13 @@ impl Type { Type::TypeVariable(var, kind) } + pub fn polymorphic_integer_or_field_or_bool(interner: &mut NodeInterner) -> Type { + let id = interner.next_type_variable_id(); + let kind = TypeVariableKind::IntegerOrFieldOrBool; + let var = TypeVariable::unbound(id); + Type::TypeVariable(var, kind) + } + pub fn polymorphic_integer(interner: &mut NodeInterner) -> Type { let id = interner.next_type_variable_id(); let kind = TypeVariableKind::Integer; @@ -1249,6 +1277,7 @@ impl Type { } // *length != target_length TypeVariableKind::Constant(_) => Err(UnificationError), + TypeVariableKind::IntegerOrFieldOrBool => Err(UnificationError), TypeVariableKind::IntegerOrField => Err(UnificationError), TypeVariableKind::Integer => Err(UnificationError), }, @@ -1261,11 +1290,11 @@ impl Type { /// Try to bind a PolymorphicInt variable to self, succeeding if self is an integer, field, /// other PolymorphicInt type, or type variable. If successful, the binding is placed in the /// given TypeBindings map rather than linked immediately. - fn try_bind_to_polymorphic_int( + fn try_bind_to_polymorphic( &self, var: &TypeVariable, bindings: &mut TypeBindings, - only_integer: bool, + kind: PolymorphicKind, ) -> Result<(), UnificationError> { let target_id = match &*var.borrow() { TypeBinding::Bound(_) => unreachable!(), @@ -1274,24 +1303,27 @@ impl Type { let this = self.substitute(bindings).follow_bindings(); match &this { + Type::Bool if matches!(kind, PolymorphicKind::IntegerOrFieldOrBool) => { + bindings.insert(target_id, (var.clone(), this)); + Ok(()) + } Type::Integer(..) => { bindings.insert(target_id, (var.clone(), this)); Ok(()) } - Type::FieldElement if !only_integer => { + Type::FieldElement if !matches!(kind, PolymorphicKind::Integer) => { bindings.insert(target_id, (var.clone(), this)); Ok(()) } - Type::TypeVariable(self_var, TypeVariableKind::IntegerOrField) => { + Type::TypeVariable(self_var, TypeVariableKind::IntegerOrField) + | Type::TypeVariable(self_var, TypeVariableKind::IntegerOrFieldOrBool) => { let borrow = self_var.borrow(); match &*borrow { - TypeBinding::Bound(typ) => { - typ.try_bind_to_polymorphic_int(var, bindings, only_integer) - } + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic(var, bindings, kind), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(new_target_id) => { - if only_integer { + if matches!(kind, PolymorphicKind::Integer) { // Integer is more specific than IntegerOrField so we bind the type // variable to Integer instead. let clone = Type::TypeVariable(var.clone(), TypeVariableKind::Integer); @@ -1306,9 +1338,7 @@ impl Type { Type::TypeVariable(self_var, TypeVariableKind::Integer) => { let borrow = self_var.borrow(); match &*borrow { - TypeBinding::Bound(typ) => { - typ.try_bind_to_polymorphic_int(var, bindings, only_integer) - } + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic(var, bindings, kind), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(_) => { @@ -1320,17 +1350,17 @@ impl Type { Type::TypeVariable(self_var, TypeVariableKind::Normal) => { let borrow = self_var.borrow(); match &*borrow { - TypeBinding::Bound(typ) => { - typ.try_bind_to_polymorphic_int(var, bindings, only_integer) - } + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic(var, bindings, kind), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(new_target_id) => { // Bind to the most specific type variable kind - let clone_kind = if only_integer { - TypeVariableKind::Integer - } else { - TypeVariableKind::IntegerOrField + let clone_kind = match kind { + PolymorphicKind::Integer => TypeVariableKind::Integer, + PolymorphicKind::IntegerOrField => TypeVariableKind::IntegerOrField, + PolymorphicKind::IntegerOrFieldOrBool => { + TypeVariableKind::IntegerOrFieldOrBool + } }; let clone = Type::TypeVariable(var.clone(), clone_kind); bindings.insert(*new_target_id, (self_var.clone(), clone)); @@ -1424,19 +1454,26 @@ impl Type { alias.try_unify(other, bindings) } + (TypeVariable(var, Kind::IntegerOrFieldOrBool), other) + | (other, TypeVariable(var, Kind::IntegerOrFieldOrBool)) => other + .try_unify_to_type_variable(var, bindings, |bindings| { + let kind = PolymorphicKind::IntegerOrFieldOrBool; + other.try_bind_to_polymorphic(var, bindings, kind) + }), + (TypeVariable(var, Kind::IntegerOrField), other) | (other, TypeVariable(var, Kind::IntegerOrField)) => { other.try_unify_to_type_variable(var, bindings, |bindings| { - let only_integer = false; - other.try_bind_to_polymorphic_int(var, bindings, only_integer) + let kind = PolymorphicKind::IntegerOrField; + other.try_bind_to_polymorphic(var, bindings, kind) }) } (TypeVariable(var, Kind::Integer), other) | (other, TypeVariable(var, Kind::Integer)) => { other.try_unify_to_type_variable(var, bindings, |bindings| { - let only_integer = true; - other.try_bind_to_polymorphic_int(var, bindings, only_integer) + let kind = PolymorphicKind::Integer; + other.try_bind_to_polymorphic(var, bindings, kind) }) } @@ -2147,6 +2184,7 @@ impl TypeVariableKind { /// during monomorphization. pub(crate) fn default_type(&self) -> Option { match self { + TypeVariableKind::IntegerOrFieldOrBool => None, TypeVariableKind::IntegerOrField => Some(Type::default_int_or_field_type()), TypeVariableKind::Integer => Some(Type::default_int_type()), TypeVariableKind::Constant(length) => Some(Type::Constant(*length)), @@ -2243,6 +2281,9 @@ impl std::fmt::Debug for Type { Signedness::Unsigned => write!(f, "u{num_bits}"), }, Type::TypeVariable(var, TypeVariableKind::Normal) => write!(f, "{:?}", var), + Type::TypeVariable(binding, TypeVariableKind::IntegerOrFieldOrBool) => { + write!(f, "IntOrFieldOrBool{:?}", binding) + } Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { write!(f, "IntOrField{:?}", binding) } diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index 48d72fa0e9a..bbf3e51fdb2 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -565,7 +565,8 @@ fn push_type_parts(typ: &Type, parts: &mut Vec, files: &File push_type_variable_parts(binding, parts, files); } } - Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { + Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) + | Type::TypeVariable(binding, TypeVariableKind::IntegerOrFieldOrBool) => { if let TypeBinding::Unbound(_) = &*binding.borrow() { parts.push(string_part("Field")); } else {