-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for lumina2 #10642
Add support for lumina2 #10642
Conversation
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] | ||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] | ||
|
||
for layer in self.context_refiner: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think context_refiner and noise_refiner should be moved inside foward
for better readbility
does it make sense to wrap the code to create freqs_cis
inside a class like this
class LTXVideoRotaryPosEmbed(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I have moved refiner into forward() for better readbility. As for the freqs_cis, we already have a class Lumina2PosEmbed.
I know we have support to pipe.enable_model_cpu_offload. I tried it out, not currently supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @zhuole1025! It looks great and almost ready to merge
We need to add tests and docs apart from addressing the comments/questions here. For both, this PR would serve as a good example of what files needs to be updated. In tests, a modeling test (test_models_transformer_lumina2.py
) and a pipeline test (test_lumina2.py
) will be needed. More than happy to help make any of the required changes to move the PR to completion 🤗
Thanks for all the suggestions! I have fix all of them and update a new version for review~ |
@zhuole1025 @csuhan Let me know your thoughts about the changes in #10732. I've verified on my end that the inference works as expected, so we should be able to merge the PR soon after these changes are incorporated |
@nitinmukesh's sample is functionally equivalent to what I had to produce that error |
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) | ||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) |
🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using GQA in lumina2, is it okay to remove this line?
Can confirm after pulling the latest changes I was unable to get the program to crash with vae slicing and tiling enabled, but further testing that I did BEFORE updating revealed that their enabling wasn't the problem and on some runs it would crash with just the below parameters being included and no VAE stuff. After updating, it appears to no longer be an issue
|
THeres a few MPS incompatibilities currently a couple of things Torch on MPS doesn't support and doesn't fall back to CPU on, and a float64 reference. in _get_freqs_cis
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: repeat(): Not supported for complex yet! in _get_freqs_cis
result.append(torch.gather(xx.repeat(index.shape[0], 1, 1), dim=1, index=index))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: gather(): Yet not supported for complex File "/Volumes/SSD2TB/AI/Lumina/lib/python3.11/site-packages/diffusers/models/transformers/transformer_lumina2.py", line 298, in forward
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. Please use float32 instead. The float64 reference is actually in emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) |
It fixes the float64, but the torch unsupported bits are still there, not sure if you expected that. def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
result = []
for i in range(len(self.axes_dim)):
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to('cpu', torch.int64)
freq_unsqueezed = freqs.unsqueeze(0).to("cpu")
result.append(torch.gather(freq_unsqueezed.repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1).to('mps') To get it working on MPS. |
The new PR backs out the float64 fix in the last PR fix if you add that back then it works |
quick q - if i look at reference implementation, it relies on |
@vladmandic Diffusers only supports directly calling torch SDPA. Custom attention processors can be used for flash-attn or other backends. To enable flash attention with pytorch, you can set the torch backend to We briefly discussed a dispatching mechanism for different libraries/backends but no decisions on how to move forward yet. Maybe something to consider for the release after current schedule |
@asomoza # test lumina2
import torch
from diffusers import Lumina2Text2ImgPipeline
import itertools
from pathlib import Path
import shutil
# branch = "refactor_lumina2"
branch = "main"
params = {
'use_mask_in_transformer': [True, False],
}
# Generate all combinations
param_combinations = list(itertools.product(*params.values()))
# Create output directory (remove if exists)
output_dir = Path(f"yiyi_test_6_outputs_{branch}")
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)
pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16).to("cuda")
prompt = [
"focused exterior view on a living room of a limestone rock regular blocks shaped villa with sliding windows and timber screens in Provence along the cliff facing the sea, with waterfalls from the roof to the pool, designed by Zaha Hadid, with rocky textures and form, made of regular giant rock blocks stacked each other with infinity edge pool in front of it, blends in with the surrounding nature. Regular rock blocks. Giant rock blocks shaping the space. The image to capture the infinity edge profile of the pool and the flow of water going down creating a waterfall effect. Adriatic Sea. The design is sustainable and semi prefab. The photo is shot on a canon 5D mark 4",
"A capybara holding a sign that reads Hello World"
]
# Run test for each combination
for (mask,) in param_combinations:
print(f"\nTesting combination:")
print(f" use_mask_in_transformer: {mask}")
# Generate image
generator = torch.Generator(device="cuda").manual_seed(0)
images = pipe(
prompt=prompt,
num_inference_steps=25,
use_mask_in_transformer=mask,
generator=generator,
).images
# Save images
for i, image in enumerate(images):
output_path = output_dir / f"output_mask{int(mask)}_prompt{i}.png"
image.save(output_path)
print(f"Saved to: {output_path}")
|
@asomoza
my results are generated when using a list of prompts as same time, I linked a test script in my previous comments here #10642 (comment) |
oh I totally missed that, so I was using a double system prompt. But still, I don't get that bad result and without the system prompt it generates cartoon images for me. I'll take another look then. |
with your code, I can see the difference, not that much as yours though because both are bad for me:
@yiyixuxu since this is happening with a list of prompts, isn't the issue with that instead of the mask or no mask? I don't use that option because it's not really usable for generating good images, so I'll add it to my tests from now on, but the difference for the capybara it's really noticeable versus the single prompt ones. |
@asomoza
for the outputs, are you getting a different results with the same testing script? is this pipeline not deterministic? |
some more tests: List of prompts with:prompt = [
"A cat wearing sunglasses, in the back there's a sign that says: Beware of the cat!",
"A capybara holding a sign that reads: I'm cute or what?",
]
Single prompts:prompt = "A cat wearing sunglasses, in the back there's a sign that says: Beware of the cat!"
prompt = "A capybara holding a sign that reads: I'm cute or what?"
I think this is just random, we're getting some bad results and good results depending on the seed and the prompt probably.
no, the pipeline is deterministic, I got different results when I switched to use your script with a list of prompts, that's expected right? since we're using the same generator. With a single prompt I can reproduce the exact same image all the time, BTW I'm always using a seed of |
I tested this and you're correct, I get the same exact image with and without the system prompt. |
@asomoza in this one example for the list of prompts, both seem to have better generation with mask; |
so using a list, and switching the order: long prompt first:
short prompt first:
Doing some more tests, the order and the longer prompt doesn't matter, the only one that changes a lot is the shorter prompt, so I'll post those only: shorter prompt without text: "Photo of a capybara wearing sunglasses and a jacket"
even shorter prompt without text: "a capybara"
medium prompt with text: "photo of a cat wearing stylish black sunglasses. The cat has a light brown and white fur pattern with distinct stripes. Behind the cat, there is a wooden sign with the text: Beware of the cat!!! written in a playful, handwritten style."
medium prompt without text: "photo of a cat wearing stylish black sunglasses and a leather jacket. The cat has a light brown and white fur pattern with distinct stripes. The overall scene has a humorous and whimsical tone, combining the cat's cool demeanor with human clothes and eyewear."
Finished with the testings, this is a nice find @yiyixuxu , we can infer that without a mask and using a list of prompts where one is shorter, the text generation will take a quality hit, so this is important for web services or APIs that use this model with those. Overall, I can't really see any other loss in quality with the exception that a shorter prompt, without details, will look worse without a mask because the model does the very basic with the prompt. For example, I noticed that without a mask and if you ask for a subject only, the model will most of the time generate the subject without a background which IMO is really nice but for normal people or use cases this won't look good because of the lack of detail and lighting. |
What does this PR do?
This PR will add the official Lumina-Image 2 to the diffusers. Lumina-Image 2.0 is the latest model in the Lumina family and will be released very soon (https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0). It is a 2B parameter Diffusion Transformer that significantly improves instruction-following and generates higher-quality, more diverse images. Our paper will be released soon, and we have finished the diffuser pipeline for Lumina-Image 2.0.
Core library: