Skip to content

Commit

Permalink
Fixing the randomness problem, updating the graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 23, 2021
1 parent b0d5f91 commit 13baac5
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 5 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_dropout_cpu():
x = torch.normal(0, 1, size=(16, 16), device="cpu")
_ = triton_dropout(x)

# Check eval means no dropout
triton_dropout.eval()
y = triton_dropout(x)
assert y.count_nonzero() == y.numel()

triton_dropout.train()
y = triton_dropout(x)
assert y.count_nonzero() != y.numel()


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand Down
15 changes: 10 additions & 5 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
from xformers.triton.sum_strided import sum_2d_dim_0

# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64)
# for the random numbers to be good enough
GROUP_M = 16
BLOCK_M = GROUP_M // 4
BLOCK_N = 128
BLOCK_N = 64


# Helper to handle the SPMD launch grid and error cases
Expand Down Expand Up @@ -61,7 +63,8 @@ def grid(meta):
USE_BIAS=bias is not None,
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N
BLOCK_N=BLOCK_N,
num_warps=2
)
# fmt: on

Expand Down Expand Up @@ -135,7 +138,7 @@ def grid(meta):
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=8
num_warps=2
)
# fmt: on

Expand Down Expand Up @@ -200,17 +203,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.bias is not None: # type: ignore
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

# Train/inference
p = self.p if self.training else 0.0

# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512

# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda or not perf_check:
x = x + self.bias if self.bias is not None else x
x = self.pytorch_activation(x)
return torch.nn.functional.dropout(x, self.p)
return torch.nn.functional.dropout(x, p)

# The normal, Triton-backed path
p = self.p if self.training else 0.0
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)

0 comments on commit 13baac5

Please sign in to comment.