Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve count vectorization: replace popcnt implementation with vector counting #4614

Merged
merged 27 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b1aacab
`count` vectorization: replace `popcnt` implementation with vector co…
AlexGuteniev Apr 21, 2024
6f134e8
Don't do extra reduce in the end
AlexGuteniev Apr 21, 2024
35d61ac
As the SSE branch has no masked tail, can reduce in a single pace
AlexGuteniev Apr 21, 2024
ce8d8a5
Reduce as infrequently as possible
AlexGuteniev Apr 22, 2024
af456db
missing range coverage
AlexGuteniev Apr 22, 2024
c232f62
test counting zeros
AlexGuteniev Apr 22, 2024
498408c
compare with expected in new coverage
AlexGuteniev Apr 22, 2024
48efb24
separate _Count_traits_N and reuse reduce
AlexGuteniev Apr 22, 2024
403ef98
formatting
AlexGuteniev Apr 22, 2024
983e93a
Stand away from overflows!
AlexGuteniev Apr 22, 2024
64b6fe6
sizes are bytes
AlexGuteniev Apr 22, 2024
85263d8
counting overflow better coverage
AlexGuteniev Apr 22, 2024
0a92efc
fewer ops to reduce
AlexGuteniev Apr 22, 2024
d816df0
Comments cleanup
AlexGuteniev Apr 22, 2024
dedc0cc
reduce 1-byte with `sad` instruction
AlexGuteniev Apr 22, 2024
015d4f7
Simplify `_Count_traits_8::_Reduce_avx()` by reusing `_Reduce_sse()`.
StephanTLavavej Apr 24, 2024
7e39a04
Fix `_Count_traits_4::_Max_count`.
StephanTLavavej Apr 25, 2024
8d4aab5
Add detailed comments explaining each `_Max_count`.
StephanTLavavej Apr 25, 2024
50d610a
For clarity, scope `__m128i _Count_vector` to each iteration of the S…
StephanTLavavej Apr 25, 2024
c06d4ff
For the AVX2 loop, scope `__m256i _Count_vector` separately for the m…
StephanTLavavej Apr 25, 2024
3f95815
Fix comment typo.
StephanTLavavej Apr 25, 2024
22b561b
Change test_count_zero() to test more lengths.
StephanTLavavej Apr 25, 2024
e587698
Restore popcnt approach
AlexGuteniev Apr 25, 2024
87749d2
Get my bounds back
AlexGuteniev Apr 25, 2024
2e1d6c7
typos
AlexGuteniev Apr 25, 2024
fc6dcbd
restore SSE4.2 comment
AlexGuteniev Apr 25, 2024
c5dd5c2
Clarify comments.
StephanTLavavej Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 178 additions & 34 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1734,8 +1734,6 @@ __declspec(noalias) _Min_max_d __stdcall __std_minmax_d(const void* const _First

namespace {
struct _Find_traits_1 {
static constexpr size_t _Shift = 0;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint8_t _Val) noexcept {
return _mm256_set1_epi8(_Val);
Expand All @@ -1756,8 +1754,6 @@ namespace {
};

struct _Find_traits_2 {
static constexpr size_t _Shift = 1;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint16_t _Val) noexcept {
return _mm256_set1_epi16(_Val);
Expand All @@ -1778,8 +1774,6 @@ namespace {
};

struct _Find_traits_4 {
static constexpr size_t _Shift = 2;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint32_t _Val) noexcept {
return _mm256_set1_epi32(_Val);
Expand All @@ -1800,8 +1794,6 @@ namespace {
};

struct _Find_traits_8 {
static constexpr size_t _Shift = 3;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint64_t _Val) noexcept {
return _mm256_set1_epi64x(_Val);
Expand Down Expand Up @@ -1978,6 +1970,116 @@ namespace {
}
}

struct _Count_traits_8 : _Find_traits_8 {
#ifndef _M_ARM64EC
static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi64(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi64(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m128i _Lo64 = _mm256_extracti128_si256(_Val, 0);
const __m128i _Hi64 = _mm256_extracti128_si256(_Val, 1);
const __m128i _Rx8 = _mm_add_epi64(_Lo64, _Hi64);
#ifdef _M_IX86
return _mm_cvtsi128_si32(_Rx8) + _mm_extract_epi32(_Rx8, 2);
#else // ^^^ defined(_M_IX86) / defined(_M_X64) vvv
return _mm_cvtsi128_si64(_Rx8) + _mm_extract_epi64(_Rx8, 1);
#endif // ^^^ defined(_M_X64) ^^^
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
#ifdef _M_IX86
return _mm_cvtsi128_si32(_Val) + _mm_extract_epi32(_Val, 2);
#else // ^^^ defined(_M_IX86) / defined(_M_X64) vvv
return _mm_cvtsi128_si64(_Val) + _mm_extract_epi64(_Val, 1);
#endif // ^^^ defined(_M_X64) ^^^
}
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#endif // !_M_ARM64EC
};

struct _Count_traits_4 : _Find_traits_4 {
#ifndef _M_ARM64EC
static constexpr size_t _Max_count = 0x1FFF'FFFF;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi32(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi32(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
constexpr auto _Shuf = _MM_SHUFFLE(3, 1, 2, 0); // Cross lane, to reduce further on low lane
const __m256i _Rx4 = _mm256_hadd_epi32(_Val, _mm256_setzero_si256()); // (0+1),(2+3),0,0 per lane
const __m256i _Rx5 = _mm256_permute4x64_epi64(_Rx4, _Shuf); // low lane (0+1),(2+3),(4+5),(6+7)
const __m256i _Rx6 = _mm256_hadd_epi32(_Rx5, _mm256_setzero_si256()); // (0+..+3),(4+...+7),0,0
const __m256i _Rx7 = _mm256_hadd_epi32(_Rx6, _mm256_setzero_si256()); // (0+...+7),0,0,0
return _mm_cvtsi128_si32(_mm256_castsi256_si128(_Rx7));
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx4 = _mm_hadd_epi32(_Val, _mm_setzero_si128()); // (0+1),(2+3),0,0
const __m128i _Rx5 = _mm_hadd_epi32(_Rx4, _mm_setzero_si128()); // (0+...+3),0,0,0
return _mm_cvtsi128_si32(_Rx5);
}
#endif // !_M_ARM64EC
};

struct _Count_traits_2 : _Find_traits_2 {
#ifndef _M_ARM64EC
static constexpr size_t _Max_count = 0x7FFF;

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi16(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi16(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m256i _Rx2 = _mm256_hadd_epi16(_Val, _mm256_setzero_si256());
const __m256i _Rx3 = _mm256_unpacklo_epi16(_Rx2, _mm256_setzero_si256());
return _Count_traits_4::_Reduce_avx(_Rx3);
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx2 = _mm_hadd_epi16(_Val, _mm_setzero_si128());
const __m128i _Rx3 = _mm_unpacklo_epi16(_Rx2, _mm_setzero_si128());
return _Count_traits_4::_Reduce_sse(_Rx3);
}
#endif // !_M_ARM64EC
};

struct _Count_traits_1 : _Find_traits_1 {
#ifndef _M_ARM64EC
static constexpr size_t _Max_count = 0xFF;

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi8(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi8(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m256i _Rx1 = _mm256_sad_epu8(_Val, _mm256_setzero_si256());
return _Count_traits_8::_Reduce_avx(_Rx1);
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx1 = _mm_sad_epu8(_Val, _mm_setzero_si128());
return _Count_traits_8::_Reduce_sse(_Rx1);
}
#endif // !_M_ARM64EC
};

template <class _Traits, class _Ty>
__declspec(noalias) size_t
__stdcall __std_count_trivial_impl(const void* _First, const void* const _Last, const _Ty _Val) noexcept {
Expand All @@ -1986,47 +2088,89 @@ namespace {
#ifndef _M_ARM64EC
const size_t _Size_bytes = _Byte_length(_First, _Last);

if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
if (size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
const __m256i _Comparand = _Traits::_Set_avx(_Val);
const void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Avx_size);
__m256i _Count_vector = _mm256_setzero_si256();

do {
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
const int _Bingo = _mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
_Advance_bytes(_First, 32);
} while (_First != _Stop_at);
for (;;) {
if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
_Advance_bytes(_Stop_at, _Avx_size);
} else {
constexpr size_t _Max_portion_size = (_Traits::_Max_count - 1) * 32;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const size_t _Portion_size = _Avx_size < _Max_portion_size ? _Avx_size : _Max_portion_size;
_Advance_bytes(_Stop_at, _Portion_size);
_Avx_size -= _Portion_size;
}

do {
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
const __m256i _Mask = _Traits::_Cmp_avx(_Data, _Comparand);
_Count_vector = _Traits::_Sub_avx(_Count_vector, _Mask);
_Advance_bytes(_First, 32);
} while (_First != _Stop_at);

if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
break;
} else {
if (_Avx_size == 0) {
break;
}

_Result += _Traits::_Reduce_avx(_Count_vector);
_Count_vector = _mm256_setzero_si256();
}
}

if (const size_t _Avx_tail_size = _Size_bytes & 0x1C; _Avx_tail_size != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Avx_tail_size >> 2);
const __m256i _Data = _mm256_maskload_epi32(static_cast<const int*>(_First), _Tail_mask);
const int _Bingo =
_mm256_movemask_epi8(_mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
const __m256i _Mask = _Traits::_Cmp_avx(_Data, _Comparand);
_Count_vector = _Traits::_Sub_avx(_Count_vector, _mm256_and_si256(_Mask, _Tail_mask));
_Advance_bytes(_First, _Avx_tail_size);
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414
_Result += _Traits::_Reduce_avx(_Count_vector);

_Result >>= _Traits::_Shift;
_mm256_zeroupper(); // TRANSITION, DevCom-10331414

if constexpr (sizeof(_Ty) >= 4) {
return _Result;
}
} else if (const size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
} else if (size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
const __m128i _Comparand = _Traits::_Set_sse(_Val);
const void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Sse_size);
__m128i _Count_vector = _mm_setzero_si128();

do {
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
const int _Bingo = _mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
_Advance_bytes(_First, 16);
} while (_First != _Stop_at);
for (;;) {
if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
_Advance_bytes(_Stop_at, _Sse_size);
} else {
constexpr size_t _Max_portion_size = _Traits::_Max_count * 16;
const size_t _Portion_size = _Sse_size < _Max_portion_size ? _Sse_size : _Max_portion_size;
_Advance_bytes(_Stop_at, _Portion_size);
_Sse_size -= _Portion_size;
}

_Result >>= _Traits::_Shift;
do {
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
const __m128i _Mask = _Traits::_Cmp_sse(_Data, _Comparand);
_Count_vector = _Traits::_Sub_sse(_Count_vector, _Mask);
_Advance_bytes(_First, 16);
} while (_First != _Stop_at);

_Result += _Traits::_Reduce_sse(_Count_vector);

if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
break;
} else {
if (_Sse_size == 0) {
break;
}

_Count_vector = _mm_setzero_si128();
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
#endif // !_M_ARM64EC

Expand Down Expand Up @@ -2549,22 +2693,22 @@ const void* __stdcall __std_find_last_trivial_8(

__declspec(noalias) size_t
__stdcall __std_count_trivial_1(const void* const _First, const void* const _Last, const uint8_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_1>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_1>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_2(const void* const _First, const void* const _Last, const uint16_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_2>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_2>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_4(const void* const _First, const void* const _Last, const uint32_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_4>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_4>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_8(const void* const _First, const void* const _Last, const uint64_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_8>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_8>(_First, _Last, _Val);
}

const void* __stdcall __std_find_first_of_trivial_1(
Expand Down
20 changes: 20 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ void test_case_count(const vector<T>& input, T v) {
auto expected = last_known_good_count(input.begin(), input.end(), v);
auto actual = count(input.begin(), input.end(), v);
assert(expected == actual);
#if _HAS_CXX20
auto actual_r = ranges::count(input, v);
assert(actual_r == expected);
#endif // _HAS_CXX20
}

template <class T>
Expand All @@ -98,6 +102,12 @@ void test_count(mt19937_64& gen) {
}
}

template <class T>
void test_count_zero() { // text that counters don't overflow
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
vector<T> input(1000000, T{0});
test_case_count(input, T{0});
}

template <class FwdIt, class T>
auto last_known_good_find(FwdIt first, FwdIt last, T v) {
for (; first != last; ++first) {
Expand Down Expand Up @@ -738,6 +748,16 @@ void test_vector_algorithms(mt19937_64& gen) {
test_count<long long>(gen);
test_count<unsigned long long>(gen);

test_count_zero<char>();
test_count_zero<signed char>();
test_count_zero<unsigned char>();
test_count_zero<short>();
test_count_zero<unsigned short>();
test_count_zero<int>();
test_count_zero<unsigned int>();
test_count_zero<long long>();
test_count_zero<unsigned long long>();

test_find<char>(gen);
test_find<signed char>(gen);
test_find<unsigned char>(gen);
Expand Down