From e723d10446edc9e5a2c348f6d0d3ab744c7d56d0 Mon Sep 17 00:00:00 2001 From: Matt Stark Date: Fri, 4 Oct 2024 14:52:00 +1000 Subject: [PATCH] feat: Add support for *args to AliasMap This is required in order to support command aliases as functions, because they will require the following to be supported: ```toml 'upload()' = ["upload", "@"] 'upload(r)' = [["fix", "$r"], ["git", "push", "--change", "$r"]] ``` Here, `jj upload` would need to map to the first, `jj upload foo` would need to map to the second, and `jj upload foo --bar` would also need to map to the second. --- lib/src/dsl_util.rs | 122 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 11 deletions(-) diff --git a/lib/src/dsl_util.rs b/lib/src/dsl_util.rs index c92f01cef3..ac6615e0db 100644 --- a/lib/src/dsl_util.rs +++ b/lib/src/dsl_util.rs @@ -458,16 +458,20 @@ impl AliasesMap { self.symbol_aliases.insert(name, defn.into()); } AliasDeclaration::Function(name, params) => { - let overloads = self.function_aliases.entry(name).or_default(); - match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) { - Ok(i) => overloads[i] = (params, defn.into()), - Err(i) => overloads.insert(i, (params, defn.into())), - } + self.insert_function(name, params, defn.into()); } } Ok(()) } + fn insert_function(&mut self, name: String, params: Vec, defn: V) { + let overloads = self.function_aliases.entry(name).or_default(); + match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) { + Ok(i) => overloads[i] = (params, defn), + Err(i) => overloads.insert(i, (params, defn)), + } + } + /// Iterates symbol names in arbitrary order. pub fn symbol_names(&self) -> impl Iterator { self.symbol_aliases.keys().map(|n| n.as_ref()) @@ -492,6 +496,18 @@ impl AliasesMap { overloads.find_by_arity(arity) } + /// Looks up a function alias by name and arity, assuming that the function + /// can take extra parameters (eg. *args in python). + /// Returns identifier, list of parameter names, and definition text. + pub fn get_function_with_leftovers( + &self, + name: &str, + arity: usize, + ) -> Option<(AliasId<'_>, &[String], &V)> { + let overloads = self.get_function_overloads(name)?; + overloads.find_by_arity_with_leftovers(arity) + } + /// Looks up function aliases by name. fn get_function_overloads(&self, name: &str) -> Option> { let (name, overloads) = self.function_aliases.get_key_value(name)?; @@ -518,16 +534,39 @@ impl<'a, V> AliasFunctionOverloads<'a, V> { self.arities().next_back().unwrap() } - fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> { - let index = self - .overloads - .binary_search_by_key(&arity, |(params, _)| params.len()) - .ok()?; + fn get_overload(&self, index: usize) -> (AliasId<'a>, &'a [String], &'a V) { let (params, defn) = &self.overloads[index]; // Exact parameter names aren't needed to identify a function, but they // provide a better error indication. (e.g. "foo(x, y)" is easier to // follow than "foo/2".) - Some((AliasId::Function(self.name, params), params, defn)) + (AliasId::Function(self.name, params), params, defn) + } + + fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> { + Some( + self.get_overload( + self.overloads + .binary_search_by_key(&arity, |(params, _)| params.len()) + .ok()?, + ), + ) + } + + // `find_arity(3)` is equivalent to finding the correct overload for `fn(a, b, c, *args)`. + // We need to find the longest match that is still less than arity. + fn find_by_arity_with_leftovers( + &self, + arity: usize, + ) -> Option<(AliasId<'a>, &'a [String], &'a V)> { + if arity < self.min_arity() { + // This is like calling `fn(a, b)` when the definition is `fn(a, b, c ,*args)` + None + } else { + let first_invalid = self + .overloads + .partition_point(|(args, _)| arity >= args.len()); + Some(self.get_overload(first_invalid - 1)) + } } } @@ -849,4 +888,65 @@ mod tests { let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]); assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err()); } + + #[test] + fn test_aliases_map_arity() { + let mut aliases_map = AliasesMap::<(), i32>::default(); + aliases_map.insert_function("single".to_string(), vec![], 0); + aliases_map.insert_function("overload".to_string(), vec!["first".to_string()], 1); + aliases_map.insert_function( + "overload".to_string(), + vec![ + "first".to_string(), + "second".to_string(), + "third".to_string(), + ], + 3, + ); + + let get_alias = |name, arity| { + let (_, params, defn) = aliases_map.get_function_with_leftovers(name, arity)?; + Some((params, *defn)) + }; + + assert_eq!(get_alias("nonexistent", 1), None); + assert_eq!(get_alias("single", 3), Some(([].as_slice(), 0))); + assert_eq!(get_alias("overload", 0), None); + + assert_eq!( + get_alias("overload", 1), + Some((["first".to_string()].as_slice(), 1)) + ); + + assert_eq!( + get_alias("overload", 2), + Some((["first".to_string()].as_slice(), 1)) + ); + + assert_eq!( + get_alias("overload", 3), + Some(( + [ + "first".to_string(), + "second".to_string(), + "third".to_string() + ] + .as_slice(), + 3 + )) + ); + + assert_eq!( + get_alias("overload", 4), + Some(( + [ + "first".to_string(), + "second".to_string(), + "third".to_string() + ] + .as_slice(), + 3 + )) + ); + } }