-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PAG variant for HunyuanDiT, PAG refactor #8936
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ahh cool! it looks like PAG improved a lot, no?! is the Chinese prompt still work better than English for hunyuan-dit 1.2? |
@@ -258,3 +260,86 @@ def pag_attn_processors(self): | |||
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): | |||
processors[name] = proc | |||
return processors | |||
|
|||
|
|||
class HunyuanDiTPAGMixin(PAGMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it is pretty natural to inherit from PAGMixin
( PAGMixin was not very designed to extend to other models I think)
Since we really want to keep the level of abstraction minimum in diffusers so let's try to come up with ideas to not do make HunyuanDiTPAGMixin
inherit from PAGMixin
( @sayakpaul is refactoring that on pixart pag PR https://github.com/huggingface/diffusers/pull/8921/files#r1689002227, I think we only need one PAG MIxin for all transformer models, so please share ideas there if you have any!)
@asomoza super cool analysis :) |
@yiyixuxu WDYT about the implementation? I've tried to make it so that handling both unet or transformer models would be easy, without too many code branches to check
Once it's good to go, I can update all tests once the other open PAG variants are merged, or their respective authors can adopt the mixin implementation from here. It is a bit backward breaking, but I think it's okay since we haven't released a Diffusers version with PAG yet so some unstability could be expected. Here's some code for demoimport torch
from diffusers import AutoPipelineForText2Image
pipe = AutoPipelineForText2Image.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
torch_dtype=torch.float16,
enable_pag=True,
).to("cuda:0")
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus. This imaginative creature features the distinctive, bulky body of a hippo, but with a texture and appearance resembling a golden-brown, crispy waffle. The creature might have elements like waffle squares across its skin and a syrup-like sheen. It’s set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting, possibly including oversized utensils or plates in the background. The image should evoke a sense of playful absurdity and culinary fantasy."
chinese_prompt = "一幅异想天开且富有创意的图像,描绘了华夫饼和河马的混合生物。这种富有想象力的生物具有河马独特而庞大的身体,但其质地和外观类似于金棕色的脆华夫饼。这种生物的皮肤上可能有华夫饼方块之类的元素,并且有糖浆般的光泽。它坐落在一个超现实的环境中,有趣地将河马的天然水生栖息地与早餐桌设置的元素结合在一起,背景可能包括超大的餐具或盘子。图像应该唤起一种俏皮的荒诞和烹饪幻想的感觉。"
test_layers = [
[],
"blocks.4",
"blocks.10",
"blocks.16",
"blocks.20",
"blocks\.22",
r"blocks.24",
"blocks.(16|17|18|19)",
r"blocks\.(20|21|22|23)",
["blocks.16", "blocks.17", "blocks\.(20|23)", r"blocks.24"],
]
for index, layers in enumerate(test_layers):
pipe.set_pag_applied_layers(layers)
print(f"Processing {layers}")
image = pipe(
prompt=prompt,
num_inference_steps=25,
guidance_scale=4,
pag_scale=2,
).images[0]
image.save(f"hunyuandit_{index}_eng.png")
image = pipe(
prompt=chinese_prompt,
num_inference_steps=25,
guidance_scale=4,
pag_scale=2,
).images[0]
image.save(f"hunyuandit_{index}_chn.png")
Yes, seems like a limitation with Hunyuan. Never got it to work for really descriptive prompts with all expected objects/styles.
Have not really dived too deep into this. I assumed it would be better for chinese because of model origin |
) | ||
|
||
self.set_pag_applied_layers( | ||
pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Pixart/SD3, deriving from PAGMixin and changing this would be all the changes needed I think. I have also enforced pag_applied_layers being either a string or list of strings, and disallowed integers to be passed for layers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! very nice!
I left some comments, also would wait for this to merge and refactor a little bit more https://github.com/huggingface/diffusers/pull/8921/files#diff-f28d850df75fd10c4100a8de8d944f3bbab67e03fd9d3122fb7cce3572a5eebc
r""" | ||
Check if the module is self-attention module based on its name. | ||
""" | ||
return "attn1" in module_name and "to" not in name | ||
attn_id = getattr(self, "_self_attn_identifier", "attn1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use this to identify self-attention?
self.is_cross_attention = cross_attention_dim is not None |
pag_attn_processors = getattr(self, "_pag_attn_processors", None) | ||
if pag_attn_processors is None: | ||
# If this hasn't been set by the user, we default to the original PAG identity processors | ||
pag_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to raise an error here
def set_pag_applied_layers( | ||
self, | ||
pag_applied_layers: Union[str, List[str]], | ||
pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, | |
pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = ((PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) |
let's just
@@ -253,8 +214,13 @@ def pag_attn_processors(self): | |||
with the key as the name of the layer. | |||
""" | |||
|
|||
if not hasattr(self, "_pag_attn_processors") or self._pag_attn_processors is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just make sure it always has _pag_attn_processor
set to simplify things
r"""Mixin class for PAG.""" | ||
|
||
@staticmethod | ||
def _check_input_pag_applied_layer(layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where did this check go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this check any more btw? since the matching is now regex based, a user of the library could do a wide variety of things that would be near to impossible to validate. how could we verify this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the regex itself should error out then right if there's no matching pattern? If we're not doing it already, could we add one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do raise an error if no matches are found via regex, so that should be covered
@yiyixuxu thanks for the review! i've addressed them for the new design. let me know if there's anything else that i missed. cc @sayakpaul for viz. we won't be able to use integer layer indexes any more to apply PAG (let's say for PixArt). i've made it so that we use strings throughout. with the new design, one can apply PAG in many ways (due to regex matching) and the intended ways to use it is by specifying will update the tests for everything once i get the approval on the final design let's try and get this merged soon since we can't introduce backward breaking changes after next release |
@a-r-r-o-w could we fix the PixArt Sigma tests as well? Can merge afterward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you. The refactor is beautiful. My comments are mostly minor. LMK in case something isn't clear enough.
Let's fix the PixArt tests and merge.
@@ -132,7 +132,7 @@ def retrieve_timesteps( | |||
return timesteps, num_inference_steps | |||
|
|||
|
|||
class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): | |||
class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update the example doc too and the default value here:
... pag_applied_layers=[14], |
pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block |
r"""Mixin class for PAG.""" | ||
|
||
@staticmethod | ||
def _check_input_pag_applied_layer(layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the regex itself should error out then right if there's no matching pattern? If we're not doing it already, could we add one?
# Identify the following simple cases: | ||
# (1) Self Attention layer existing | ||
# (2) Whether the module name matches pag layer id even partially | ||
# (3) Make sure it's not a fake integral match if the layer_id ends with a number | ||
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
and re.search(layer_id, name) is not None | ||
and not is_fake_integral_match(layer_id, name) | ||
): | ||
logger.debug(f"Apply PAG to layer: {name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
if self._pag_attn_processors is None: | ||
return {} | ||
|
||
valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious on the choice of tuple
as a data structure here. Why not a set in case performance is concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, apologies for the oversight. Should be set
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( | ||
PAGCFGIdentitySelfAttnProcessor2_0(), | ||
PAGIdentitySelfAttnProcessor2_0(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should Hunyuan processors as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct. Will update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, no. If this was a type hint, we would have had Hunyuan ones as well. I am instantiating with default pag attention processors here if the user does not provide any (which is how we do it in most pag pipelines). For Hunyuan, we explicitly pass the pag_attn_processors parameter.
@@ -127,7 +127,7 @@ def test_pag_disable_enable(self): | |||
out = pipe(**inputs).images[0, -3:, -3:, -1] | |||
|
|||
# pag disabled with pag_scale=0.0 | |||
components["pag_applied_layers"] = [1] | |||
components["pag_applied_layers"] = ["1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be "block.1"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, my bad. Will update these. The intended usage is "blocks.{index}" or some regex combo of layers, but just strings like this work too (in this case) since the PAG logic is able to fulfil the 3 conditions mentioned. It will give unexpected behaviour if you have layers named like:
something.blocks.1.another_something.1.attn1
something.blocks.42.another_something.1.attn1
In this case, the intent was probably to only apply to something.blocks.1
but PAG will get applied to both. But the general guide would be to follow blocks.{}
convention because we use it consistently in naming our blocks. With regex, there are many ways to shoot yourself in the foot and so we can't really account for all cases if users don't want to follow the general guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah fair enough. Maybe we can update the documentation to be explicit about that? PAG is a very widely used feature, so, need to be rigorous. WDYT?
@sayakpaul Why does it say that it can't find the For expected slices, will fix them in a bit |
I would suggest taking a look more closely. If that does not cut it I can take a closer look. Why do we have to update the slices? Usually, that means there might be some incorrect codepath being taken. |
The reason slices need to be updated in AnimateDiffPAG is because they were created for original PAG implementation where if you specified "mid", only
Now that attn2 is also used by default (quality-wise it is similar), we'd need to update slices. Edit: the easier fix would be to just specify attn1 layers to stay consistent with the test and example code here |
@yiyixuxu @sayakpaul For some reason, MRO is fighting me here and I'm not completely sure why. Dump>>> from diffusers import HunyuanDiTPAGPipeline
>>> pipe = HunyuanDiTPAGPipeline(None, None, None, None, None)
You have disabled the safety checker for <class 'diffusers.pipelines.pag.pipeline_pag_hunyuandit.HunyuanDiTPAGPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
>>> dir(pipe)
['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply_perturbed_attention_guidance', '_callback_tensor_inputs', '_deprecated_kwargs', '_dict_from_json_file', '_exclude_from_cpu_offload', '_execution_device', '_get_init_keys', '_get_pag_scale', '_get_signature_keys', '_get_signature_types', '_internal_dict', '_is_onnx', '_load_connected_pipes', '_optional_components', '_pag_attn_processors', '_prepare_perturbed_attention_guidance', '_set_pag_attn_processor', '_upload_folder', 'check_inputs', 'components', 'config', 'config_name', 'default_sample_size', 'device', 'disable_attention_slicing', 'disable_xformers_memory_efficient_attention', 'do_classifier_free_guidance', 'do_pag_adaptive_scaling', 'do_perturbed_attention_guidance', 'download', 'dtype', 'enable_attention_slicing', 'enable_model_cpu_offload', 'enable_sequential_cpu_offload', 'enable_xformers_memory_efficient_attention', 'encode_prompt', 'extract_init_dict', 'feature_extractor', 'from_config', 'from_pipe', 'from_pretrained', 'get_config_dict', 'guidance_rescale', 'guidance_scale', 'has_compatibles', 'hf_device_map', 'ignore_for_config', 'image_processor', 'interrupt', 'load_config', 'maybe_free_model_hooks', 'model_cpu_offload_seq', 'name_or_path', 'num_timesteps', 'numpy_to_pil', 'pag_adaptive_scale', 'pag_applied_layers', 'pag_attn_processors', 'pag_scale', 'prepare_extra_step_kwargs', 'prepare_latents', 'progress_bar', 'push_to_hub', 'register_modules', 'register_to_config', 'remove_all_hooks', 'reset_device_map', 'run_safety_checker', 'safety_checker', 'save_config', 'save_pretrained', 'scheduler', 'set_attention_slice', 'set_pag_applied_layers', 'set_progress_bar_config', 'set_use_memory_efficient_attention_xformers', 'text_encoder', 'text_encoder_2', 'to', 'to_json_file', 'to_json_string', 'tokenizer', 'tokenizer_2', 'transformer', 'vae', 'vae_scale_factor']
>>> pipe.pag_attn_processors
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/aryan/work/diffusers/src/diffusers/configuration_utils.py", line 143, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HunyuanDiTPAGPipeline' object has no attribute 'pag_attn_processors'. Did you mean: '_pag_attn_processors'?
>>> pipe.pag_scale
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/aryan/work/diffusers/src/diffusers/configuration_utils.py", line 143, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'HunyuanDiTPAGPipeline' object has no attribute 'pag_scale' Also:
If you log the |
That is indeed better. Regarding the MRO, I am taking a look. Will update here if I find anything. |
@a-r-r-o-w I think I found the bug. Here's a patch: Let me know if this resolves your issue. |
Co-Authored-By: Sayak Paul <[email protected]>
Since we have Yiyi's approval, I think we can merge this. |
Thanks for all the reviews on this one @sayakpaul @yiyixuxu and the design iterations, @asomoza for the really cool examples, and ofcourse @sunovivid and team for the awesome technique! Will merge once CI is green |
* copy hunyuandit pipeline * pag variant of hunyuan dit * add tests * update docs * make style * make fix-copies * Update src/diffusers/pipelines/pag/pag_utils.py * remove incorrect copied from * remove pag hunyuan attn procs to resolve conflicts * add pag attn procs again * new implementation for pag_utils * revert pag changes * add pag refactor back; update pixart sigma * update pixart pag tests * apply suggestions from review Co-Authored-By: [email protected] * make style * update docs, fix tests * fix tests * fix test_components_function since list not accepted as valid __init__ param * apply patch to fix broken tests Co-Authored-By: Sayak Paul <[email protected]> * make style * fix hunyuan tests --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
Part of #8785.
Code
TLDR; Layers 16-28 seem to be the sweet spot for applying PAG in Hunyuan. High CFG causes final image to have many artifacts, so values between 2-5 seem to work best. Low PAG scale between 1-3 also yields better results. Higher PAG and CFG causes the hilt of the sword to become more curved and causes artifacts near the hand, nose and whiskers, and it looks more blown-out and high constrast-y.
Ablations
[0, 0.5, 2.5, 4]
[1.0, 2.5, 5.0, 7.0]
TODO:
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@yiyixuxu @asomoza @sunovivid