diff --git a/src/libcollectionstest/slice.rs b/src/libcollectionstest/slice.rs index ca2ee0c512bfe..236c151891d11 100644 --- a/src/libcollectionstest/slice.rs +++ b/src/libcollectionstest/slice.rs @@ -574,18 +574,48 @@ fn test_slice_2() { assert_eq!(v[1], 3); } +macro_rules! assert_order { + (Greater, $a:expr, $b:expr) => { + assert_eq!($a.cmp($b), Greater); + assert!($a > $b); + }; + (Less, $a:expr, $b:expr) => { + assert_eq!($a.cmp($b), Less); + assert!($a < $b); + }; + (Equal, $a:expr, $b:expr) => { + assert_eq!($a.cmp($b), Equal); + assert_eq!($a, $b); + } +} + +#[test] +fn test_total_ord_u8() { + let c = &[1u8, 2, 3]; + assert_order!(Greater, &[1u8, 2, 3, 4][..], &c[..]); + let c = &[1u8, 2, 3, 4]; + assert_order!(Less, &[1u8, 2, 3][..], &c[..]); + let c = &[1u8, 2, 3, 6]; + assert_order!(Equal, &[1u8, 2, 3, 6][..], &c[..]); + let c = &[1u8, 2, 3, 4, 5, 6]; + assert_order!(Less, &[1u8, 2, 3, 4, 5, 5, 5, 5][..], &c[..]); + let c = &[1u8, 2, 3, 4]; + assert_order!(Greater, &[2u8, 2][..], &c[..]); +} + + #[test] -fn test_total_ord() { +fn test_total_ord_i32() { let c = &[1, 2, 3]; - [1, 2, 3, 4][..].cmp(c) == Greater; + assert_order!(Greater, &[1, 2, 3, 4][..], &c[..]); let c = &[1, 2, 3, 4]; - [1, 2, 3][..].cmp(c) == Less; + assert_order!(Less, &[1, 2, 3][..], &c[..]); let c = &[1, 2, 3, 6]; - [1, 2, 3, 4][..].cmp(c) == Equal; + assert_order!(Equal, &[1, 2, 3, 6][..], &c[..]); let c = &[1, 2, 3, 4, 5, 6]; - [1, 2, 3, 4, 5, 5, 5, 5][..].cmp(c) == Less; + assert_order!(Less, &[1, 2, 3, 4, 5, 5, 5, 5][..], &c[..]); let c = &[1, 2, 3, 4]; - [2, 2][..].cmp(c) == Greater; + assert_order!(Greater, &[2, 2][..], &c[..]); } #[test] diff --git a/src/libcore/lib.rs b/src/libcore/lib.rs index 648b652772aec..fa5e90562d80e 100644 --- a/src/libcore/lib.rs +++ b/src/libcore/lib.rs @@ -75,6 +75,7 @@ #![feature(unwind_attributes)] #![feature(repr_simd, platform_intrinsics)] #![feature(rustc_attrs)] +#![feature(specialization)] #![feature(staged_api)] #![feature(unboxed_closures)] #![feature(question_mark)] diff --git a/src/libcore/slice.rs b/src/libcore/slice.rs index aa555b44e899f..25082eed2fe6f 100644 --- a/src/libcore/slice.rs +++ b/src/libcore/slice.rs @@ -1630,12 +1630,60 @@ pub unsafe fn from_raw_parts_mut<'a, T>(p: *mut T, len: usize) -> &'a mut [T] { } // -// Boilerplate traits +// Comparison traits // +extern { + /// Call implementation provided memcmp + /// + /// Interprets the data as u8. + /// + /// Return 0 for equal, < 0 for less than and > 0 for greater + /// than. + // FIXME(#32610): Return type should be c_int + fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32; +} + #[stable(feature = "rust1", since = "1.0.0")] impl PartialEq<[B]> for [A] where A: PartialEq { fn eq(&self, other: &[B]) -> bool { + SlicePartialEq::equal(self, other) + } + + fn ne(&self, other: &[B]) -> bool { + SlicePartialEq::not_equal(self, other) + } +} + +#[stable(feature = "rust1", since = "1.0.0")] +impl Eq for [T] {} + +#[stable(feature = "rust1", since = "1.0.0")] +impl Ord for [T] { + fn cmp(&self, other: &[T]) -> Ordering { + SliceOrd::compare(self, other) + } +} + +#[stable(feature = "rust1", since = "1.0.0")] +impl PartialOrd for [T] { + fn partial_cmp(&self, other: &[T]) -> Option { + SlicePartialOrd::partial_compare(self, other) + } +} + +#[doc(hidden)] +// intermediate trait for specialization of slice's PartialEq +trait SlicePartialEq { + fn equal(&self, other: &[B]) -> bool; + fn not_equal(&self, other: &[B]) -> bool; +} + +// Generic slice equality +impl SlicePartialEq for [A] + where A: PartialEq +{ + default fn equal(&self, other: &[B]) -> bool { if self.len() != other.len() { return false; } @@ -1648,7 +1696,8 @@ impl PartialEq<[B]> for [A] where A: PartialEq { true } - fn ne(&self, other: &[B]) -> bool { + + default fn not_equal(&self, other: &[B]) -> bool { if self.len() != other.len() { return true; } @@ -1663,12 +1712,36 @@ impl PartialEq<[B]> for [A] where A: PartialEq { } } -#[stable(feature = "rust1", since = "1.0.0")] -impl Eq for [T] {} +// Use memcmp for bytewise equality when the types allow +impl SlicePartialEq for [A] + where A: PartialEq + BytewiseEquality +{ + fn equal(&self, other: &[A]) -> bool { + if self.len() != other.len() { + return false; + } + unsafe { + let size = mem::size_of_val(self); + memcmp(self.as_ptr() as *const u8, + other.as_ptr() as *const u8, size) == 0 + } + } -#[stable(feature = "rust1", since = "1.0.0")] -impl Ord for [T] { - fn cmp(&self, other: &[T]) -> Ordering { + fn not_equal(&self, other: &[A]) -> bool { + !self.equal(other) + } +} + +#[doc(hidden)] +// intermediate trait for specialization of slice's PartialOrd +trait SlicePartialOrd { + fn partial_compare(&self, other: &[B]) -> Option; +} + +impl SlicePartialOrd for [A] + where A: PartialOrd +{ + default fn partial_compare(&self, other: &[A]) -> Option { let l = cmp::min(self.len(), other.len()); // Slice to the loop iteration range to enable bound check @@ -1677,19 +1750,33 @@ impl Ord for [T] { let rhs = &other[..l]; for i in 0..l { - match lhs[i].cmp(&rhs[i]) { - Ordering::Equal => (), + match lhs[i].partial_cmp(&rhs[i]) { + Some(Ordering::Equal) => (), non_eq => return non_eq, } } - self.len().cmp(&other.len()) + self.len().partial_cmp(&other.len()) } } -#[stable(feature = "rust1", since = "1.0.0")] -impl PartialOrd for [T] { - fn partial_cmp(&self, other: &[T]) -> Option { +impl SlicePartialOrd for [u8] { + #[inline] + fn partial_compare(&self, other: &[u8]) -> Option { + Some(SliceOrd::compare(self, other)) + } +} + +#[doc(hidden)] +// intermediate trait for specialization of slice's Ord +trait SliceOrd { + fn compare(&self, other: &[B]) -> Ordering; +} + +impl SliceOrd for [A] + where A: Ord +{ + default fn compare(&self, other: &[A]) -> Ordering { let l = cmp::min(self.len(), other.len()); // Slice to the loop iteration range to enable bound check @@ -1698,12 +1785,48 @@ impl PartialOrd for [T] { let rhs = &other[..l]; for i in 0..l { - match lhs[i].partial_cmp(&rhs[i]) { - Some(Ordering::Equal) => (), + match lhs[i].cmp(&rhs[i]) { + Ordering::Equal => (), non_eq => return non_eq, } } - self.len().partial_cmp(&other.len()) + self.len().cmp(&other.len()) } } + +// memcmp compares a sequence of unsigned bytes lexicographically. +// this matches the order we want for [u8], but no others (not even [i8]). +impl SliceOrd for [u8] { + #[inline] + fn compare(&self, other: &[u8]) -> Ordering { + let order = unsafe { + memcmp(self.as_ptr(), other.as_ptr(), + cmp::min(self.len(), other.len())) + }; + if order == 0 { + self.len().cmp(&other.len()) + } else if order < 0 { + Less + } else { + Greater + } + } +} + +#[doc(hidden)] +/// Trait implemented for types that can be compared for equality using +/// their bytewise representation +trait BytewiseEquality { } + +macro_rules! impl_marker_for { + ($traitname:ident, $($ty:ty)*) => { + $( + impl $traitname for $ty { } + )* + } +} + +impl_marker_for!(BytewiseEquality, + u8 i8 u16 i16 u32 i32 u64 i64 usize isize char bool); + diff --git a/src/libcore/str/mod.rs b/src/libcore/str/mod.rs index d5a5e2b47419f..305546df5be2d 100644 --- a/src/libcore/str/mod.rs +++ b/src/libcore/str/mod.rs @@ -1150,16 +1150,7 @@ Section: Comparing strings #[lang = "str_eq"] #[inline] fn eq_slice(a: &str, b: &str) -> bool { - a.len() == b.len() && unsafe { cmp_slice(a, b, a.len()) == 0 } -} - -/// Bytewise slice comparison. -/// NOTE: This uses the system's memcmp, which is currently dramatically -/// faster than comparing each byte in a loop. -#[inline] -unsafe fn cmp_slice(a: &str, b: &str, len: usize) -> i32 { - extern { fn memcmp(s1: *const i8, s2: *const i8, n: usize) -> i32; } - memcmp(a.as_ptr() as *const i8, b.as_ptr() as *const i8, len) + a.as_bytes() == b.as_bytes() } /* @@ -1328,8 +1319,7 @@ Section: Trait implementations */ mod traits { - use cmp::{self, Ordering, Ord, PartialEq, PartialOrd, Eq}; - use cmp::Ordering::{Less, Greater}; + use cmp::{Ord, Ordering, PartialEq, PartialOrd, Eq}; use iter::Iterator; use option::Option; use option::Option::Some; @@ -1340,16 +1330,7 @@ mod traits { impl Ord for str { #[inline] fn cmp(&self, other: &str) -> Ordering { - let cmp = unsafe { - super::cmp_slice(self, other, cmp::min(self.len(), other.len())) - }; - if cmp == 0 { - self.len().cmp(&other.len()) - } else if cmp < 0 { - Less - } else { - Greater - } + self.as_bytes().cmp(other.as_bytes()) } }