Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lumina2 #10642

Merged
merged 28 commits into from
Feb 11, 2025
Merged

Add support for lumina2 #10642

merged 28 commits into from
Feb 11, 2025

Conversation

zhuole1025
Copy link
Contributor

What does this PR do?

This PR will add the official Lumina-Image 2 to the diffusers. Lumina-Image 2.0 is the latest model in the Lumina family and will be released very soon (https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0). It is a 2B parameter Diffusion Transformer that significantly improves instruction-following and generates higher-quality, more diverse images. Our paper will be released soon, and we have finished the diffuser pipeline for Lumina-Image 2.0.

Core library:

cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]

for layer in self.context_refiner:
Copy link
Collaborator

@yiyixuxu yiyixuxu Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think context_refiner and noise_refiner should be moved inside foward for better readbility
does it make sense to wrap the code to create freqs_cis inside a class like this

class LTXVideoRotaryPosEmbed(nn.Module):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I have moved refiner into forward() for better readbility. As for the freqs_cis, we already have a class Lumina2PosEmbed.

@rodjjo
Copy link

rodjjo commented Feb 3, 2025

I know we have support to pipe.enable_model_cpu_offload.
Do you have the intent of adding pipe.enable_sequential_cpu_offload ?

I tried it out, not currently supported.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @zhuole1025! It looks great and almost ready to merge

We need to add tests and docs apart from addressing the comments/questions here. For both, this PR would serve as a good example of what files needs to be updated. In tests, a modeling test (test_models_transformer_lumina2.py) and a pipeline test (test_lumina2.py) will be needed. More than happy to help make any of the required changes to move the PR to completion 🤗

src/diffusers/models/attention.py Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/embeddings.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/transformer_lumina2.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/transformer_lumina2.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/lumina2/pipeline_lumina2.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/lumina2/pipeline_lumina2.py Outdated Show resolved Hide resolved
@a-r-r-o-w a-r-r-o-w added close-to-merge roadmap Add to current release roadmap labels Feb 4, 2025
@zhuole1025
Copy link
Contributor Author

zhuole1025 commented Feb 5, 2025

Thanks for all the suggestions! I have fix all of them and update a new version for review~

@a-r-r-o-w a-r-r-o-w mentioned this pull request Feb 6, 2025
@a-r-r-o-w
Copy link
Member

@zhuole1025 @csuhan Let me know your thoughts about the changes in #10732. I've verified on my end that the inference works as expected, so we should be able to merge the PR soon after these changes are incorporated

@Column01
Copy link

Column01 commented Feb 9, 2025

@nitinmukesh's sample is functionally equivalent to what I had to produce that error

# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)

🚀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using GQA in lumina2, is it okay to remove this line?

@Column01
Copy link

Column01 commented Feb 9, 2025

Can confirm after pulling the latest changes I was unable to get the program to crash with vae slicing and tiling enabled, but further testing that I did BEFORE updating revealed that their enabling wasn't the problem and on some runs it would crash with just the below parameters being included and no VAE stuff. After updating, it appears to no longer be an issue

cfg_trunc_ratio=0.25,
cfg_normalization=True,

@Vargol
Copy link

Vargol commented Feb 11, 2025

THeres a few MPS incompatibilities currently a couple of things Torch on MPS doesn't support and doesn't fall back to CPU on, and a float64 reference.

 in _get_freqs_cis
    result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: repeat(): Not supported for complex yet!
 in _get_freqs_cis
    result.append(torch.gather(xx.repeat(index.shape[0], 1, 1), dim=1, index=index))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: gather(): Yet not supported for complex
  File "/Volumes/SSD2TB/AI/Lumina/lib/python3.11/site-packages/diffusers/models/transformers/transformer_lumina2.py", line 298, in forward
    cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. Please use float32 instead.

The float64 reference is actually in _precompute_freqs_cis

emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)

@yiyixuxu
Copy link
Collaborator

@Vargol thanks for reporting, can you test to see if it run on MPS now with the latest commit 9ea9bbe?

@yiyixuxu yiyixuxu merged commit 81440fd into huggingface:main Feb 11, 2025
11 of 12 checks passed
@Vargol
Copy link

Vargol commented Feb 11, 2025

It fixes the float64, but the torch unsupported bits are still there, not sure if you expected that.
I still had to hack _get_freqs_cis to

   def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
        result = []
        for i in range(len(self.axes_dim)):
            freqs = self.freqs_cis[i].to(ids.device)
            index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to('cpu', torch.int64)
            freq_unsqueezed = freqs.unsqueeze(0).to("cpu")
            result.append(torch.gather(freq_unsqueezed.repeat(index.shape[0], 1, 1), dim=1, index=index))
        return torch.cat(result, dim=-1).to('mps')

