-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tencent Hunyuan Team: add HunyuanDiT related updates #8240
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this.
I left a number of design related comments. I think we need to get them sorted first and then work on other nits. But let's also wait for @yiyixuxu to comment further before making any changes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I think the first a few main tasks are:
- rewrite attentions using diffusers (I will help with this)
- remove
timm
dependency (I think you only use it for a simple function, it should be easy to remove) - refactor the pipelines (I agree with @sayakpaul's comments there, we can start with these things he has pointed out)
embedder_t5=embedder_t5, | ||
) | ||
|
||
self.text_encoder.pooler.to_empty(device='cpu') ### workaround for the meta device in pooler... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul can you look into how we can make it work with from_pretrained()
?
do we need to make a wrapper of bert model inside diffusers?
Hi: I removed feel free to test out the PR branch and cherry-pick this commit 3f85b1d if results are ok to you I included a testing script here #8265 (comment) |
… norm3; update test file
Hi I did the following things:
For now, I will not change the remote state_dict Please review and comment, thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes and for incorporating our feedback so nicely so far! I have a left a couple of comments.
Broadly, I think, for the next phase of the integration, we could focus on:
- Solving the
BertModel
pooler problem as mentioned in the comments. - Remove all the unnecessary blocks related to attention, etc.
- Move the embedding related functionalities to
embeddings.py
. - Try to have Rotary Embeddings as a class.
Would like to also cross-check with @yiyixuxu here at this point if these points make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
looks much better now! I left a few more comments.
remaining to-dos:
- can you address the comments on the pipeline?
- I will help refactor models
- working on tests (cc @sayakpaul here, can you give more guidance)
- make the
BertModel
work withDiffusionPipeline.from_pretrained()
cc @sayakpaul and @DN6 here see https://github.com/huggingface/diffusers/pull/8240/files#r1611508812
I made some improvements according to Sayak and Yiyi's suggestions. Several additional problems :
|
Why wouldn't it work, though? Could you provide more details here? In order for the pipeline to operate in diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Line 1264 in b3d10d6
Additionally, could you see if this worked so that we can avoid |
I refactored the model here #8310 I did below things:
I changed arg names to be more aligned with our transformer models and blocks also removed bunch of functionalities that are not used in this implementation - let me know if I did anything wrong or any of the changes does not make sense! feel free to just pick the commit and make any modifications on your PR. |
Hi Sayak,
I guess the reason is I'm using the latest |
Okay great. It seems like the FP16 problem and also the |
I pushed a new version. In this version:
The new test file is updated in Thank you for the help! Please review the new version @yiyixuxu @sayakpaul |
Note: from transformers import BertModel
bert_model = BertModel.from_pretrained("XCLiu/HunyuanDiT-0523", add_pooling_layer=True, subfolder="text_encoder")
pipe = HunyuanDiTPipeline.from_pretrained("XCLiu/HunyuanDiT-0523", text_encoder=bert_model, transformer=model, torch_dtype=torch.float32)
pipe.to('cuda')
pipe.save_pretrained("HunyuanDiT-ckpt")
del pipe
pipe = HunyuanDiTPipeline.from_pretrained("HunyuanDiT-ckpt")
pipe.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!!
I think we can merge this soon
We have tests and doc left - doc can be added in a separate PR if you need more time, but let's quickly add a test, you can use pixart_sigma test as reference https://github.com/huggingface/diffusers/tree/main/tests/pipelines/pixart_sigma
src/diffusers/models/embeddings.py
Outdated
@@ -806,6 +920,27 @@ def forward(self, caption): | |||
return hidden_states | |||
|
|||
|
|||
# YiYi notes: combine PixArtAlphaTextProjection and HunYuanTextProjection? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 here should we use PixArtAlphaTextProjection
instead? it seems like this projection layer is going to be used a lot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think we can make it a generic TextProjection class and reuse it.
).to(origin_dtype) | ||
|
||
|
||
class AdaLayerNormShift(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gnobitab, was this a method your team came up with? Is there any reason you chose this over the regular Ada layer norm with both scale and shift?
I want to understand this because we will need to decide to keep this module here or put it in the embeddings.py so it is easier for other researchers to use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think AdaLayerNorm can be reused here as well
class AdaLayerNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll go double check with the team and get back to you later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do [2]: refactor?
@@ -507,6 +626,88 @@ def forward(self, timestep, class_labels, hidden_dtype=None): | |||
return conditioning | |||
|
|||
|
|||
class HunyuanDiTAttentionPool(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do[1] (low prior): see if we can consolidate this class with
diffusers/src/diffusers/models/embeddings.py
Line 588 in bc108e1
class AttentionPooling(nn.Module): |
class Kandinsky3AttentionPooling(nn.Module): |
can ask community for help
).to(origin_dtype) | ||
|
||
|
||
class AdaLayerNormShift(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do [2]: refactor?
return hidden_states | ||
|
||
|
||
class HunyuanDiT2DModel(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do[3] optimization! cc @sayakpaul
text_encoder: BertModel, | ||
tokenizer: BertTokenizer, | ||
transformer: HunyuanDiT2DModel, | ||
scheduler: DDPMScheduler, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[to-do 3] Does hunyuan DiT work with other schedulers? @gnobitab
if so, we need to change the type here, and also let's update the doc example with a more efficient scheduler
scheduler: DDPMScheduler, | |
scheduler: KarrasDiffusionSchedulers, |
enable_full_determinism() | ||
|
||
|
||
class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do[4]
I added s simple fast test and slow test. @gnobitab can we add more tests? see examples https://github.com/huggingface/diffusers/blob/bc108e15333cb0e8a092647320cbb4d70d6d0f03/tests/pipelines/pixart_sigma/test_pixart.py
we help with this too but we prefer authors to add tests and make sure the current implementation is correct!
max_diff = np.abs(image_slice.flatten() - expected_slice).max() | ||
self.assertLessEqual(max_diff, 1e-3) | ||
|
||
def test_sequential_cpu_offload_forward_pass(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to-do[5] (slow prior) the sequential cpu offload does not work, can look into this. but not a priority since the other offloading method works and is way more popular cc @sayakpaul
merging now! I left a bunch to-dos in the comments too for @sayakpaul for @gnobitab
here is the script that works now for the checkpoint in my PR to your repo https://huggingface.co/XCLiu/HunyuanDiT-0523/discussions/2 # integration test (hunyuan dit)
import torch
from diffusers import HunyuanDiTPipeline
device = "cuda"
dtype = torch.float16
repo = "XCLiu/HunyuanDiT-0523"
pipe = HunyuanDiTPipeline.from_pretrained(repo, revision="refs/pr/2", torch_dtype=dtype)
pipe.enable_model_cpu_offload()
### NOTE: HunyuanDiT supports both Chinese and English inputs
prompt = "一个宇航员在骑马"
#prompt = "An astronaut riding a horse"
generator=torch.Generator(device="cuda").manual_seed(0)
image = pipe(height=1024, width=1024, prompt=prompt, generator=generator).images[0]
image.save("yiyi_test_out.png") and this is the script I used to convert the current checkpoint in
import torch
from huggingface_hub import hf_hub_download
from diffusers import HunyuanDiTPipeline, HunyuanDiT2DModel
from transformers import T5EncoderModel, T5Tokenizer
import safetensors.torch
device = "cuda"
dtype = torch.float32
repo = "XCLiu/HunyuanDiT-0523"
tokenizer_2 = T5Tokenizer.from_pretrained(repo, subfolder = "tokenizer_t5")
text_encoder_2 = T5EncoderModel.from_pretrained(repo, subfolder = "embedder_t5", torch_dtype=dtype)
model_config = HunyuanDiT2DModel.load_config("XCLiu/HunyuanDiT-0523", subfolder="transformer")
model = HunyuanDiT2DModel.from_config(model_config).to(device)
ckpt_path = hf_hub_download(
"XCLiu/HunyuanDiT-0523",
filename ="diffusion_pytorch_model.safetensors",
subfolder="transformer",
)
state_dict = safetensors.torch.load_file(ckpt_path)
prefix = "time_extra_emb."
# time_embedding.linear_1 -> timestep_embedder.linear_1
state_dict[f"{prefix}timestep_embedder.linear_1.weight"] = state_dict["time_embedding.linear_1.weight"]
state_dict[f"{prefix}timestep_embedder.linear_1.bias"] = state_dict["time_embedding.linear_1.bias"]
state_dict.pop("time_embedding.linear_1.weight")
state_dict.pop("time_embedding.linear_1.bias")
# time_embedding.linear_2 -> timestep_embedder.linear_2
state_dict[f"{prefix}timestep_embedder.linear_2.weight"] = state_dict["time_embedding.linear_2.weight"]
state_dict[f"{prefix}timestep_embedder.linear_2.bias"] = state_dict["time_embedding.linear_2.bias"]
state_dict.pop("time_embedding.linear_2.weight")
state_dict.pop("time_embedding.linear_2.bias")
# pooler.positional_embedding
state_dict[f"{prefix}pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
state_dict.pop("pooler.positional_embedding")
# pooler.k_proj
state_dict[f"{prefix}pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
state_dict[f"{prefix}pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
state_dict.pop("pooler.k_proj.weight")
state_dict.pop("pooler.k_proj.bias")
#pooler.q_proj
state_dict[f"{prefix}pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
state_dict[f"{prefix}pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
state_dict.pop("pooler.q_proj.weight")
state_dict.pop("pooler.q_proj.bias")
# pooler.v_proj
state_dict[f"{prefix}pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
state_dict[f"{prefix}pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
state_dict.pop("pooler.v_proj.weight")
state_dict.pop("pooler.v_proj.bias")
# pooler.c_proj
state_dict[f"{prefix}pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
state_dict[f"{prefix}pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
state_dict.pop("pooler.c_proj.weight")
state_dict.pop("pooler.c_proj.bias")
# style_embedder.weight
state_dict[f"{prefix}style_embedder.weight"] = state_dict["style_embedder.weight"]
state_dict.pop("style_embedder.weight")
# extra_embedder.linear_1
state_dict[f"{prefix}extra_embedder.linear_1.weight"] = state_dict["extra_embedder.linear_1.weight"]
state_dict[f"{prefix}extra_embedder.linear_1.bias"] = state_dict["extra_embedder.linear_1.bias"]
state_dict.pop("extra_embedder.linear_1.weight")
state_dict.pop("extra_embedder.linear_1.bias")
# extra_embedder.linear_2
state_dict[f"{prefix}extra_embedder.linear_2.weight"] = state_dict["extra_embedder.linear_2.weight"]
state_dict[f"{prefix}extra_embedder.linear_2.bias"] = state_dict["extra_embedder.linear_2.bias"]
state_dict.pop("extra_embedder.linear_2.weight")
state_dict.pop("extra_embedder.linear_2.bias")
model.load_state_dict(state_dict)
model.to(dtype)
pipe = HunyuanDiTPipeline.from_pretrained(
repo,
tokenizer_2 = tokenizer_2,
text_encoder_2 = text_encoder_2,
transformer = model,
torch_dtype=dtype) |
Thanks for merging! Reply to your TODOs:
(2) Scheduler: I tested several fast samplers. From my test, I think it is safe to switch from ``DDPMScheduler 4 and 5: I merged your PR in |
* Hunyuan Team: add HunyuanDiT related updates --------- Co-authored-by: XCLiu <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
diffusers commit 4136044 Tencent Hunyuan Team: add HunyuanDiT related updates huggingface/diffusers#8240
* Hunyuan Team: add HunyuanDiT related updates --------- Co-authored-by: XCLiu <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
This PR did the following things:
HunyuanDiTPipeline
insrc/diffusers/pipelines/hunyuandit/
andHunyuanDiT2DModel
in./src/diffusers/models/transformers/
.HunyuanDiT2DModel
, addedHunyuanDiTBlock
and helper functions insrc/diffusers/models/attention.py
.XCLiu/HunyuanDiT-0523
In this branch, you can run HunyuanDiT in FP32 with:
which includes the following codes:
Dependency:
maybe the
timm
packageTODO lists:
use_fp16
inHunyuanDiTPipeline.__call__()
. The reason isBertModel
does not support FP16 quantization. In our repo we only quantize the diffusion transformer to FP16. I guess there must be some smart way to support FP16.HunyuanDiTBlock
related codes insrc/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
.Thank you so much! I'll be there and help with everything.
cc: @sayakpaul @yiyixuxu