Skip to content

Commit

Permalink
First PoC
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 20, 2021
1 parent 5148844 commit c8980ab
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/test_triton_outer_product_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging

import pytest
import torch

_triton_available = torch.cuda.is_available()
if _triton_available:
try:
from xformers.triton import outer_product_mean
from xformers.triton.utils import gpu_capabilities_older_than_70

except ImportError:
logging.warning(
"Triton is not available, some optimizations will not be tested."
)
_triton_available = False

SHAPES = [(1, 128, 256), (1, 384, 128), (1, 784, 512)]


def reference_opm(a, b):
# [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)
return outer


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
not _triton_available or gpu_capabilities_older_than_70(),
reason="Triton requires a SM70+ GPU",
)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_triton_outer_product_mean(shape, dtype):
a = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
b = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))

ref_opm = reference_opm(a, b) # noqa
triton_opm = outer_product_mean(
a.transpose(-2, -1), b.transpose(-2, -1), average=False
) # noqa

assert torch.allclose(ref_opm, triton_opm, rtol=0.01)
82 changes: 82 additions & 0 deletions xformers/benchmarks/benchmark_triton_outer_prod_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton import outer_product_mean

SHAPES = [
(1, 256, 256),
(1, 512, 512),
(1, 1024, 1024),
(1, 2048, 2048),
(1, 4096, 4096),
]


def to_gbs_fw(a, b, ms):
# Read the two arrays, write the consolidated version
return (
(a.numel() + b.numel() + a.shape[-1] * b.shape[-1]) * a.element_size() * 1e-9
) / (ms * 1e-3)


def bench_outer_product_mean(avg):
device = torch.device("cuda")

for dtype in [
torch.float16,
torch.float32,
]:
results: Dict[str, Any] = {}

for B, M, K in SHAPES:
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=False)
b = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=False)

def torch_step(x, y):
z = torch.einsum("...bac,...dae->...bdce", x, y)
if avg:
return z / x.shape[-2]
return z

def triton_step(x, y):
return outer_product_mean(x, y, average=avg)

for testcase in [
TestCase(
torch_step,
"pytorch - avg{}".format(avg),
),
TestCase(
triton_step,
"triton - avg{}".format(avg),
),
]:
time = triton.testing.do_bench(lambda: testcase.function(a, b))[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}

# Record BW
bandwidth = to_gbs_fw(a, b, time)
results[key][testcase.name] = f"{bandwidth:.1f}"

pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
pretty_plot(
results,
title="OuterProduct-AVG{}-{}".format(avg, dtype),
units="GB/s",
dash_key="pytorch",
)


for avg in [False, True]:
bench_outer_product_mean(avg)
2 changes: 2 additions & 0 deletions xformers/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
from .layer_norm import FusedLayerNorm, layer_norm # noqa
from .outer_product_mean import outer_product_mean # noqa
from .softmax import log_softmax, softmax # noqa

__all__ = [
Expand All @@ -22,6 +23,7 @@
"FusedLinear",
"FusedLayerNorm",
"layer_norm",
"outer_product_mean",
]
except ImportError:
__all__ = []
88 changes: 88 additions & 0 deletions xformers/triton/k_outer_product_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) SEQeta, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import triton
import triton.language as tl


# fmt: off
@triton.autotune(
configs=[
triton.Config({"BLOCK_I": 16, "BLOCK_J": 16}, num_stages=5, num_warps=1),
triton.Config({"BLOCK_I": 32, "BLOCK_J": 32}, num_stages=5, num_warps=1),
triton.Config({"BLOCK_I": 64, "BLOCK_J": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_I": 32, "BLOCK_J": 64}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_I": 128, "BLOCK_J": 64}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_I": 64, "BLOCK_J": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_I": 128, "BLOCK_J": 128}, num_stages=4, num_warps=4),
],
key=["S", "I", "J"],
)
@triton.jit
def k_outer_product_mean(
OUT, # out ptr
A, B, # in ptrs
S, I, J, # dims # noqa
**META, # Optional SEQeta-paraSEQeters for the kernel
):
"""
Implements Algorithm 10 in the supplementary data of
"Highly accurate protein structure prediction with AlphaFold",
Jumper et al. (https://doi.org/10.1038/s41586-021-03819-2)
The notations are preserved, in that we'll compute the outer product in between
A(i, s) and B(j, s), and then mean over s.
Note that s and (i, j) are flipped with respect to the paper, which
helps handling extra dimensions.
Args:
OUT (I, J)
A (I, S)
B (J, S)
"""
# fmt: on

