Skip to content

Commit

Permalink
Modify FlowMatch Scale Noise (#8678)
Browse files Browse the repository at this point in the history
* initial fix

* apply suggestion

* delete step_index line
  • Loading branch information
asomoza authored and sayakpaul committed Dec 23, 2024
1 parent ac2d23d commit 26d6d99
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def __call__(
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

# 5. Prepare latent variables
if latents is None:
Expand Down
27 changes: 24 additions & 3 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,31 @@ def scale_noise(
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)

if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)

# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)

sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample

return sample
Expand Down

0 comments on commit 26d6d99

Please sign in to comment.