Skip to content

Commit

Permalink
Add set host find APIs taking probe key equality and hash
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 25, 2024
1 parent 9fa8904 commit 2ce7048
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
55 changes: 55 additions & 0 deletions include/cuco/detail/static_set/static_set.inl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_async(
InputIt first,
InputIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_async(first,
last,
output_begin,
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down Expand Up @@ -376,6 +399,38 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename ProbeEqual,
typename ProbeHash,
typename OutputIt>
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first,
last,
stencil,
pred,
output_begin,
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
Expand Down
71 changes: 71 additions & 0 deletions include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,34 @@ class static_set {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds an element with key
* equivalent to the query key.
*
* @note If the key `*(first + i)` has a matched `element` in the set, copies `element` to
* `(output_begin + i)`. Else, copies the empty key sentinel.
*
* @tparam InputIt Device accessible input iterator
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type that can be constructed from
* an integer value
* @tparam OutputIt Device accessible output iterator assignable from the set's `key_type`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param probe_equal The binary function to compare set keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_begin Beginning of the sequence of elements retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
void find_async(InputIt first,
InputIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the
* query key.
Expand Down Expand Up @@ -654,6 +682,49 @@ class static_set {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds
* a match with its key equivalent to the query key.
*
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
* Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type that can be constructed from
* an integer value
* @tparam OutputIt Device accessible output iterator
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param probe_equal The binary function to compare set keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_begin Beginning of the sequence of matches retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename ProbeEqual,
typename ProbeHash,
typename OutputIt>
void find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
* container
Expand Down

0 comments on commit 2ce7048

Please sign in to comment.