Skip to content

Commit

Permalink
Vectorize basic_string::find (the string needle overload) (#5048)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Oct 30, 2024
1 parent 30c9391 commit 5e8f003
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
19 changes: 19 additions & 0 deletions benchmarks/src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ void search_default_searcher(benchmark::State& state) {
}
}

template <class T>
void member_find(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const T haystack(src_haystack.begin(), src_haystack.end());
const T needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
benchmark::DoNotOptimize(needle);
auto res = haystack.find(needle);
benchmark::DoNotOptimize(res);
}
}

template <class T>
void classic_find_end(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
Expand Down Expand Up @@ -158,6 +174,9 @@ BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);

BENCHMARK(member_find<std::string>)->Apply(common_args);
BENCHMARK(member_find<std::wstring>)->Apply(common_args);

BENCHMARK(classic_find_end<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_find_end<std::uint16_t>)->Apply(common_args);

Expand Down
15 changes: 15 additions & 0 deletions stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,21 @@ constexpr size_t _Traits_find(_In_reads_(_Hay_size) const _Traits_ptr_t<_Traits>
return _Start_at;
}

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits> && sizeof(typename _Traits::char_type) <= 2) {
if (!_STD _Is_constant_evaluated()) {
const auto _End = _Haystack + _Hay_size;
const auto _Ptr = _STD _Search_vectorized(_Haystack + _Start_at, _End, _Needle, _Needle_size);

if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

const auto _Possible_matches_end = _Haystack + (_Hay_size - _Needle_size) + 1;
for (auto _Match_try = _Haystack + _Start_at;; ++_Match_try) {
_Match_try = _Traits::find(_Match_try, static_cast<size_t>(_Possible_matches_end - _Match_try), *_Needle);
Expand Down
35 changes: 35 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,22 +1328,57 @@ void test_case_string_find_last_of(const basic_string<T>& input_haystack, const
assert(expected == actual);
}

template <class T>
void test_case_string_find_str(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
ptrdiff_t expected;
if (input_needle.empty()) {
expected = 0;
} else {
const auto expected_iter = last_known_good_search(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}
}
const auto actual = static_cast<ptrdiff_t>(input_haystack.find(input_needle));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
basic_string<T> input_needle;
basic_string<T> temp;
input_haystack.reserve(haystackDataCount);
input_needle.reserve(needleDataCount);
temp.reserve(needleDataCount);

for (;;) {
input_needle.clear();

test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);

for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);

// For large needles the chance of a match is low, so test a guaranteed match
if (input_haystack.size() > input_needle.size() * 2) {
uniform_int_distribution<size_t> pos_dis(0, input_haystack.size() - input_needle.size());
const size_t pos = pos_dis(gen);
const auto overwritten_first = input_haystack.begin() + static_cast<ptrdiff_t>(pos);
temp.assign(overwritten_first, overwritten_first + static_cast<ptrdiff_t>(input_needle.size()));
copy(input_needle.begin(), input_needle.end(), overwritten_first);
test_case_string_find_str(input_haystack, input_needle);
copy(temp.begin(), temp.end(), overwritten_first);
}
}

if (input_haystack.size() == haystackDataCount) {
Expand Down

0 comments on commit 5e8f003

Please sign in to comment.