Skip to content

Commit

Permalink
feat(traits): Add trait impl for buildin types (noir-lang#2964)
Browse files Browse the repository at this point in the history
Co-authored-by: Yordan Madzhunkov <[email protected]>
  • Loading branch information
2 people authored and Sakapoi committed Oct 19, 2023
1 parent 4a317d8 commit 4fde6bb
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 18 deletions.
16 changes: 9 additions & 7 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,6 @@ fn collect_trait_impl(
errors.push((err.into(), trait_impl.file_id));
}
}
} else {
let error = DefCollectorErrorKind::NonStructTraitImpl {
trait_path: trait_impl.trait_path.clone(),
span: trait_impl.trait_path.span(),
};
errors.push((error.into(), trait_impl.file_id));
}
}
}
Expand Down Expand Up @@ -978,6 +972,8 @@ fn resolve_trait_impls(
let path_resolver = StandardPathResolver::new(module_id);
let trait_definition_ident = trait_impl.trait_path.last_segment();

let self_type_span = unresolved_type.span;

let self_type = {
let mut resolver =
Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id);
Expand Down Expand Up @@ -1024,7 +1020,13 @@ fn resolve_trait_impls(
trait_id,
methods: vecmap(&impl_methods, |(_, func_id)| *func_id),
});
interner.add_trait_implementation(&key, resolved_trait_impl.clone());
if !interner.add_trait_implementation(&key, resolved_trait_impl.clone()) {
let error = DefCollectorErrorKind::TraitImplNotAllowedFor {
trait_path: trait_impl.trait_path.clone(),
span: self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()),
};
errors.push((error.into(), trait_impl.file_id));
}
}

