Skip to content

Commit

Permalink
Cherry pick binary search fix for 6.0 (#345)
Browse files Browse the repository at this point in the history
Co-authored-by: Lőrinc Serfőző <[email protected]>
  • Loading branch information
stanleytsang-amd and mfep authored Dec 5, 2023
1 parent 5d9e939 commit 44020d6
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Full documentation for rocThrust is available at [https://rocthrust.readthedocs.
### Removed
- Removed cub symlink from the root of the repository.
- Removed support for deprecated macros (THRUST_DEVICE_BACKEND and THRUST_HOST_BACKEND).
### Fixed
- Fixed a segmentation fault when binary search / upper bound / lower bound / equal range was invoked with `hip_rocprim::execute_on_stream_base` policy.
### Known issues
- For NVIDIA backend, `NV_IF_TARGET` and `THRUST_RDC_ENABLED` intend to substitute the `THRUST_HAS_CUDART` macro, which is now no longer used in Thrust (provided for legacy support only). However, there is no `THRUST_RDC_ENABLED` macro available for the HIP backend, so some branches in Thrust's code may be unreachable in the HIP backend.

Expand Down
23 changes: 23 additions & 0 deletions test/test_binary_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,29 @@ TEST(BinarySearchTests, TestScalarEqualRangeDispatchImplicit)
ASSERT_EQ(13, vec.front());
}

TEST(BinarySearchTests, TestEqualRangeExecutionPolicy)
{
using thrust_exec_policy_t
= thrust::detail::execute_with_allocator<thrust::device_allocator<char>,
thrust::hip_rocprim::execute_on_stream_base>;

constexpr int data[] = {1, 2, 3, 4, 4, 5, 6, 7, 8, 9};
constexpr size_t size = sizeof(data) / sizeof(data[0]);
constexpr int key = 4;
thrust::device_vector<int> d_data(data, data + size);

thrust::pair<thrust::device_vector<int>::iterator, thrust::device_vector<int>::iterator> range
= thrust::equal_range(
thrust_exec_policy_t(thrust::hip_rocprim::execute_on_stream_base<thrust_exec_policy_t>(
hipStreamPerThread),
thrust::device_allocator<char>()),
d_data.begin(),
d_data.end(),
key);

ASSERT_EQ(*range.first, 4);
ASSERT_EQ(*range.second, 5);
}

__global__
THRUST_HIP_LAUNCH_BOUNDS_DEFAULT
Expand Down
89 changes: 83 additions & 6 deletions thrust/system/hip/detail/binary_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,38 @@ HaystackIt lower_bound(execution_policy<Derived>& policy,
values_type values(policy, 1);
results_type result(policy, 1);

values[0] = value;
{
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
value_in_system_t value_in_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
&value,
1,
values.begin());
}

__binary_search::lower_bound(
policy, first, last, values.begin(), values.end(), result.begin(), compare_op);

return first + result[0];
difference_type h_result;
{
typedef
typename thrust::iterator_system<difference_type*>::type result_out_system_t;
result_out_system_t result_out_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
thrust::detail::derived_cast(
thrust::detail::strip_const(result_out_system))),
result.begin(),
1,
&h_result);
}

return first + h_result;
}

__device__
Expand Down Expand Up @@ -524,13 +550,39 @@ HaystackIt upper_bound(execution_policy<Derived>& policy,
values_type values(policy, 1);
results_type result(policy, 1);

values[0] = value;
{
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
value_in_system_t value_in_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
&value,
1,
values.begin());
}

__binary_search::upper_bound(
policy, first, last, values.begin(), values.end(), result.begin(), compare_op
);

return first + result[0];
difference_type h_result;
{
typedef
typename thrust::iterator_system<difference_type*>::type result_out_system_t;
result_out_system_t result_out_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
thrust::detail::derived_cast(
thrust::detail::strip_const(result_out_system))),
result.begin(),
1,
&h_result);
}

return first + h_result;
}

__device__
Expand Down Expand Up @@ -583,13 +635,38 @@ bool binary_search(execution_policy<Derived>& policy,
values_type values(policy, 1);
results_type result(policy, 1);

values[0] = value;
{
typedef typename thrust::iterator_system<const T*>::type value_in_system_t;
value_in_system_t value_in_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(
thrust::detail::derived_cast(thrust::detail::strip_const(value_in_system)),
thrust::detail::derived_cast(thrust::detail::strip_const(policy))),
&value,
1,
values.begin());
}

__binary_search::binary_search(
policy, first, last, values.begin(), values.end(), result.begin(), compare_op
);

return result[0] != 0;
int h_result;
{
typedef typename thrust::iterator_system<int*>::type result_out_system_t;
result_out_system_t result_out_system;
using thrust::system::detail::generic::select_system;
thrust::copy_n(
select_system(thrust::detail::derived_cast(thrust::detail::strip_const(policy)),
thrust::detail::derived_cast(
thrust::detail::strip_const(result_out_system))),
result.begin(),
1,
&h_result);
}

return h_result != 0;
}

__device__
Expand Down

0 comments on commit 44020d6

Please sign in to comment.