diff --git a/mmagic/models/editors/consistency_models/consistencymodel_modules.py b/mmagic/models/editors/consistency_models/consistencymodel_modules.py index 486f37299..70376cdd9 100644 --- a/mmagic/models/editors/consistency_models/consistencymodel_modules.py +++ b/mmagic/models/editors/consistency_models/consistencymodel_modules.py @@ -235,13 +235,13 @@ def __init__( self.convert_to_fp16() def convert_to_fp16(self): - """Convert the torso of the model to float16.""" + """Convert the tensor of the model to float16.""" self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) self.output_blocks.apply(convert_module_to_f16) def convert_to_fp32(self): - """Convert the torso of the model to float32.""" + """Convert the tensor of the model to float32.""" self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py index 2c822f3e6..400ec5edc 100644 --- a/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py +++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_model_utils.py @@ -69,6 +69,8 @@ def test_karras_sample(self): model_kwargs=model_kwargs) assert sample.shape == (batch_size, channel_num, image_size, image_size) + unet.convert_to_fp32() + unet.convert_to_fp16() def test_get_generator(self): self.assertIsInstance(get_generator('dummy'), DummyGenerator) @@ -122,6 +124,7 @@ def denoiser(x_t, sigma): generator_list = ['dummy', 'determ', 'determ-indiv'] for generator_str in generator_list: generator = get_generator(generator_str, 4, 0) + generator.randint(1, 2, (1, 2), dtype=torch.long, device='cpu') x_T = generator.randn(*shape, device=device) * sigma_max for sample in sample_list: if sample == 'progdist':