methods.append(&mut impl_methods);
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/hir/def_collector/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub enum DefCollectorErrorKind {
PathResolutionError(PathResolutionError),
#[error("Non-struct type used in impl")]
NonStructTypeInImpl { span: Span },
#[error("Non-struct type used in trait impl")]
NonStructTraitImpl { trait_path: Path, span: Span },
#[error("Trait implementation is not allowed for this")]
TraitImplNotAllowedFor { trait_path: Path, span: Span },
#[error("Cannot `impl` a type defined outside the current crate")]
ForeignImpl { span: Span, type_name: String },
#[error("Mismatch number of parameters in of trait implementation")]
Expand Down Expand Up @@ -125,10 +125,10 @@ impl From<DefCollectorErrorKind> for Diagnostic {
"Only struct types may have implementation methods".into(),
span,
),
DefCollectorErrorKind::NonStructTraitImpl { trait_path, span } => {
DefCollectorErrorKind::TraitImplNotAllowedFor { trait_path, span } => {
Diagnostic::simple_error(
format!("Only struct types may implement trait `{trait_path}`"),
"Only struct types may implement traits".into(),
format!("Only limited types may implement trait `{trait_path}`"),
"Only limited types may implement traits".into(),
span,
)
}
Expand Down
12 changes: 10 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,14 +859,22 @@ impl<'interner> TypeChecker<'interner> {
}
// Mutable references to another type should resolve to methods of their element type.
// This may be a struct or a primitive type.
Type::MutableReference(element) => self.lookup_method(element, method_name, expr_id),
Type::MutableReference(element) => self
.interner
.lookup_mut_primitive_trait_method(element.as_ref(), method_name)
.map(HirMethodReference::FuncId)
.or_else(|| self.lookup_method(element, method_name, expr_id)),
// If we fail to resolve the object to a struct type, we have no way of type
// checking its arguments as we can't even resolve the name of the function
Type::Error => None,

// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => match self.interner.lookup_primitive_method(other, method_name) {
other => match self
.interner
.lookup_primitive_method(other, method_name)
.or_else(|| self.interner.lookup_primitive_trait_method(other, method_name))
{
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
Expand Down
64 changes: 60 additions & 4 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub struct NodeInterner {

// For trait implementation functions, this is their self type and trait they belong to
func_id_to_trait: HashMap<FuncId, (Type, TraitId)>,

/// Trait implementations on primitive types
primitive_trait_impls: HashMap<(Type, String), FuncId>,
}

/// All the information from a function that is filled out during definition collection rather than
Expand Down Expand Up @@ -382,6 +385,7 @@ impl Default for NodeInterner {
globals: HashMap::new(),
struct_methods: HashMap::new(),
primitive_methods: HashMap::new(),
primitive_trait_impls: HashMap::new(),
};

// An empty block expression is used often, we add this into the `node` on startup
Expand Down Expand Up @@ -881,11 +885,49 @@ impl NodeInterner {
self.trait_implementations.get(key).cloned()
}

pub fn add_trait_implementation(&mut self, key: &TraitImplKey, trait_impl: Shared<TraitImpl>) {
pub fn add_trait_implementation(
&mut self,
key: &TraitImplKey,
trait_impl: Shared<TraitImpl>,
) -> bool {
self.trait_implementations.insert(key.clone(), trait_impl.clone());

for func_id in &trait_impl.borrow().methods {
self.add_method(&key.typ, self.function_name(func_id).to_owned(), *func_id);
match &key.typ {
Type::Struct(struct_type, _generics) => {
for func_id in &trait_impl.borrow().methods {
let method_name = self.function_name(func_id).to_owned();
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, *func_id);
}
true
}
Type::FieldElement
| Type::Unit
| Type::Array(..)
| Type::Integer(..)
| Type::Bool
| Type::Tuple(..)
| Type::String(..)
| Type::FmtString(..)
| Type::Function(..)
| Type::MutableReference(..) => {
for func_id in &trait_impl.borrow().methods {
let method_name = self.function_name(func_id).to_owned();
let key = (key.typ.clone(), method_name);
self.primitive_trait_impls.insert(key, *func_id);
}
true
}
// We should allow implementing traits NamedGenerics will also eventually be possible once we support generics
// impl<T> Foo for T
// but it's fine not to include these until we do.
Type::NamedGeneric(..) => false,
// prohibited are internal types (like NotConstant, TypeVariable, Forall, and Error) that
// aren't possible for users to write anyway
Type::TypeVariable(..)
| Type::Forall(..)
| Type::NotConstant
| Type::Constant(..)
| Type::Error => false,
}
}

Expand All @@ -899,6 +941,20 @@ impl NodeInterner {
get_type_method_key(typ)
.and_then(|key| self.primitive_methods.get(&(key, method_name.to_owned())).copied())
}

pub fn lookup_primitive_trait_method(&self, typ: &Type, method_name: &str) -> Option<FuncId> {
self.primitive_trait_impls.get(&(typ.clone(), method_name.to_string())).copied()
}

pub fn lookup_mut_primitive_trait_method(
&self,
typ: &Type,
method_name: &str,
) -> Option<FuncId> {
self.primitive_trait_impls
.get(&(Type::MutableReference(Box::new(typ.clone())), method_name.to_string()))
.copied()
}
}

/// These are the primitive type variants that we support adding methods to
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "trait_impl_base_type"
type = "bin"
authors = [""]
compiler_version = "0.10.5"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "5"
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
trait Fieldable {
fn to_field(self) -> Field;
}

impl Fieldable for u32 {
fn to_field(self) -> Field {
let res = self as Field;
res * 3
}
}


impl Fieldable for [u32; 3] {
fn to_field(self) -> Field {
let res = self[0] + self[1] + self[2];
res as Field
}
}

impl Fieldable for bool {
fn to_field(self) -> Field {
if self {
14
} else {
3
}
}
}

impl Fieldable for (u32, bool) {
fn to_field(self) -> Field {
if self.1 {
self.0 as Field
} else {
32
}
}
}

impl Fieldable for Field {
fn to_field(self) -> Field {
self
}
}

impl Fieldable for str<6> {
fn to_field(self) -> Field {
6
}
}

impl Fieldable for () {
fn to_field(self) -> Field {
0
}
}

type Point2D = [Field; 2];
type Point2DAlias = Point2D;

impl Fieldable for Point2DAlias {
fn to_field(self) -> Field {
self[0] + self[1]
}
}

impl Fieldable for fmtstr<14, (Field, Field)> {
fn to_field(self) -> Field {
52
}
}

impl Fieldable for fn(u32) -> u32 {
fn to_field(self) -> Field {
self(10) as Field
}
}

fn some_func(x: u32) -> u32 {
x * 2 - 3
}


trait MutFieldable {
fn mut_to_field(self) -> Field;
}

impl MutFieldable for &mut u64 {
fn mut_to_field(self) -> Field {
1337 as Field
}
}

fn a(y: &mut u64) -> Field {
y.mut_to_field()
}

impl Fieldable for &mut u64 {
fn to_field(self) -> Field {
777 as Field
}
}

impl Fieldable for u64 {
fn to_field(self) -> Field {
66 as Field
}
}

// x = 15
fn main(x: u32) {
assert(x.to_field() == 15);
let arr: [u32; 3] = [3, 5, 8];
assert(arr.to_field() == 16);
let b_true = 2 == 2;
assert(b_true.to_field() == 14);
let b_false = 2 == 3;
assert(b_false.to_field() == 3);
let f = 13 as Field;
assert(f.to_field() == 13);
let k_true = (12 as u32, true);
assert(k_true.to_field() == 12);
let k_false = (11 as u32, false);
assert(k_false.to_field() == 32);
let m = "String";
assert(m.to_field() == 6);
let unit = ();
assert(unit.to_field() == 0);
let point: Point2DAlias = [2, 3];
assert(point.to_field() == 5);
let i: Field = 2;
let j: Field = 6;
assert(f"i: {i}, j: {j}".to_field() == 52);
assert(some_func.to_field() == 17);

let mut y = 0 as u64;
assert(a(&mut y) == 1337);
assert((&mut y).mut_to_field() == 1337);
assert((&mut y).to_field() == 777);
assert(y.to_field() == 66);
}

0 comments on commit 4fde6bb

Please sign in to comment.