-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements top_k
functions in dpctl.tensor
#1921
base: master
Are you sure you want to change the base?
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1921/index.html |
Array API standard conformance tests for dpctl=0.19.0dev0=py310hdf72452_296 ran successfully. |
@ndgrigorian Please add diff --git a/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst b/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst
index ae1605d988..ef20f4654c 100644
--- a/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst
+++ b/docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst
@@ -10,3 +10,4 @@ Sorting functions
argsort
sort
+ top_k |
79b97d9
to
882c70d
Compare
Array API standard conformance tests for dpctl=0.19.0dev0=py310hdf72452_295 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310hdf72452_297 ran successfully. |
a56e21c
to
26718f3
Compare
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_326 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_327 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_331 ran successfully. |
8bcb100
to
8f38b80
Compare
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_331 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_331 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_388 ran successfully. |
775b423
to
84d1388
Compare
The implementation leverages existing merge-sort code, and partially sorts the array in cases where a parial sort reduces the size of temporary memory allocation
Reduces amount of casting. `k` will need to fit in `py::ssize_t` regardless.
Instead of using an overload to handle the `axis=None` case, use std::optional and check for trailing_dims_to_search in validation logic
Factored out map_back_impl projects indexing from flat index to a row-wise index. Removed dead code excluded by preprocessor conditional.
Replaced it with hand-written implementation of ceil_log2(n), such that n <= (dectype(n){1} << ceil_log2(n)) is true for all positive values of `n` in the range.
Add check of computed against expected indices
One asserts that at least one unique pointer is specified. Another that specified arguments are unique pointers with USMDeleter.
84d1388
to
809cb70
Compare
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_391 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_386 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_385 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_386 ran successfully. |
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_387 ran successfully. |
@oleksandr-pavlyk |
Good to see the CI green again! I was suggesting to only skip it in the test command used in the workflow so that we can still provide a reproducer to the CPU team. I was thinking we could add a file of tests to skip, and pass it as argument to I think this is the approach taken by |
…tly SYCL bundle DPC++ compiler
gid-lane_id is already a multiple of sg_size.
Change kernel to process few data elements in the work-item.
Counters can not exceed uint16_t max, because the kernel assumes that the number of elements to sort fits into uint16_t. The change reduces the kernel SLM footprint. Also, remove use of std::move, uint16_t->std::uint16_t, etc Replace size_t->std::size_t, uint32_t->std::uint32_t Use `if constexpr` in order-preservign-cast for better readability.
The team developing OpenCL:CPU device runtime and compiler was notified. See CMPLRLLVM-64592 Once fixed, the work-around should be removed.
was applied in C++. Add tests for 2d input arrays, for axis=0 and axis=1 Add a test for non-contiguous input, 0d input, validation 100% coverage of top_k function implementation achieved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is ready to go in.
Thank you for working on this feature, @ndgrigorian
Any additional changes can be deferred to future PRs.
Array API standard conformance tests for dpctl=0.19.0dev0=py310h93fe807_399 ran successfully. |
I'd suggest rebasing the branch on top of the targeted base branch to remove two cherry-picked commits fixing the workflow for building with nightly DPC++ bundle |
Also, to make GH rule checker happy, add a line to changelog, so that your commit is the last one, otherwise my approve does not count |
This PR implements the functions
top_k
,top_k_indices
, andtop_k_values
as per proposal in array API spec.Radix and merge sorting are used, and modified merge-sort kernels are introduced which sort the array in chunks and write out to a temporary the
k
largest or smallest values.