Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Dropout(Activation(x+bias)), now with partial BW fusion #144

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,25 @@ Note that in the Triton case the slowdowns at extreme sizes are because of regis

![Fused layer norm throughput in fp32 - training](docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png))

### Fused dropout + bias
### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png)
![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png))
![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

## LRA

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
3. Native support for SquaredReLU (on top of ReLU, LeakyReLU, GeLU, ..), extensible activations

### FAQ ?

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def top_k_logits(logits, k):
gpus=1,
max_epochs=EPOCHS,
precision=16,
gradient_clip_val=1,
gradient_clip_val=1, # Use to catch divergent gradients, if experimenting
log_every_n_steps=1,
detect_anomaly=True,
# detect_anomaly=True, # Use to catch NaNs, if experimenting
accumulate_grad_batches=REF_BATCH // BATCH,
)

Expand Down
12 changes: 9 additions & 3 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_dropout_cpu():
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("p", [0, 0.1, 0.5])
def test_dropout(shape, amp, bias, p):
"""
Check some basic dropout properties
"""
Expand Down Expand Up @@ -97,6 +98,11 @@ def test_dropout(shape, amp, bias):
== y.shape[1]
)

# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.1


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand All @@ -107,7 +113,7 @@ def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("p", [0, 0.001, 0.5])
@pytest.mark.parametrize("p", [0, 0.01, 0.5])
def test_dropout_parity(shape, amp, bias, activation, p):
"""
Check some basic dropout properties
Expand Down Expand Up @@ -158,4 +164,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad.norm()}\n{b_.grad.norm()}"
), f"{b.grad.norm()} - {b_.grad.norm()}"
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def torch_step(x):
y = torch_act(y)

if backward:
y.grad = None
torch.norm(y).backward()
return y

def triton_step(x):
y = triton_dropout(x)
if backward:
y.grad = None
torch.norm(y).backward()
return y

Expand Down Expand Up @@ -105,7 +107,7 @@ def triton_step(x):
)


for activation in [Activation.GeLU, None]:
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)
10 changes: 5 additions & 5 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<40}".format(units)
print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))

offset = len(header)
print(
"|{}|".format("-" * offset)
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)

Expand All @@ -44,7 +44,7 @@ def pretty_print(results, title, units):

for k, w in workloads.items():
print(
"|{0:<{offset}}|".format(k, offset=offset)
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.xticks(rotation=45)

plt.savefig(filename, bbox_inches="tight")
plt.clf()
plt.close(f)


if _triton_is_available:
Expand Down
123 changes: 88 additions & 35 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,57 @@
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
from xformers.triton.sum_strided import sum_2d_dim_0

GROUP_M = 16
BLOCK_M = GROUP_M // 4
BLOCK_N = 128


# Helper to handle the SPMD launch grid and error cases
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad):
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the trainable bias (or not) was not properly handled before (bias was always assumed to be trainable, which is mostly true but not always)

# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
_, N = x_.shape

assert bias is None or bias.dtype == x.dtype, bias
M, N = x_.shape

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32)
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)

# SPMD launch grid
def grid(meta):
return (
x_.shape[0],
triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

N_BLOCK_N = triton.cdiv(N, BLOCK_N)

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
# FIXME: adjust the number of seeds needed
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32)

# fmt: off
k_dropout_fw[grid](
y, x_, bias if bias is not None else x_,
y, x_,
bias if bias is not None else x_,
seeds,
y.stride(0),
N,
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N
)
# fmt: on

if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None

ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.p = p

Expand All @@ -76,40 +86,68 @@ def backward(ctx, grad_out):
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)

_, N = grad_out_.shape
M, N = grad_out_.shape

# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None

if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, grad_out.shape[-1])
inputs = inputs.reshape(-1, N)

# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over M, we go by 4 tiles at at time (consequence of the random number generation)
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, GROUP_M)

if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)

else:
grad_bias = grad_in # will not be used

# SPMD launch grid
def grid(meta):
return (
grad_out_.shape[0],
triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

# fmt: off
k_dropout_bw[grid](
grad_in, grad_out_, inputs, bias if bias is not None else inputs,
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
N,
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad)
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=8
)
# fmt: on

if ctx.trainable_bias:
grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in)
else:
grad_bias = None

return grad_in.reshape_as(grad_out), None, grad_bias, None, None
return (
grad_in.reshape_as(grad_out),
None,
sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None,
None,
None,
None,
)


def dropout(
Expand All @@ -129,7 +167,14 @@ def dropout(

act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel)
return _dropout.apply(
x,
p,
bias,
act_kernel,
act_grad_kernel,
bias is not None and bias.requires_grad,
)


class FusedDropoutBias(torch.nn.Module):
Expand All @@ -142,23 +187,31 @@ def __init__(
super().__init__()
self.p = p
self.activation_type = activation
self.register_buffer(
"bias", torch.zeros(bias_shape) if bias_shape is not None else None
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_kernel(activation)
self.pytorch_activation = build_activation(self.activation_type)
self.activation_grad = get_triton_activation_bwd_kernel(activation)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
if self.bias is not None: # type: ignore
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512

# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda:
activation = build_activation(self.activation_type)
if not x.is_cuda or not perf_check:
x = x + self.bias if self.bias is not None else x
x = activation(x)
x = self.pytorch_activation(x)
return torch.nn.functional.dropout(x, self.p)

# The normal, Triton-backed path
p = self.p if self.training else 0.0
return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad)
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)
9 changes: 3 additions & 6 deletions xformers/triton/k_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def relu(x):
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero = 0.0
zero = zero.to(x.dtype)
return tl.where(x >= 0, x, zero)
return tl.where(x >= 0, x, zero.to(x.dtype))


@triton.jit
Expand All @@ -74,10 +73,8 @@ def relu_grad(x):
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
zero = zero.to(x.dtype)
one = 1.0
one = one.to(x.dtype)
return tl.where(x >= 0, one, zero)
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))


@triton.jit
Expand All @@ -88,7 +85,7 @@ def squared_relu(x):
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return x_ * x_
return (x_ * x_).to(x.dtype)


@triton.jit
Expand Down
Loading