Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling gradient checkpointing in eval() mode #9878

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))

for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def forward(

hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1191,7 +1191,7 @@ def forward(

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def forward(

# Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(

# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.temp_conv_in(sample)
sample = sample + residual

if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -522,7 +522,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -636,7 +636,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -773,7 +773,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -939,7 +939,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -311,7 +311,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -392,7 +392,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward(
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -646,7 +646,7 @@ def forward(
hidden_states = self.conv_in(hidden_states)

# 1. Mid
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:

sample = self.conv_in(sample)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -291,7 +291,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -544,7 +544,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -876,7 +876,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Clamp.
x = torch.tanh(x / 3) * 3

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/controlnets/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def forward(

block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -363,7 +363,7 @@ def custom_forward(*inputs):

single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(
block_res_samples = ()

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def custom_forward(*inputs):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)

# apply base subblock
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
Expand All @@ -1489,7 +1489,7 @@ def custom_forward(*inputs):

# apply ctrl subblock
if apply_control:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
Expand Down Expand Up @@ -1898,7 +1898,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def forward(

# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -497,7 +497,7 @@ def custom_forward(*inputs):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def forward(

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
Expand Down Expand Up @@ -271,7 +271,7 @@ def forward(
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def forward(
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def forward(
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def forward(
hidden_states = hidden_states[:, text_seq_length:]

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def forward(
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -525,7 +525,7 @@ def custom_forward(*inputs):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def forward(
)

for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Loading
Loading