Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

find_first_of vectorized: generalize fast approach for 4 and 8 bytes elements #4623

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 84 additions & 76 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,9 @@ namespace {
#ifndef _M_ARM64EC
template <size_t _Amount>
static __m256i _Spread_avx(__m256i _Val, const size_t _Needle_length_el) noexcept {
if constexpr (_Amount == 1) {
if constexpr (_Amount == 0) {
return _mm256_undefined_si256();
} else if constexpr (_Amount == 1) {
return _mm256_broadcastd_epi32(_mm256_castsi256_si128(_Val));
} else if constexpr (_Amount == 2) {
return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val));
Expand Down Expand Up @@ -2248,7 +2250,9 @@ namespace {
#ifndef _M_ARM64EC
template <size_t _Amount>
static __m256i _Spread_avx(const __m256i _Val, const size_t _Needle_length_el) noexcept {
if constexpr (_Amount == 1) {
if constexpr (_Amount == 0) {
return _mm256_undefined_si256();
} else if constexpr (_Amount == 1) {
return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val));
} else if constexpr (_Amount == 2) {
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 1, 0));
Expand Down Expand Up @@ -2279,37 +2283,42 @@ namespace {
#ifndef _M_ARM64EC
template <class _Traits, size_t _Needle_length_el_magnitude>
__m256i _Shuffle_step(const __m256i _Data1, const __m256i _Data2s0) noexcept {
__m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2s0);
if constexpr (_Needle_length_el_magnitude >= 2) {
const __m256i _Data2s1 = _Traits::_Shuffle_avx<1>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s1));
if constexpr (_Needle_length_el_magnitude >= 4) {
const __m256i _Data2s2 = _Traits::_Shuffle_avx<2>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s2));
const __m256i _Data2s3 = _Traits::_Shuffle_avx<1>(_Data2s2);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s3));
if constexpr (_Needle_length_el_magnitude >= 8) {
const __m256i _Data2s4 = _Traits::_Shuffle_avx<4>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s4));
const __m256i _Data2s5 = _Traits::_Shuffle_avx<1>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s5));
const __m256i _Data2s6 = _Traits::_Shuffle_avx<2>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s6));
const __m256i _Data2s7 = _Traits::_Shuffle_avx<1>(_Data2s6);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s7));
__m256i _Eq = _mm256_setzero_si256();
if constexpr (_Needle_length_el_magnitude >= 1) {
_Eq = _Traits::_Cmp_avx(_Data1, _Data2s0);
if constexpr (_Needle_length_el_magnitude >= 2) {
const __m256i _Data2s1 = _Traits::_Shuffle_avx<1>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s1));
if constexpr (_Needle_length_el_magnitude >= 4) {
const __m256i _Data2s2 = _Traits::_Shuffle_avx<2>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s2));
const __m256i _Data2s3 = _Traits::_Shuffle_avx<1>(_Data2s2);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s3));
if constexpr (_Needle_length_el_magnitude >= 8) {
const __m256i _Data2s4 = _Traits::_Shuffle_avx<4>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s4));
const __m256i _Data2s5 = _Traits::_Shuffle_avx<1>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s5));
const __m256i _Data2s6 = _Traits::_Shuffle_avx<2>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s6));
const __m256i _Data2s7 = _Traits::_Shuffle_avx<1>(_Data2s6);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s7));
}
}
}
}
return _Eq;
}

