Skip to content

Commit

Permalink
fixing mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Yuan committed Jul 25, 2022
1 parent 568c09a commit 0e3275d
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 39 deletions.
42 changes: 22 additions & 20 deletions tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
)
from xformers.components.nvfuser.utils import build_nvfused

FUSED_PATTERNS = (
[
NVFusedBiasActivationDropout,
NVFusedBiasDropoutRes,
NVFusedBiasDropoutResLayerNorm,
]
if xformers._is_functorch_available
else []
)

# Testing odd (non-power-of-two for instance) shapes on purpose
SHAPES = [
(384, 512),
Expand All @@ -41,32 +51,24 @@
LATENT = 128
DEVICES = [torch.device("cuda")]

ACTIVATIONS = [
Activation.ReLU,
Activation.GeLU,
Activation.LeakyReLU,
Activation.SquaredReLU,
Activation.SmeLU,
]


@pytest.mark.skipif(not _gpu_available, reason="GPU is not available")
@pytest.mark.skipif(
not xformers._is_functorch_available, reason="Functorch is not available"
)
@pytest.mark.parametrize(
"fused_pattern",
[
NVFusedBiasActivationDropout,
NVFusedBiasDropoutRes,
NVFusedBiasDropoutResLayerNorm,
],
)
@pytest.mark.skipif(not _gpu_available, reason="GPU is not available")
@pytest.mark.parametrize("fused_pattern", FUSED_PATTERNS)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize(
"activation",
[
Activation.ReLU,
Activation.GeLU,
Activation.LeakyReLU,
Activation.SquaredReLU,
Activation.SmeLU,
],
)
@pytest.mark.parametrize("activation", ACTIVATIONS)
@pytest.mark.parametrize("p", [0, 0.1, 0.5])
@pytest.mark.parametrize("layer_norm_style", [LayerNormStyle.Pre, LayerNormStyle.Post])
def test_nvfused_pattern_parity(
Expand Down Expand Up @@ -118,7 +120,7 @@ def test_nvfused_pattern_parity(
@pytest.mark.skipif(
not xformers._is_functorch_available, reason="Functorch is not available"
)
@pytest.mark.parametrize("activation", [Activation.ReLU, Activation.GeLU])
@pytest.mark.parametrize("activation", ACTIVATIONS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("p", [0, 0.1, 0.5])
def test_nvfused_mlp(activation: Activation, device: torch.device, p: float):
Expand Down
6 changes: 3 additions & 3 deletions xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
# Please update the doc version in docs/source/conf.py as well.
__version__ = "0.0.12.dev"

_is_sparse_available = True
_is_triton_available = torch.cuda.is_available()
_is_sparse_available: bool = True
_is_triton_available: bool = torch.cuda.is_available()

# Set to true to utilize functorch
_is_functorch_available = False
_is_functorch_available: bool = False


def _register_extensions():
Expand Down
19 changes: 11 additions & 8 deletions xformers/benchmarks/benchmark_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def bench_nvfused(
):
device = torch.device("cuda")

pattern_str = {
pattern_str = { # noqa
NVFusedBiasActivationDropout: "Bias_Act_Dropout",
NVFusedBiasDropoutRes: "Bias_Dropout_Res",
NVFusedBiasDropoutResLayerNorm: "Bias_Dropout_Res_LayerNorm",
}[fused_pattern]
}[
fused_pattern # noqa
]

for dtype in [
torch.float16,
Expand Down Expand Up @@ -181,13 +183,13 @@ def step(fn, residual, x):
for testcase in testcases:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
# torch.cuda.synchronize()

time = triton.testing.do_bench(
lambda: testcase.function(x=a), grad_to_none=[a, b]
)[0]

torch.cuda.synchronize()
# torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2**20

key = f"B={B}, M={M}, K={K}"
Expand Down Expand Up @@ -243,12 +245,13 @@ def step(fn, residual, x):
)


# for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for pattern in [
PATTERNS = [
NVFusedBiasActivationDropout,
NVFusedBiasDropoutRes,
NVFusedBiasDropoutResLayerNorm,
]:
]

for pattern in PATTERNS:
activations: List[Optional[Activation]] = (
[Activation.ReLU, Activation.GeLU, Activation.SquaredReLU]
if pattern == NVFusedBiasActivationDropout
Expand All @@ -263,4 +266,4 @@ def step(fn, residual, x):
else [None]
)
for style in styles:
bench_nvfused(pattern, bias, bw, activation, style)
bench_nvfused(pattern, bias, bw, activation, style) # noqa
1 change: 1 addition & 0 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xformers.components.attention import Attention, AttentionConfig, register_attention

if _is_triton_available:

from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion xformers/components/nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from xformers import _is_functorch_available

if _is_functorch_available:
if _is_functorch_available: # noqa
try:
from .bias_act_dropout import NVFusedBiasActivationDropout # noqa
from .bias_dropout_res import NVFusedBiasDropoutRes # noqa
Expand Down
5 changes: 4 additions & 1 deletion xformers/components/nvfuser/bias_act_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@


def _fn(
x: torch.Tensor, bias: Optional[torch.Tensor], activation: nn.Module, prob: float
x: torch.Tensor,
bias: Optional[torch.nn.parameter.Parameter],
activation: nn.Module,
prob: float,
) -> torch.Tensor:
if bias is not None:
x = torch.add(x, bias)
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/nvfuser/bias_dropout_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def _fn(
x: torch.Tensor,
bias: Optional[torch.Tensor],
bias: Optional[torch.nn.parameter.Parameter],
prob: float,
residual: torch.Tensor,
) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/nvfuser/bias_dropout_res_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def _fn(
x: torch.Tensor,
bias: Optional[torch.Tensor],
bias: Optional[torch.nn.parameter.Parameter],
prob: float,
layer_norm_style: Optional[LayerNormStyle],
norm: nn.Module,
Expand Down
13 changes: 9 additions & 4 deletions xformers/components/nvfuser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ def build_nvfused(
):
bias_shape = shape[-1] if bias else None
d_model = shape[-1]
init_args: Dict[nn.Module, List[Any]] = {
NVFusedBiasActivationDropout: [p, activation, bias_shape],
NVFusedBiasDropoutRes: [p, bias_shape],
NVFusedBiasDropoutResLayerNorm: [p, d_model, bias_shape, layer_norm_style],
init_args: Dict[nn.Module, List[Any]] = { # noqa
NVFusedBiasActivationDropout: [p, activation, bias_shape], # noqa
NVFusedBiasDropoutRes: [p, bias_shape], # noqa
NVFusedBiasDropoutResLayerNorm: [
p,
d_model,
bias_shape,
layer_norm_style,
], # noqa
}
return fused_pattern(*init_args[fused_pattern])

0 comments on commit 0e3275d

Please sign in to comment.