Skip to content

Commit

Permalink
add latest freeinit test from #8969
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jul 26, 2024
1 parent 3cfd4fd commit 56f1cc4
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
AnimateDiffSparseControlNetPipeline,
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LCMScheduler,
MotionAdapter,
SparseControlNetModel,
StableDiffusionPipeline,
Expand Down Expand Up @@ -426,5 +428,51 @@ def test_free_init(self):
"Disabling of FreeInit should lead to results similar to the default pipeline results",
)

def test_free_init_with_schedulers(self):
components = self.get_dummy_components()
pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)

inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]

schedulers_to_test = [
DPMSolverMultistepScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
algorithm_type="dpmsolver++",
steps_offset=1,
clip_sample=False,
),
LCMScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
clip_sample=False,
),
]
components.pop("scheduler")

for scheduler in schedulers_to_test:
components["scheduler"] = scheduler
pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)

pipe.enable_free_init(num_iters=2, use_fast_sampling=False)

inputs = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()

self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeInit should lead to results different from the default pipeline results",
)

def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)

0 comments on commit 56f1cc4

Please sign in to comment.