# Each kernel owns a M line,
# and a tile over I and J to help with coefficient reuse
# We process M in chunks
BLOCK_I = META["BLOCK_I"]
BLOCK_J = META["BLOCK_J"]
GROUP_S = META["GROUP_S"]

i_id = tl.program_id(axis=0) * BLOCK_I
j_id = tl.program_id(axis=1) * BLOCK_J

# matrix containing the current state [SEQ, DIM] matrix
running_mean = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32)

# Offset by batch size
rn_i = tl.arange(0, BLOCK_I) + i_id
rn_j = tl.arange(0, BLOCK_J) + j_id
rn_s = tl.arange(0, GROUP_S)
scale = 1. / S

i = 0
for _ in range(S, 0, -GROUP_S):
rs = rn_s + i * GROUP_S
a_ptrs = A + rn_i[:, None] * S + rs[None, :]
b_ptrs = B + rn_j[None, :] * S + rs[:, None]

a = tl.load(a_ptrs, mask=((rn_i[:, None] < I) & (rs[None, :] < S)), other=0.0)
b = tl.load(b_ptrs, mask=((rs[:, None] < S) & (rn_j[None, :] < J)), other=0.0)

# This will sum over S directly
outer_prod = tl.dot(a, b).to(tl.float32)

# Sum over S
if META["AVERAGE"]:
running_mean += outer_prod * scale
else:
running_mean += outer_prod

i += 1

# We're done for this chunk, save the results
out_ptr = OUT + rn_i[:, None] * J + rn_j[None, :]
tl.store(out_ptr, running_mean, mask=(rn_i[:, None] < I) & (rn_j[None, :] < J))
61 changes: 61 additions & 0 deletions xformers/triton/outer_product_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch
import triton

from xformers.triton.k_outer_product_mean import k_outer_product_mean


def outer_product_mean(a, b, average: bool = True):
"""
Implements Algorithm 10 in the supplementary data of
"Highly accurate protein structure prediction with AlphaFold",
Jumper et al. (https://doi.org/10.1038/s41586-021-03819-2)
The notations are preserved, in that we'll compute the outer product in between
A(s, i) and B(s, j), and then mean over s
"""

# Make sure that we're in the known [i, s] and [j, s] configuration
assert a.shape[-1] == b.shape[-1]
assert a.ndim == b.ndim

if a.ndim > 2:
a_ = a.reshape(-1, a.shape[-1])
else:
a_ = a

if b.ndim > 2:
b_ = b.reshape(-1, b.shape[-1])
else:
b_ = b

if not a_.is_contiguous():
a_ = a_.contiguous()

if not b_.is_contiguous():
b_ = b_.contiguous()

I, S = a_.shape # noqa # "ambiguous variable name I -> keeping the paper notations"
J, _ = b_.shape

outputs = torch.empty((I, J), device=a.device, dtype=a.dtype)

def grid(META):
return (
triton.cdiv(I, META["BLOCK_I"]),
triton.cdiv(J, META["BLOCK_J"]),
)

# fmt: off
k_outer_product_mean[grid](
outputs, a_, b_,
S, I, J,
GROUP_S=32,
AVERAGE=average)
# fmt: on

return outputs.reshape(a.shape[0], a.shape[-2], b.shape[-2])

0 comments on commit c8980ab

Please sign in to comment.