Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve gemm code and performance #1541

Merged
merged 10 commits into from
Feb 10, 2024
Merged

Improve gemm code and performance #1541

merged 10 commits into from
Feb 10, 2024

Conversation

oleksandr-pavlyk
Copy link
Collaborator

Refactoring of gemm implementation, adding faster gemm kernel implementation.

This change gets rid of all non-batch functors, modularizes
duplicated code, and implement non-batches functions as calls
to batched functors with trivial constexpr batch indexer.

This change also adds faster gemm kernel that threads of N,M space,
and accumulates entire range of K in single work-item.

Dispatch logic changed too, we dispatch to thead-K kernel only if
(n,m) space is sufficiently small.

A test is added to exercise new kernel for all supported types.

  • Have you provided a meaningful PR description?
  • Have you added a test, reproducer or referred to an issue with a reproducer?
  • Have you tested your changes locally for CPU and GPU devices?
  • Have you made sure that new changes do not introduce compiler warnings?
  • Have you checked performance impact of proposed changes?
  • If this PR is a work in progress, are you opening the PR as a draft?

Also support boolean output for boolean input (NumPy does support it)
This change gets rid of all non-batch functors, modularizes
duplicated code, and implement non-batches functions as calls
to batched functors with trivial constexpr batch indexer.

This change also adds faster gemm kernel that threads of N,M space,
and accumulates entire range of K in single work-item.

Dispatch logic changed too, we dispatch to thead-K kernel only if
(n,m) space is sufficiently small.
Made hyperparameter scaling down logic more lenient.
Copy link

github-actions bot commented Feb 10, 2024

Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞

@coveralls
Copy link
Collaborator

coveralls commented Feb 10, 2024

Coverage Status

coverage: 91.138%. remained the same
when pulling 61ec3d3 on improve-gemm
into 8757289 on master.

Copy link

Array API standard conformance tests for dpctl=0.15.1dev3=py310h15de555_116 ran successfully.
Passed: 907
Failed: 2
Skipped: 86

@ndgrigorian
Copy link
Collaborator

ndgrigorian commented Feb 10, 2024

The change to the complex hyperparameters only slightly impacted performance.

Before

In [4]: x1 = dpt.ones((2000, 2000), dtype="c8")

In [5]: x2 = dpt.ones((2000, 2000), dtype="c8")

In [6]: %timeit z = dpt.matmul(x1, x2)
153 ms ± 17.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [7]: %timeit z = dpt.matmul(x1, x2)
145 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

After

In [2]: x1 = dpt.ones((2000, 2000), dtype="c8")
x
In [3]: x2 = dpt.ones((2000, 2000), dtype="c8")

In [4]: %timeit z = dpt.matmul(x1, x2)
221 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit z = dpt.matmul(x1, x2)
213 ms ± 10.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@ndgrigorian
Copy link
Collaborator

ndgrigorian commented Feb 10, 2024

Complex batched contiguous case appears to be broken.
Inflection point is at exactly 256

In [17]: dpt.matmul(x1, x2)
Out[17]:
usm_ndarray([[[256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j],
              [256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j],
              [256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j],
              ...,
              [256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j],
              [256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j],
              [256.+0.j, 256.+0.j, 256.+0.j, ..., 256.+0.j, 256.+0.j,
               256.+0.j]],

             [[  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              ...,
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j]],

             [[  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              ...,
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j]],

             [[  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              ...,
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j]],

             [[  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              ...,
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j],
              [  0.+0.j,   0.+0.j,   0.+0.j, ...,   0.+0.j,   0.+0.j,
                 0.+0.j]]], dtype=complex64)

In [18]: x2 = dpt.ones((5, 255, 255), dtype="c8")

In [19]: x1 = dpt.ones((5, 255, 255), dtype="c8")

In [20]: dpt.matmul(x1, x2)
Out[20]:
usm_ndarray([[[255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              ...,
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j]],

             [[255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              ...,
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j]],

             [[255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              ...,
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j]],

             [[255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              ...,
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j]],

             [[255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              ...,
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j],
              [255.+0.j, 255.+0.j, 255.+0.j, ..., 255.+0.j, 255.+0.j,
               255.+0.j]]], dtype=complex64)

This appears to happen for int16, uint16, etc. as well.
Does not impact strided case.

There was a lapse in logic in handling batches for contiguous inputs
in new_nm implementation for types that fall into tree_contig impl.

The hyperparameter selection was refined to address observation of a
slowdown for "c8" type inputs. The hyperparameters must be chosen to
keep size of registers needed to store private_C matrix the same.

This is now accomplished using constexpr selector helper class.

Few typos were fixed discovered during debugging that resulted in
unreferenced errors (passed n, m, k arguments instead of expected
n, k, m).
Copy link

Array API standard conformance tests for dpctl=0.15.1dev3=py310h15de555_118 ran successfully.
Passed: 908
Failed: 1
Skipped: 86

@oleksandr-pavlyk
Copy link
Collaborator Author

Complex batched contiguous case appears to be broken. Inflection point is at exactly 256>
This appears to happen for int16, uint16, etc. as well. Does not impact strided case.

Thanks @ndgrigorian . This was caused by an oversight, and would impact all contiguous inputs that would normally dispatch to tree implementation. This is fixed now.

Copy link
Collaborator

@ndgrigorian ndgrigorian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested these changes out for various inputs, including rewriting tests locally to test batched cases in particular. In the future, it may be sensible to add more tests for the batched cases.

For now though, I think this should be fine to go in.

Approved, thank you for going to all of this work @oleksandr-pavlyk !

@oleksandr-pavlyk oleksandr-pavlyk merged commit 7e798a7 into master Feb 10, 2024
49 checks passed
@oleksandr-pavlyk oleksandr-pavlyk deleted the improve-gemm branch February 10, 2024 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants