diff --git a/src/map.rs b/src/map.rs index e71c799c..bb915e30 100644 --- a/src/map.rs +++ b/src/map.rs @@ -477,6 +477,40 @@ where } } + /// Return the values for `N` keys. If any key is missing a value, or there + /// are duplicate keys, `None` is returned. + /// + /// # Examples + /// + /// ``` + /// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_many_mut([&2, &1]), Some([&mut 'c', &mut 'a'])); + /// ``` + pub fn get_many_mut<'a, 'b, Q: ?Sized, const N: usize>( + &'a mut self, + keys: [&'b Q; N], + ) -> Option<[&'a mut V; N]> + where + Q: Hash + Equivalent, + { + let len = self.len(); + let indices = keys.map(|key| self.get_index_of(key)); + + // Handle out-of-bounds indices with panic as this is an internal error in get_index_of. + for idx in indices { + let idx = idx?; + debug_assert!( + idx < len, + "Index is out of range! Got '{}' but length is '{}'", + idx, + len + ); + } + let indices = indices.map(Option::unwrap); + let entries = self.get_many_index_mut(indices)?; + Some(entries.map(|(_key, value)| value)) + } + /// Remove the key-value pair equivalent to `key` and return /// its value. /// @@ -784,6 +818,44 @@ impl IndexMap { self.as_entries_mut().get_mut(index).map(Bucket::ref_mut) } + /// Get an array of `N` key-value pairs by `N` indices + /// + /// Valid indices are *0 <= index < self.len()* and each index needs to be unique. + /// + /// Computes in **O(1)** time. + /// + /// # Examples + /// + /// ``` + /// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_many_index_mut([2, 0]), Some([(&2, &mut 'c'), (&1, &mut 'a')])); + /// ``` + pub fn get_many_index_mut( + &mut self, + indices: [usize; N], + ) -> Option<[(&K, &mut V); N]> { + // SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data. + let len = self.len(); + for i in 0..N { + let idx = indices[i]; + if idx >= len || indices[i + 1..N].contains(&idx) { + return None; + } + } + + let entries_ptr = self.as_entries_mut().as_mut_ptr(); + let out = indices.map(|i| { + // SAFETY: The base pointer is valid as it comes from a slice and the deref is always + // in-bounds as we've already checked the indices above. + #[allow(unsafe_code)] + unsafe { + (*(entries_ptr.add(i))).ref_mut() + } + }); + + Some(out) + } + /// Returns a slice of key-value pairs in the given range of indices. /// /// Valid indices are *0 <= index < self.len()* diff --git a/src/map/tests.rs b/src/map/tests.rs index b6c6a42d..c2ae8ede 100644 --- a/src/map/tests.rs +++ b/src/map/tests.rs @@ -418,3 +418,72 @@ fn from_array() { assert_eq!(map, expected) } + +#[test] +fn many_mut_empty() { + let mut map: IndexMap = IndexMap::default(); + assert!(map.get_many_mut([&0, &1, &2, &3]).is_none()); +} + +#[test] +fn many_mut_single_fail() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert!(map.get_many_mut([&0]).is_none()); +} + +#[test] +fn many_mut_single_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert_eq!(map.get_many_mut([&1]), Some([&mut 10])); +} + +#[test] +fn many_mut_multi_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&1, &1123]), Some([&mut 10, &mut 100])); + assert_eq!(map.get_many_mut([&1, &1337]), Some([&mut 10, &mut 30])); + assert_eq!( + map.get_many_mut([&1337, &321, &1, &1123]), + Some([&mut 30, &mut 20, &mut 10, &mut 100]) + ); +} + +#[test] +fn many_mut_multi_fail_missing() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&121, &1123]), None); + assert_eq!(map.get_many_mut([&1, &1337, &56]), None); + assert_eq!(map.get_many_mut([&1337, &123, &321, &1, &1123]), None); +} + +#[test] +fn many_mut_multi_fail_duplicate() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&1, &1]), None); + assert_eq!( + map.get_many_mut([&1337, &123, &321, &1337, &1, &1123]), + None + ); +} + +#[test] +fn many_index_mut_fail_oob() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!(map.get_many_index_mut([1, 3]), None); +}