Skip to content

Commit

Permalink
Expand Single File support in SD3 Pipeline (#8517)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
DN6 authored and yiyixuxu committed Jun 20, 2024
1 parent 46418bd commit 7fada49
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/api/loaders/single_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`]
- [`StableDiffusion3Pipeline`]
- [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`]
Expand All @@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]

## FromSingleFileMixin

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ The abstract from the paper is:

## Usage Example

_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._

Use the command below to log in:
Use the command below to log in:

```bash
huggingface-cli login
Expand Down Expand Up @@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability

## Loading the single checkpoint for the `StableDiffusion3Pipeline`

### Loading the single file checkpoint without T5

```python
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel

text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
torch_dtype=torch.float16,
text_encoder_3=None
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')
```

<Tip>
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip>
### Loading the single file checkpoint without T5

```python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')
```

## StableDiffusion3Pipeline

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)

Expand Down Expand Up @@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading=is_legacy_loading,
)

elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)

elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
Expand Down
32 changes: 23 additions & 9 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@
LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
Expand Down Expand Up @@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):


def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint)
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True

return False


def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True

return False
Expand Down Expand Up @@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint


def convert_ldm_clip_checkpoint(checkpoint):
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys())
text_model_dict = {}

remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)

for key in keys:
for prefix in remove_prefixes:
Expand Down Expand Up @@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)

elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)

elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
Expand All @@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)

elif is_open_clip_sd3_model(checkpoint):
prefix = "text_encoders.clip_g.transformer."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif (
is_open_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")

else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
Expand Down Expand Up @@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}

remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
remove_prefixes = ["text_encoders.t5xxl.transformer."]

for key in keys:
for prefix in remove_prefixes:
Expand Down

0 comments on commit 7fada49

Please sign in to comment.