diff --git a/lib/src/dsl_util.rs b/lib/src/dsl_util.rs
index c92f01cef3..ac6615e0db 100644
--- a/lib/src/dsl_util.rs
+++ b/lib/src/dsl_util.rs
@@ -458,16 +458,20 @@ impl
AliasesMap
{
self.symbol_aliases.insert(name, defn.into());
}
AliasDeclaration::Function(name, params) => {
- let overloads = self.function_aliases.entry(name).or_default();
- match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) {
- Ok(i) => overloads[i] = (params, defn.into()),
- Err(i) => overloads.insert(i, (params, defn.into())),
- }
+ self.insert_function(name, params, defn.into());
}
}
Ok(())
}
+ fn insert_function(&mut self, name: String, params: Vec, defn: V) {
+ let overloads = self.function_aliases.entry(name).or_default();
+ match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) {
+ Ok(i) => overloads[i] = (params, defn),
+ Err(i) => overloads.insert(i, (params, defn)),
+ }
+ }
+
/// Iterates symbol names in arbitrary order.
pub fn symbol_names(&self) -> impl Iterator- {
self.symbol_aliases.keys().map(|n| n.as_ref())
@@ -492,6 +496,18 @@ impl
AliasesMap
{
overloads.find_by_arity(arity)
}
+ /// Looks up a function alias by name and arity, assuming that the function
+ /// can take extra parameters (eg. *args in python).
+ /// Returns identifier, list of parameter names, and definition text.
+ pub fn get_function_with_leftovers(
+ &self,
+ name: &str,
+ arity: usize,
+ ) -> Option<(AliasId<'_>, &[String], &V)> {
+ let overloads = self.get_function_overloads(name)?;
+ overloads.find_by_arity_with_leftovers(arity)
+ }
+
/// Looks up function aliases by name.
fn get_function_overloads(&self, name: &str) -> Option> {
let (name, overloads) = self.function_aliases.get_key_value(name)?;
@@ -518,16 +534,39 @@ impl<'a, V> AliasFunctionOverloads<'a, V> {
self.arities().next_back().unwrap()
}
- fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
- let index = self
- .overloads
- .binary_search_by_key(&arity, |(params, _)| params.len())
- .ok()?;
+ fn get_overload(&self, index: usize) -> (AliasId<'a>, &'a [String], &'a V) {
let (params, defn) = &self.overloads[index];
// Exact parameter names aren't needed to identify a function, but they
// provide a better error indication. (e.g. "foo(x, y)" is easier to
// follow than "foo/2".)
- Some((AliasId::Function(self.name, params), params, defn))
+ (AliasId::Function(self.name, params), params, defn)
+ }
+
+ fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
+ Some(
+ self.get_overload(
+ self.overloads
+ .binary_search_by_key(&arity, |(params, _)| params.len())
+ .ok()?,
+ ),
+ )
+ }
+
+ // `find_arity(3)` is equivalent to finding the correct overload for `fn(a, b, c, *args)`.
+ // We need to find the longest match that is still less than arity.
+ fn find_by_arity_with_leftovers(
+ &self,
+ arity: usize,
+ ) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
+ if arity < self.min_arity() {
+ // This is like calling `fn(a, b)` when the definition is `fn(a, b, c ,*args)`
+ None
+ } else {
+ let first_invalid = self
+ .overloads
+ .partition_point(|(args, _)| arity >= args.len());
+ Some(self.get_overload(first_invalid - 1))
+ }
}
}
@@ -849,4 +888,65 @@ mod tests {
let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
}
+
+ #[test]
+ fn test_aliases_map_arity() {
+ let mut aliases_map = AliasesMap::<(), i32>::default();
+ aliases_map.insert_function("single".to_string(), vec![], 0);
+ aliases_map.insert_function("overload".to_string(), vec!["first".to_string()], 1);
+ aliases_map.insert_function(
+ "overload".to_string(),
+ vec![
+ "first".to_string(),
+ "second".to_string(),
+ "third".to_string(),
+ ],
+ 3,
+ );
+
+ let get_alias = |name, arity| {
+ let (_, params, defn) = aliases_map.get_function_with_leftovers(name, arity)?;
+ Some((params, *defn))
+ };
+
+ assert_eq!(get_alias("nonexistent", 1), None);
+ assert_eq!(get_alias("single", 3), Some(([].as_slice(), 0)));
+ assert_eq!(get_alias("overload", 0), None);
+
+ assert_eq!(
+ get_alias("overload", 1),
+ Some((["first".to_string()].as_slice(), 1))
+ );
+
+ assert_eq!(
+ get_alias("overload", 2),
+ Some((["first".to_string()].as_slice(), 1))
+ );
+
+ assert_eq!(
+ get_alias("overload", 3),
+ Some((
+ [
+ "first".to_string(),
+ "second".to_string(),
+ "third".to_string()
+ ]
+ .as_slice(),
+ 3
+ ))
+ );
+
+ assert_eq!(
+ get_alias("overload", 4),
+ Some((
+ [
+ "first".to_string(),
+ "second".to_string(),
+ "third".to_string()
+ ]
+ .as_slice(),
+ 3
+ ))
+ );
+ }
}