diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png index 088ed16884..c4ebfb507c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png index f818ff4552..fb483ce0ff 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png index 82f6c035c6..2b46f2b0a9 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png index 022e73330d..4995a7c0d2 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png index a23fe0baf7..044564e4d2 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png index 8f466eb95f..7a8bcab0c0 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png index 16f9e90975..a27b8809d7 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png index dd90d79c87..05f51ac367 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png index ae7c4d533a..b319f08279 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png index b84c860197..d18ec43b40 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png differ diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 108145e572..e2e5e05040 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -44,6 +44,15 @@ def test_dropout_cpu(): x = torch.normal(0, 1, size=(16, 16), device="cpu") _ = triton_dropout(x) + # Check eval means no dropout + triton_dropout.eval() + y = triton_dropout(x) + assert y.count_nonzero() == y.numel() + + triton_dropout.train() + y = triton_dropout(x) + assert y.count_nonzero() != y.numel() + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 4c8571de99..64c12f6b84 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -21,9 +21,11 @@ from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw from xformers.triton.sum_strided import sum_2d_dim_0 +# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64) +# for the random numbers to be good enough GROUP_M = 16 BLOCK_M = GROUP_M // 4 -BLOCK_N = 128 +BLOCK_N = 64 # Helper to handle the SPMD launch grid and error cases @@ -61,7 +63,8 @@ def grid(meta): USE_BIAS=bias is not None, ACTIVATION=activation, BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N + BLOCK_N=BLOCK_N, + num_warps=2 ) # fmt: on @@ -135,7 +138,7 @@ def grid(meta): TRAINABLE_BIAS=ctx.trainable_bias, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=8 + num_warps=2 ) # fmt: on @@ -200,6 +203,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.bias is not None: # type: ignore self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore + # Train/inference + p = self.p if self.training else 0.0 + # This kernel is slower than pytorch for small buffers, bypassing it in that case perf_check = x.shape[-1] > 512 @@ -207,10 +213,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not x.is_cuda or not perf_check: x = x + self.bias if self.bias is not None else x x = self.pytorch_activation(x) - return torch.nn.functional.dropout(x, self.p) + return torch.nn.functional.dropout(x, p) # The normal, Triton-backed path - p = self.p if self.training else 0.0 return _dropout.apply( x, p, self.bias, self.activation, self.activation_grad, True )