template <class _Traits, size_t _Needle_length_el_magnitude>
template <class _Traits, bool _Large, size_t _Last2_length_el_magnitude>
const void* _Shuffle_impl(const void* _First1, const void* const _Last1, const void* const _First2,
const size_t _Needle_length_el) noexcept {
using _Ty = _Traits::_Ty;
const __m256i _Data2 = _mm256_maskload_epi32(
reinterpret_cast<const int*>(_First2), _Avx2_tail_mask_32(_Needle_length_el * (sizeof(_Ty) / 4)));
const __m256i _Data2s0 = _Traits::_Spread_avx<_Needle_length_el_magnitude>(_Data2, _Needle_length_el);
const void* const _Stop2, const size_t _Last2_length_el) noexcept {
using _Ty = _Traits::_Ty;
constexpr size_t _Length_el = 32 / sizeof(_Ty);

const __m256i _Last2val = _mm256_maskload_epi32(
reinterpret_cast<const int*>(_Stop2), _Avx2_tail_mask_32(_Last2_length_el * (sizeof(_Ty) / 4)));
const __m256i _Last2s0 = _Traits::_Spread_avx<_Last2_length_el_magnitude>(_Last2val, _Last2_length_el);

const size_t _Haystack_length = _Byte_length(_First1, _Last1);

Expand All @@ -2318,10 +2327,16 @@ namespace {

for (; _First1 != _Stop1; _Advance_bytes(_First1, 32)) {
const __m256i _Data1 = _mm256_loadu_si256(static_cast<const __m256i*>(_First1));
const __m256i _Eq = _Shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0);
const int _Bingo = _mm256_movemask_epi8(_Eq);
__m256i _Eq = _Shuffle_step<_Traits, _Last2_length_el_magnitude>(_Data1, _Last2s0);

if (_Bingo != 0) {
if constexpr (_Large) {
for (const void* _Ptr2 = _First2; _Ptr2 != _Stop2; _Advance_bytes(_Ptr2, 32)) {
const __m256i _Data2s0 = _mm256_loadu_si256(static_cast<const __m256i*>(_Ptr2));
_Eq = _mm256_or_si256(_Eq, _Shuffle_step<_Traits, _Length_el>(_Data1, _Data2s0));
}
}

if (const int _Bingo = _mm256_movemask_epi8(_Eq); _Bingo != 0) {
const unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First1, _Offset);
return _First1;
Expand All @@ -2331,10 +2346,16 @@ namespace {
if (const size_t _Haystack_tail_length = _Haystack_length & 0x1C; _Haystack_tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Haystack_tail_length >> 2);
const __m256i _Data1 = _mm256_maskload_epi32(static_cast<const int*>(_First1), _Tail_mask);
const __m256i _Eq = _Shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0);
const int _Bingo = _mm256_movemask_epi8(_mm256_and_si256(_Eq, _Tail_mask));
__m256i _Eq = _Shuffle_step<_Traits, _Last2_length_el_magnitude>(_Data1, _Last2s0);

if (_Bingo != 0) {
if constexpr (_Large) {
for (const void* _Ptr2 = _First2; _Ptr2 != _Stop2; _Advance_bytes(_Ptr2, 32)) {
const __m256i _Data2s0 = _mm256_loadu_si256(static_cast<const __m256i*>(_Ptr2));
_Eq = _mm256_or_si256(_Eq, _Shuffle_step<_Traits, _Length_el>(_Data1, _Data2s0));
}
}

if (const int _Bingo = _mm256_movemask_epi8(_mm256_and_si256(_Eq, _Tail_mask)); _Bingo != 0) {
const unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First1, _Offset);
return _First1;
Expand All @@ -2346,6 +2367,26 @@ namespace {
return _First1;
}

template <class _Traits, bool _Large>
const void* _Shuffle_impl_dispatch_magnitude(const void* _First1, const void* const _Last1,
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const void* const _First2, const void* const _Stop2, const size_t _Last2_length_el) noexcept {
if (_Last2_length_el == 0) {
return _Shuffle_impl<_Traits, _Large, 0>(_First1, _Last1, _First2, _Stop2, _Last2_length_el);
} else if (_Last2_length_el == 1) {
return _Shuffle_impl<_Traits, _Large, 1>(_First1, _Last1, _First2, _Stop2, _Last2_length_el);
} else if (_Last2_length_el == 2) {
return _Shuffle_impl<_Traits, _Large, 2>(_First1, _Last1, _First2, _Stop2, _Last2_length_el);
} else if (_Last2_length_el <= 4) {
return _Shuffle_impl<_Traits, _Large, 4>(_First1, _Last1, _First2, _Stop2, _Last2_length_el);
} else if (_Last2_length_el <= 8) {
if constexpr (sizeof(_Traits::_Ty) == 4) {
return _Shuffle_impl<_Traits, _Large, 8>(_First1, _Last1, _First2, _Stop2, _Last2_length_el);
}
}

_STL_UNREACHABLE;
}

#endif // !_M_ARM64EC

template <class _Traits>
Expand All @@ -2356,52 +2397,19 @@ namespace {
if (_Use_avx2()) {
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414

const size_t _Needle_length = _Byte_length(_First2, _Last2);
const int _Needle_length_el = static_cast<int>(_Needle_length / sizeof(_Ty));

// Special handling of small needle
// The generic approach could also handle it but with worse performance
if (_Needle_length_el == 0) {
return _Last1;
} else if (_Needle_length_el == 1) {
_STL_UNREACHABLE; // This is expected to be forwarded to 'find' on an upper level
} else if (_Needle_length_el == 2) {
return _Shuffle_impl<_Traits, 2>(_First1, _Last1, _First2, _Needle_length_el);
} else if (_Needle_length_el <= 4) {
return _Shuffle_impl<_Traits, 4>(_First1, _Last1, _First2, _Needle_length_el);
} else if (_Needle_length_el <= 8) {
if constexpr (sizeof(_Ty) == 4) {
return _Shuffle_impl<_Traits, 8>(_First1, _Last1, _First2, _Needle_length_el);
}
}

// Generic approach
const size_t _Needle_length_tail = _Needle_length & 0x1C;
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Needle_length_tail >> 2);

const void* _Stop2 = _First2;
_Advance_bytes(_Stop2, _Needle_length & ~size_t{0x1F});
const size_t _Needle_length = _Byte_length(_First2, _Last2);
const size_t _Last_needle_length = _Needle_length & 0x1F;
const size_t _Last_needle_length_el = _Last_needle_length / sizeof(_Ty);

for (auto _Ptr1 = static_cast<const _Ty*>(_First1); _Ptr1 != _Last1; ++_Ptr1) {
const auto _Data1 = _Traits::_Set_avx(*_Ptr1);
for (auto _Ptr2 = _First2; _Ptr2 != _Stop2; _Advance_bytes(_Ptr2, 32)) {
const __m256i _Data2 = _mm256_loadu_si256(static_cast<const __m256i*>(_Ptr2));
const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2);
if (!_mm256_testz_si256(_Eq, _Eq)) {
return _Ptr1;
}
}

if (_Needle_length_tail != 0) {
const __m256i _Data2 = _mm256_maskload_epi32(static_cast<const int*>(_Stop2), _Tail_mask);
const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2);
if (!_mm256_testz_si256(_Eq, _Tail_mask)) {
return _Ptr1;
}
}
if (size_t _Needle_length_large = _Needle_length & ~size_t{0x1F}; _Needle_length_large != 0) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const void* _Stop2 = _First2;
_Advance_bytes(_Stop2, _Needle_length & ~size_t{0x1F});
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
return _Shuffle_impl_dispatch_magnitude<_Traits, true>(
_First1, _Last1, _First2, _Stop2, _Last_needle_length_el);
} else {
return _Shuffle_impl_dispatch_magnitude<_Traits, false>(
_First1, _Last1, _First2, _First2, _Last_needle_length_el);
}

return _Last1;
}
#endif // !_M_ARM64EC
return _Fallback<_Ty>(_First1, _Last1, _First2, _Last2);
Expand Down