-
Notifications
You must be signed in to change notification settings - Fork 644
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5148844
commit c8980ab
Showing
5 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
|
||
import pytest | ||
import torch | ||
|
||
_triton_available = torch.cuda.is_available() | ||
if _triton_available: | ||
try: | ||
from xformers.triton import outer_product_mean | ||
from xformers.triton.utils import gpu_capabilities_older_than_70 | ||
|
||
except ImportError: | ||
logging.warning( | ||
"Triton is not available, some optimizations will not be tested." | ||
) | ||
_triton_available = False | ||
|
||
SHAPES = [(1, 128, 256), (1, 384, 128), (1, 784, 512)] | ||
|
||
|
||
def reference_opm(a, b): | ||
# [*, N_res, N_res, C, C] | ||
outer = torch.einsum("...bac,...dae->...bdce", a, b) | ||
return outer | ||
|
||
|
||
@pytest.mark.skipif(not _triton_available, reason="Triton is not available") | ||
@pytest.mark.skipif( | ||
not _triton_available or gpu_capabilities_older_than_70(), | ||
reason="Triton requires a SM70+ GPU", | ||
) | ||
@pytest.mark.parametrize("shape", SHAPES) | ||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) | ||
def test_triton_outer_product_mean(shape, dtype): | ||
a = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) | ||
b = torch.rand(shape, dtype=dtype, device=torch.device("cuda")) | ||
|
||
ref_opm = reference_opm(a, b) # noqa | ||
triton_opm = outer_product_mean( | ||
a.transpose(-2, -1), b.transpose(-2, -1), average=False | ||
) # noqa | ||
|
||
assert torch.allclose(ref_opm, triton_opm, rtol=0.01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import Any, Dict | ||
|
||
import torch | ||
import triton | ||
|
||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print | ||
from xformers.triton import outer_product_mean | ||
|
||
SHAPES = [ | ||
(1, 256, 256), | ||
(1, 512, 512), | ||
(1, 1024, 1024), | ||
(1, 2048, 2048), | ||
(1, 4096, 4096), | ||
] | ||
|
||
|
||
def to_gbs_fw(a, b, ms): | ||
# Read the two arrays, write the consolidated version | ||
return ( | ||
(a.numel() + b.numel() + a.shape[-1] * b.shape[-1]) * a.element_size() * 1e-9 | ||
) / (ms * 1e-3) | ||
|
||
|
||
def bench_outer_product_mean(avg): | ||
device = torch.device("cuda") | ||
|
||
for dtype in [ | ||
torch.float16, | ||
torch.float32, | ||
]: | ||
results: Dict[str, Any] = {} | ||
|
||
for B, M, K in SHAPES: | ||
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=False) | ||
b = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=False) | ||
|
||
def torch_step(x, y): | ||
z = torch.einsum("...bac,...dae->...bdce", x, y) | ||
if avg: | ||
return z / x.shape[-2] | ||
return z | ||
|
||
def triton_step(x, y): | ||
return outer_product_mean(x, y, average=avg) | ||
|
||
for testcase in [ | ||
TestCase( | ||
torch_step, | ||
"pytorch - avg{}".format(avg), | ||
), | ||
TestCase( | ||
triton_step, | ||
"triton - avg{}".format(avg), | ||
), | ||
]: | ||
time = triton.testing.do_bench(lambda: testcase.function(a, b))[0] | ||
key = f"B={B}, M={M}, K={K}" | ||
if key not in results: | ||
results[key] = {} | ||
|
||
# Record BW | ||
bandwidth = to_gbs_fw(a, b, time) | ||
results[key][testcase.name] = f"{bandwidth:.1f}" | ||
|
||
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s") | ||
pretty_plot( | ||
results, | ||
title="OuterProduct-AVG{}-{}".format(avg, dtype), | ||
units="GB/s", | ||
dash_key="pytorch", | ||
) | ||
|
||
|
||
for avg in [False, True]: | ||
bench_outer_product_mean(avg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) SEQeta, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
|
||
# fmt: off | ||
@triton.autotune( | ||
configs=[ | ||
triton.Config({"BLOCK_I": 16, "BLOCK_J": 16}, num_stages=5, num_warps=1), | ||
triton.Config({"BLOCK_I": 32, "BLOCK_J": 32}, num_stages=5, num_warps=1), | ||
triton.Config({"BLOCK_I": 64, "BLOCK_J": 32}, num_stages=5, num_warps=2), | ||
triton.Config({"BLOCK_I": 32, "BLOCK_J": 64}, num_stages=5, num_warps=2), | ||
triton.Config({"BLOCK_I": 128, "BLOCK_J": 64}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_I": 64, "BLOCK_J": 128}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_I": 128, "BLOCK_J": 128}, num_stages=4, num_warps=4), | ||
], | ||
key=["S", "I", "J"], | ||
) | ||
@triton.jit | ||
def k_outer_product_mean( | ||
OUT, # out ptr | ||
A, B, # in ptrs | ||
S, I, J, # dims # noqa | ||
**META, # Optional SEQeta-paraSEQeters for the kernel | ||
): | ||
""" | ||
Implements Algorithm 10 in the supplementary data of | ||
"Highly accurate protein structure prediction with AlphaFold", | ||
Jumper et al. (https://doi.org/10.1038/s41586-021-03819-2) | ||
The notations are preserved, in that we'll compute the outer product in between | ||
A(i, s) and B(j, s), and then mean over s. | ||
Note that s and (i, j) are flipped with respect to the paper, which | ||
helps handling extra dimensions. | ||
Args: | ||
OUT (I, J) | ||
A (I, S) | ||
B (J, S) | ||
""" | ||
# fmt: on | ||
|
||
# Each kernel owns a M line, | ||
# and a tile over I and J to help with coefficient reuse | ||
# We process M in chunks | ||
BLOCK_I = META["BLOCK_I"] | ||
BLOCK_J = META["BLOCK_J"] | ||
GROUP_S = META["GROUP_S"] | ||
|
||
i_id = tl.program_id(axis=0) * BLOCK_I | ||
j_id = tl.program_id(axis=1) * BLOCK_J | ||
|
||
# matrix containing the current state [SEQ, DIM] matrix | ||
running_mean = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32) | ||
|
||
# Offset by batch size | ||
rn_i = tl.arange(0, BLOCK_I) + i_id | ||
rn_j = tl.arange(0, BLOCK_J) + j_id | ||
rn_s = tl.arange(0, GROUP_S) | ||
scale = 1. / S | ||
|
||
i = 0 | ||
for _ in range(S, 0, -GROUP_S): | ||
rs = rn_s + i * GROUP_S | ||
a_ptrs = A + rn_i[:, None] * S + rs[None, :] | ||
b_ptrs = B + rn_j[None, :] * S + rs[:, None] | ||
|
||
a = tl.load(a_ptrs, mask=((rn_i[:, None] < I) & (rs[None, :] < S)), other=0.0) | ||
b = tl.load(b_ptrs, mask=((rs[:, None] < S) & (rn_j[None, :] < J)), other=0.0) | ||
|
||
# This will sum over S directly | ||
outer_prod = tl.dot(a, b).to(tl.float32) | ||
|
||
# Sum over S | ||
if META["AVERAGE"]: | ||
running_mean += outer_prod * scale | ||
else: | ||
running_mean += outer_prod | ||
|
||
i += 1 | ||
|
||
# We're done for this chunk, save the results | ||
out_ptr = OUT + rn_i[:, None] * J + rn_j[None, :] | ||
tl.store(out_ptr, running_mean, mask=(rn_i[:, None] < I) & (rn_j[None, :] < J)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import triton | ||
|
||
from xformers.triton.k_outer_product_mean import k_outer_product_mean | ||
|
||
|
||
def outer_product_mean(a, b, average: bool = True): | ||
""" | ||
Implements Algorithm 10 in the supplementary data of | ||
"Highly accurate protein structure prediction with AlphaFold", | ||
Jumper et al. (https://doi.org/10.1038/s41586-021-03819-2) | ||
The notations are preserved, in that we'll compute the outer product in between | ||
A(s, i) and B(s, j), and then mean over s | ||
""" | ||
|
||
# Make sure that we're in the known [i, s] and [j, s] configuration | ||
assert a.shape[-1] == b.shape[-1] | ||
assert a.ndim == b.ndim | ||
|
||
if a.ndim > 2: | ||
a_ = a.reshape(-1, a.shape[-1]) | ||
else: | ||
a_ = a | ||
|
||
if b.ndim > 2: | ||
b_ = b.reshape(-1, b.shape[-1]) | ||
else: | ||
b_ = b | ||
|
||
if not a_.is_contiguous(): | ||
a_ = a_.contiguous() | ||
|
||
if not b_.is_contiguous(): | ||
b_ = b_.contiguous() | ||
|
||
I, S = a_.shape # noqa # "ambiguous variable name I -> keeping the paper notations" | ||
J, _ = b_.shape | ||
|
||
outputs = torch.empty((I, J), device=a.device, dtype=a.dtype) | ||
|
||
def grid(META): | ||
return ( | ||
triton.cdiv(I, META["BLOCK_I"]), | ||
triton.cdiv(J, META["BLOCK_J"]), | ||
) | ||
|
||
# fmt: off | ||
k_outer_product_mean[grid]( | ||
outputs, a_, b_, | ||
S, I, J, | ||
GROUP_S=32, | ||
AVERAGE=average) | ||
# fmt: on | ||
|
||
return outputs.reshape(a.shape[0], a.shape[-2], b.shape[-2]) |