-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve gemm code and performance #1541
Conversation
Also support boolean output for boolean input (NumPy does support it)
This change gets rid of all non-batch functors, modularizes duplicated code, and implement non-batches functions as calls to batched functors with trivial constexpr batch indexer. This change also adds faster gemm kernel that threads of N,M space, and accumulates entire range of K in single work-item. Dispatch logic changed too, we dispatch to thead-K kernel only if (n,m) space is sufficiently small.
Made hyperparameter scaling down logic more lenient.
Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞 |
Array API standard conformance tests for dpctl=0.15.1dev3=py310h15de555_116 ran successfully. |
The change to the complex hyperparameters only slightly impacted performance. Before
After
|
Complex batched contiguous case appears to be broken.
This appears to happen for int16, uint16, etc. as well. |
There was a lapse in logic in handling batches for contiguous inputs in new_nm implementation for types that fall into tree_contig impl. The hyperparameter selection was refined to address observation of a slowdown for "c8" type inputs. The hyperparameters must be chosen to keep size of registers needed to store private_C matrix the same. This is now accomplished using constexpr selector helper class. Few typos were fixed discovered during debugging that resulted in unreferenced errors (passed n, m, k arguments instead of expected n, k, m).
Array API standard conformance tests for dpctl=0.15.1dev3=py310h15de555_118 ran successfully. |
Thanks @ndgrigorian . This was caused by an oversight, and would impact all contiguous inputs that would normally dispatch to tree implementation. This is fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested these changes out for various inputs, including rewriting tests locally to test batched cases in particular. In the future, it may be sensible to add more tests for the batched cases.
For now though, I think this should be fine to go in.
Approved, thank you for going to all of this work @oleksandr-pavlyk !
Refactoring of gemm implementation, adding faster gemm kernel implementation.
This change gets rid of all non-batch functors, modularizes
duplicated code, and implement non-batches functions as calls
to batched functors with trivial constexpr batch indexer.
This change also adds faster gemm kernel that threads of N,M space,
and accumulates entire range of K in single work-item.
Dispatch logic changed too, we dispatch to thead-K kernel only if
(n,m) space is sufficiently small.
A test is added to exercise new kernel for all supported types.