Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chore] allow auraflow latest to be torch compile compatible. #8859

Merged
merged 3 commits into from
Jul 17, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 13, 2024

What does this PR do?

Otherwise, we run into a shape broadcasting problem during torch.compile(). Note that this doesn't happen when the internal AuraDiffusion/auradiffusion-v0.1a0 variant.

Benchmarking script

Benchmarking script

import torch 

torch.set_float32_matmul_precision("high")

from diffusers import AuraFlowPipeline
import time

id = "fal/AuraFlow" # "AuraDiffusion/auradiffusion-v0.1a0"
pipeline = AuraFlowPipeline.from_pretrained(
    id, 
    torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "a photo of a cat"
for _ in range(3):
    _ = pipeline(
        prompt=prompt,
        num_inference_steps=50, 
        guidance_scale=5.0,
        generator=torch.manual_seed(1),
    )

start = time.time()
for _ in range(10):
    _ = pipeline(
        prompt=prompt,
        num_inference_steps=50, 
        guidance_scale=5.0,
        generator=torch.manual_seed(1),
    )
end = time.time()
avg_inference_time = (end - start) / 10
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
Results

Results

| With `torch.compile()` | Without `torch.compile()` | |----------|----------| | 7.293 seconds | 8.801 seconds |

@apolinario FYI.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu July 13, 2024 13:28
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

I don't think there is anything wrong with torch.compile
maybe you should update the default height and width for aura pipelines now the checkpoint only supports 1024 https://github.com/huggingface/diffusers/blob/bbd2f9d4e9ae70b04fedf65903fd1fb035437db4/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py#L394C9-L394C15

@bghira
Copy link
Contributor

bghira commented Jul 14, 2024

it's a shame that the height and width aren't just None, and then default to the sample size in the model config. but i guess it's actually hard to concretely calculate that. we can hardcode it based on max_pos_embed value. i'm not sure if we're going to make any 512px models available. they are very fast and nice to work with.

@sayakpaul
Copy link
Member Author

@yiyixuxu done.

@bghira

height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor

is our automatic resolution inference.

@apolinario
Copy link
Collaborator

@sayakpaul , by changing the default, you un-did the initial change, is that intended? f6c975a

@sayakpaul
Copy link
Member Author

yeah that is right.

@sayakpaul sayakpaul merged commit c2fbf8d into main Jul 17, 2024
16 of 18 checks passed
@sayakpaul sayakpaul deleted the aura-flow-compile branch July 17, 2024 02:56
Disty0 pushed a commit to Disty0/diffusers that referenced this pull request Jul 18, 2024
…gface#8859)

* allow auraflow latest to be torch compile compatible.

* default to 1024 1024.
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* allow auraflow latest to be torch compile compatible.

* default to 1024 1024.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants