-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add offloading tests and fix obscure edge case #1860
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import gpu_test | ||
from torch import nn | ||
from torchtune.training import OffloadActivations | ||
|
||
|
||
@gpu_test(gpu_count=1) | ||
@pytest.mark.parametrize("use_streams", [True, False]) | ||
def test_offloading_is_same_as_without(use_streams) -> None: | ||
with torch.device("cuda"): | ||
torch.manual_seed(2024) | ||
model = nn.Sequential( | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.ReLU(), | ||
) | ||
torch.manual_seed(2024) | ||
model_c = nn.Sequential( | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.Linear(10, 10), | ||
nn.ReLU(), | ||
) | ||
|
||
inp = torch.randn((2, 10), device="cuda") | ||
loss = model(inp).sum() | ||
loss.backward() | ||
|
||
with OffloadActivations(use_streams=use_streams): | ||
loss_c = model_c(inp).sum() | ||
loss_c.backward() | ||
|
||
for param, param_c in zip(model.parameters(), model_c.parameters()): | ||
assert torch.equal(param.grad, param_c.grad) | ||
|
||
|
||
@gpu_test(gpu_count=1) | ||
def test_offloading_works_with_view_outputs() -> None: | ||
""" | ||
This test is quite contrived but tests against a very obscure situation where | ||
any of the outputs of a backward node are a view of the unpacked tensor. | ||
|
||
We want to ensure that if an unpacked tensor may be used later that we do not | ||
free it too early. | ||
|
||
How did we contrive this test? We need the backward to execute as so: | ||
1. We first need a node that unpacks a tensor and returns a view of the tensor | ||
2. The next node just needs to pass that view along--this NoOp node is needed | ||
to bypass our heuristic where we delete the _previous_ node's stash after | ||
executing the current node. | ||
3. We need to allow the tensor to die to be contaminated with new info, and | ||
we need a way to look into the contents of the contaminated tensor. We | ||
separate these into two nodes (because having them in the same node does | ||
not properly let the tensor reference die as it is within scope.) The | ||
"Compute" Node queues up ~1 second of work on CUDA followed by a kernel | ||
evaluating whether dX is full of 1s. The next Node then inspects the | ||
earlier activation and asserts the result of dX == 1, which is a sync! | ||
|
||
Note that for the backward to execute in the above order, the fwd was made | ||
to execute in reverse order. | ||
""" | ||
|
||
class BwdReturnsViewOfActivation(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, cloned_activation): | ||
cloned_activation = cloned_activation.t() | ||
ctx.save_for_backward(cloned_activation) | ||
return torch.rand(2, 4, device="cuda") | ||
|
||
@staticmethod | ||
def backward(ctx, dy): | ||
unpacked_activation = ctx.saved_tensors[0] | ||
return unpacked_activation.t() | ||
|
||
class NoOp(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, cloned_activation): | ||
ctx.save_for_backward(cloned_activation) | ||
return cloned_activation.clone() | ||
|
||
@staticmethod | ||
def backward(ctx, viewed_activation): | ||
rando_activation = ctx.saved_tensors[0] | ||
return viewed_activation | ||
|
||
class ComputeNode(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, activation): | ||
return activation.clone() | ||
|
||
@staticmethod | ||
def backward(ctx, viewed_activation): | ||
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles | ||
return viewed_activation == 1 | ||
|
||
class InspectEarlierActivation(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, activation): | ||
ctx.save_for_backward(torch.ones_like(activation) * 5) | ||
return activation | ||
|
||
@staticmethod | ||
def backward(ctx, viewed_activation_all_1): | ||
corrupter = ctx.saved_tensors[0] | ||
assert torch.all( | ||
viewed_activation_all_1 | ||
) # is the same as before (1s) and NOT W (5s)!! | ||
return corrupter | ||
|
||
def fwd(t): | ||
a = InspectEarlierActivation.apply(t) | ||
b = ComputeNode.apply(a) | ||
c = NoOp.apply(b) | ||
d = BwdReturnsViewOfActivation.apply(c) | ||
return d.sum() | ||
|
||
tensor_c = torch.ones(256, 1024, device="cuda", requires_grad=True) | ||
ctx = OffloadActivations(use_streams=True) | ||
with ctx: | ||
loss_c = fwd(tensor_c) | ||
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd | ||
ctx.fwd_stash = {} | ||
loss_c.backward() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I run with this, I get:
I think
o
can be None.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch thanks! Should be updated in latest commit