To get it working on MPS.

@yiyixuxu
Copy link
Collaborator

@Vargol can you test on this PR? #10776

@Vargol
Copy link

Vargol commented Feb 12, 2025

The new PR backs out the float64 fix in the last PR fix if you add that back then it works

@vladmandic
Copy link
Contributor

quick q - if i look at reference implementation, it relies on flash_attn, but diffusers implementation does not?

@a-r-r-o-w
Copy link
Member

@vladmandic Diffusers only supports directly calling torch SDPA. Custom attention processors can be used for flash-attn or other backends. To enable flash attention with pytorch, you can set the torch backend to SDPBackend.FLASH_ATTENTION.

We briefly discussed a dispatching mechanism for different libraries/backends but no decisions on how to move forward yet. Maybe something to consider for the release after current schedule

@yiyixuxu
Copy link
Collaborator

@asomoza
I think not using mask affects the quality of shorter prompt quite a bit if using multiple prompts, can you verify?

# test lumina2
import torch
from diffusers import Lumina2Text2ImgPipeline
import itertools
from pathlib import Path
import shutil

# branch = "refactor_lumina2"
branch = "main"
params = {
    'use_mask_in_transformer': [True, False],  
}

# Generate all combinations
param_combinations = list(itertools.product(*params.values()))

# Create output directory (remove if exists)
output_dir = Path(f"yiyi_test_6_outputs_{branch}")
if output_dir.exists():
    shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16).to("cuda")

prompt = [
    "focused exterior view on a living room of a limestone rock regular blocks shaped villa with sliding windows and timber screens in Provence along the cliff facing the sea, with waterfalls from the roof to the pool, designed by Zaha Hadid, with rocky textures and form, made of regular giant rock blocks stacked each other with infinity edge pool in front of it, blends in with the surrounding nature. Regular rock blocks. Giant rock blocks shaping the space. The image to capture the infinity edge profile of the pool and the flow of water going down creating a waterfall effect. Adriatic Sea. The design is sustainable and semi prefab. The photo is shot on a canon 5D mark 4",
    "A capybara holding a sign that reads Hello World"
]

# Run test for each combination
for (mask,) in param_combinations:
    print(f"\nTesting combination:")
    print(f"  use_mask_in_transformer: {mask}")
    
    # Generate image
    generator = torch.Generator(device="cuda").manual_seed(0)
    images = pipe(
        prompt=prompt,
        num_inference_steps=25,
        use_mask_in_transformer=mask,
        generator=generator,
    ).images
    
    # Save images
    for i, image in enumerate(images):
        output_path = output_dir / f"output_mask{int(mask)}_prompt{i}.png"
        image.save(output_path)
        print(f"Saved to: {output_path}")

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

There's something different in your version or env, I don't get that bad result, also there's a couple of things to take notice:

  • You're not using a system prompt and I think it was trained to use them all the time.
  • It's really not a real use case scenario to use a really short prompt.
  • Text generation is really finicky everywhere and even if it's a cool research project and to play with it, I don't use it to benchmark because I can literally do a better job in seconds with any image software and it's not really usable in production. Even with flux I sometimes get bad results or low quality texts.
  • There's some quirks of the model that I haven't discovered how to use yet, for example, I can't get it to produce pose estimation images (with mask or no mask).

But to compare it to your results, these are my images without system prompt:

mask no mask
capybara_mask capybara_no_mask
cat_mask cat_no_mask

with a system prompt and the capybara I get a more similar result to yours but without the bad text:

mask no mask
capybara_sp_mask capybara_sp_no_mask

There's a difference but I need a lot more tests to see if there's a loss in quality, sometimes I think they masked ones are better but sometimes not, also most people won't notice the difference unless you show them side by side.

I do test the models with texts and with all the stuff I see in the paper or technical report, and for lumina2, it's really hard to tell yet if there's a loss in quality for me, maybe we will see a more clear difference with controlnets and loras.

IMO we should still make the default True just in case even if it adds some seconds in the generation time.

@yiyixuxu
Copy link
Collaborator

@asomoza
it comes with default system prompt, you don't have to explicitly provide one

self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."

my results are generated when using a list of prompts as same time, I linked a test script in my previous comments here #10642 (comment)

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

