diff --git a/docs/source/en/api/loaders/single_file.md b/docs/source/en/api/loaders/single_file.md index 80b494ceb2e7..0af0ce6488d4 100644 --- a/docs/source/en/api/loaders/single_file.md +++ b/docs/source/en/api/loaders/single_file.md @@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: - [`StableDiffusionXLInstructPix2PixPipeline`] - [`StableDiffusionXLControlNetPipeline`] - [`StableDiffusionXLKDiffusionPipeline`] +- [`StableDiffusion3Pipeline`] - [`LatentConsistencyModelPipeline`] - [`LatentConsistencyModelImg2ImgPipeline`] - [`StableDiffusionControlNetXSPipeline`] @@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: - [`StableCascadeUNet`] - [`AutoencoderKL`] - [`ControlNetModel`] +- [`SD3Transformer2DModel`] ## FromSingleFileMixin diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index a7b0103bd2e0..cc605f4a94bb 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -21,9 +21,9 @@ The abstract from the paper is: ## Usage Example -_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ +_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ -Use the command below to log in: +Use the command below to log in: ```bash huggingface-cli login @@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability ## Loading the single checkpoint for the `StableDiffusion3Pipeline` +### Loading the single file checkpoint without T5 + ```python +import torch from diffusers import StableDiffusion3Pipeline -from transformers import T5EncoderModel -text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16) -pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3) +pipe = StableDiffusion3Pipeline.from_single_file( + "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", + torch_dtype=torch.float16, + text_encoder_3=None +) +pipe.enable_model_cpu_offload() + +image = pipe("a picture of a cat holding a sign that says hello world").images[0] +image.save('sd3-single-file.png') ``` - -`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space. - +### Loading the single file checkpoint without T5 + +```python +import torch +from diffusers import StableDiffusion3Pipeline + +pipe = StableDiffusion3Pipeline.from_single_file( + "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors", + torch_dtype=torch.float16, +) +pipe.enable_model_cpu_offload() + +image = pipe("a picture of a cat holding a sign that says hello world").images[0] +image.save('sd3-single-file-t5-fp8.png') +``` ## StableDiffusion3Pipeline diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index d2b69b234ab0..d7bf67288c0a 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -28,9 +28,11 @@ _legacy_load_safety_checker, _legacy_load_scheduler, create_diffusers_clip_model_from_ldm, + create_diffusers_t5_model_from_checkpoint, fetch_diffusers_config, fetch_original_config, is_clip_model_in_single_file, + is_t5_in_single_file, load_single_file_checkpoint, ) @@ -118,6 +120,16 @@ def load_single_file_sub_model( is_legacy_loading=is_legacy_loading, ) + elif is_transformers_model and is_t5_in_single_file(checkpoint): + loaded_sub_model = create_diffusers_t5_model_from_checkpoint( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + ) + elif is_tokenizer and is_legacy_loading: loaded_sub_model = _legacy_load_clip_tokenizer( class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d788aa11d37d..98fef894ee2f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -252,7 +252,6 @@ LDM_CLIP_PREFIX_TO_REMOVE = [ "cond_stage_model.transformer.", "conditioner.embedders.0.transformer.", - "text_encoders.clip_l.transformer.", ] OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 @@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint): def is_open_clip_sd3_model(checkpoint): - is_open_clip_sdxl_refiner_model(checkpoint) + if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: + return True + + return False def is_open_clip_sdxl_refiner_model(checkpoint): - if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: + if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: return True return False @@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def convert_ldm_clip_checkpoint(checkpoint): +def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): keys = list(checkpoint.keys()) text_model_dict = {} - remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE + remove_prefixes = [] + remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) + if remove_prefix: + remove_prefixes.append(remove_prefix) for key in keys: for prefix in remove_prefixes: @@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm( ): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + elif ( + is_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") + diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) + elif is_open_clip_model(checkpoint): prefix = "cond_stage_model.model." diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) @@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm( prefix = "conditioner.embedders.0.model." diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - elif is_open_clip_sd3_model(checkpoint): - prefix = "text_encoders.clip_g.transformer." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + elif ( + is_open_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") else: raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") @@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): keys = list(checkpoint.keys()) text_model_dict = {} - remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."] + remove_prefixes = ["text_encoders.t5xxl.transformer."] for key in keys: for prefix in remove_prefixes: