Skip to content

Commit

Permalink
Add a unit test to check memory use, needs improvements..
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 22, 2021
1 parent 0329d4f commit a47ef5b
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions tests/test_memory_efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

# Testing odd shapes on purpose
SHAPES = [
(384, 256),
(1, 384, 128),
(8, 384, 128),
(8, 784, 512),
(2, 2048, 384),
# (384, 256),
# (1, 384, 128),
# (8, 384, 128),
# (8, 784, 512),
# (2, 2048, 384),
(4, 3136, 1024),
(2, 1024, 1024),
]
Expand Down Expand Up @@ -63,3 +63,38 @@ def test_mem_efficient_attention_parity(shape, dtype):
# assert torch.allclose(res_pytorch, res_me, rtol=1e-1) FIXME
# TODO: test different sequence lengths for q and k
# TODO: check parity with normal attention


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
not _triton_available or gpu_capabilities_older_than_70(),
reason="Triton requires a SM70+ GPU",
)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
def test_mem_efficient_attention_memory_use(shape, dtype):
# FW a random bunch of data
q = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
k = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))
v = torch.rand(shape, dtype=dtype, device=torch.device("cuda"))

# Vanilla attention
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

_ = attention_pytorch(q, k, v)
torch.cuda.synchronize()
max_memory_torch = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Dense - Peak memory use: {max_memory_torch}MB")

# Mem efficient attention
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

_ = mem_efficient_attention.apply(q, k, v, None)
torch.cuda.synchronize()

max_memory_me = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Memory efficient - Peak memory use: {max_memory_me}MB")

assert max_memory_me <= max_memory_torch

0 comments on commit a47ef5b

Please sign in to comment.