Skip to content

Commit

Permalink
Add offloading tests and fix obscure edge case (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 authored Oct 18, 2024
1 parent f8073ed commit 3ca0d30
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 4 deletions.
131 changes: 131 additions & 0 deletions tests/torchtune/training/test_activation_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import gpu_test
from torch import nn
from torchtune.training import OffloadActivations


@gpu_test(gpu_count=1)
@pytest.mark.parametrize("use_streams", [True, False])
def test_offloading_is_same_as_without(use_streams) -> None:
with torch.device("cuda"):
torch.manual_seed(2024)
model = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.ReLU(),
)
torch.manual_seed(2024)
model_c = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.ReLU(),
)

inp = torch.randn((2, 10), device="cuda")
loss = model(inp).sum()
loss.backward()

with OffloadActivations(use_streams=use_streams):
loss_c = model_c(inp).sum()
loss_c.backward()

for param, param_c in zip(model.parameters(), model_c.parameters()):
assert torch.equal(param.grad, param_c.grad)


@gpu_test(gpu_count=1)
def test_offloading_works_with_view_outputs() -> None:
"""
This test is quite contrived but tests against a very obscure situation where
any of the outputs of a backward node are a view of the unpacked tensor.
We want to ensure that if an unpacked tensor may be used later that we do not
free it too early.
How did we contrive this test? We need the backward to execute as so:
1. We first need a node that unpacks a tensor and returns a view of the tensor
2. The next node just needs to pass that view along--this NoOp node is needed
to bypass our heuristic where we delete the _previous_ node's stash after
executing the current node.
3. We need to allow the tensor to die to be contaminated with new info, and
we need a way to look into the contents of the contaminated tensor. We
separate these into two nodes (because having them in the same node does
not properly let the tensor reference die as it is within scope.) The
"Compute" Node queues up ~1 second of work on CUDA followed by a kernel
evaluating whether dX is full of 1s. The next Node then inspects the
earlier activation and asserts the result of dX == 1, which is a sync!
Note that for the backward to execute in the above order, the fwd was made
to execute in reverse order.
"""

class BwdReturnsViewOfActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, cloned_activation):
cloned_activation = cloned_activation.t()
ctx.save_for_backward(cloned_activation)
return torch.rand(2, 4, device="cuda")

@staticmethod
def backward(ctx, dy):
unpacked_activation = ctx.saved_tensors[0]
return unpacked_activation.t()

class NoOp(torch.autograd.Function):
@staticmethod
def forward(ctx, cloned_activation):
ctx.save_for_backward(cloned_activation)
return cloned_activation.clone()

@staticmethod
def backward(ctx, viewed_activation):
rando_activation = ctx.saved_tensors[0]
return viewed_activation

class ComputeNode(torch.autograd.Function):
@staticmethod
def forward(ctx, activation):
return activation.clone()

@staticmethod
def backward(ctx, viewed_activation):
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles
return viewed_activation == 1

class InspectEarlierActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, activation):
ctx.save_for_backward(torch.ones_like(activation) * 5)
return activation

@staticmethod
def backward(ctx, viewed_activation_all_1):
corrupter = ctx.saved_tensors[0]
assert torch.all(
viewed_activation_all_1
) # is the same as before (1s) and NOT W (5s)!!
return corrupter

def fwd(t):
a = InspectEarlierActivation.apply(t)
b = ComputeNode.apply(a)
c = NoOp.apply(b)
d = BwdReturnsViewOfActivation.apply(c)
return d.sum()

tensor_c = torch.ones(256, 1024, device="cuda", requires_grad=True)
ctx = OffloadActivations(use_streams=True)
with ctx:
loss_c = fwd(tensor_c)
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss_c.backward()
26 changes: 22 additions & 4 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,12 @@ def pack_tensor(activation: torch.Tensor) -> int:
num_bytes = get_num_bytes_tensor(activation)
tensor_id = get_tensor_id()

# only offload hefty bois
if num_bytes >= self.min_tensor_size_bytes:
# only offload hefty bois if they're activations (our heuristic for that is to
# check if they're not params or buffers)!
if num_bytes >= self.min_tensor_size_bytes and (
not isinstance(activation, torch.nn.Parameter)
and not isinstance(activation, torch.nn.Buffer)
):
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
Expand Down Expand Up @@ -281,8 +285,22 @@ def wait_and_del_remaining_references() -> None:
def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
event = self.s0.record_event()
self.bwd_ev_stash[unpack_tensor_id] = event
# if any of the outputs is a view of the tensor, meaning the tensor might be used later,
# we cannot presume to delete it after only the current node is done! So we use our frenemy,
# record_stream, to ensure the Tensor stays unmessed with until it's done getting used
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic
# memory usage, but this case should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
if o is not None
):
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
else:
event = self.s0.record_event()
self.bwd_ev_stash[unpack_tensor_id] = event

# if there are still things in the fwd_stash, get rid of them as we're in bwd now
for id in [k for k in self.fwd_stash.keys()]:
Expand Down

0 comments on commit 3ca0d30

Please sign in to comment.