oh I totally missed that, so I was using a double system prompt. But still, I don't get that bad result and without the system prompt it generates cartoon images for me. I'll take another look then.

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

with your code, I can see the difference, not that much as yours though because both are bad for me:

mask no mask
output_mask1_prompt1 output_mask0_prompt1

@yiyixuxu since this is happening with a list of prompts, isn't the issue with that instead of the mask or no mask? I don't use that option because it's not really usable for generating good images, so I'll add it to my tests from now on, but the difference for the capybara it's really noticeable versus the single prompt ones.

@yiyixuxu
Copy link
Collaborator

@asomoza
I don't think you are using double, it's just if you don't provide one, it have a default, otherwise it will use yours

oh I totally missed that, so I was using a double system prompt.

for the outputs, are you getting a different results with the same testing script? is this pipeline not deterministic?

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

some more tests:

List of prompts with:

prompt = [
    "A cat wearing sunglasses, in the back there's a sign that says: Beware of the cat!",
    "A capybara holding a sign that reads: I'm cute or what?",
]
mask no mask
output_mask1_prompt0 output_mask0_prompt0
output_mask1_prompt1 output_mask0_prompt1

Single prompts:

prompt = "A cat wearing sunglasses, in the back there's a sign that says: Beware of the cat!"
mask no mask
cat_mask cat_no_mask
prompt = "A capybara holding a sign that reads: I'm cute or what?"
mask no mask
capybara_mask capybara_no_mask

I think this is just random, we're getting some bad results and good results depending on the seed and the prompt probably.

for the outputs, are you getting a different results with the same testing script? is this pipeline not deterministic?

no, the pipeline is deterministic, I got different results when I switched to use your script with a list of prompts, that's expected right? since we're using the same generator.

With a single prompt I can reproduce the exact same image all the time, BTW I'm always using a seed of 0 for these tests.

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

I don't think you are using double, it's just if you don't provide one, it have a default, otherwise it will use yours

I tested this and you're correct, I get the same exact image with and without the system prompt.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 12, 2025

@asomoza
thanks!
can you only test for list of prompts use case to see if we always get better results with mask or it is sort of random? I'm particularly interested in the shorter prompt in the list (because their hidden_states is padded when pass to transformer block)

in this one example for the list of prompts, both seem to have better generation with mask;

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

so using a list, and switching the order:

long prompt first:

mask no mask
output_mask1_prompt0 output_mask1_prompt1
output_mask0_prompt0 output_mask0_prompt1

short prompt first:

mask no mask
output_mask1_prompt0 output_mask1_prompt1
output_mask0_prompt0 output_mask0_prompt1

Doing some more tests, the order and the longer prompt doesn't matter, the only one that changes a lot is the shorter prompt, so I'll post those only:

shorter prompt without text: "Photo of a capybara wearing sunglasses and a jacket"

mask no mask
output_mask1_prompt0 output_mask0_prompt0

even shorter prompt without text: "a capybara"

mask no mask
output_mask1_prompt0 output_mask0_prompt0

medium prompt with text: "photo of a cat wearing stylish black sunglasses. The cat has a light brown and white fur pattern with distinct stripes. Behind the cat, there is a wooden sign with the text: Beware of the cat!!! written in a playful, handwritten style."

mask no mask
output_mask1_prompt0 output_mask0_prompt0

medium prompt without text: "photo of a cat wearing stylish black sunglasses and a leather jacket. The cat has a light brown and white fur pattern with distinct stripes. The overall scene has a humorous and whimsical tone, combining the cat's cool demeanor with human clothes and eyewear."

mask no mask
output_mask1_prompt0 output_mask0_prompt0

Finished with the testings, this is a nice find @yiyixuxu , we can infer that without a mask and using a list of prompts where one is shorter, the text generation will take a quality hit, so this is important for web services or APIs that use this model with those.

Overall, I can't really see any other loss in quality with the exception that a shorter prompt, without details, will look worse without a mask because the model does the very basic with the prompt. For example, I noticed that without a mask and if you ask for a subject only, the model will most of the time generate the subject without a background which IMO is really nice but for normal people or use cases this won't look good because of the lack of detail and lighting.

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

also an example with just text:

prompt = [
    "A sign that says: Lumina2 and diffusers are really good!!!",
    "A sign that says: We should make the capybara the official diffusers mascot!",
]
mask no mask
output_mask1_prompt0 output_mask0_prompt0
output_mask1_prompt1 output_mask0_prompt1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.