Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaomile committed Dec 13, 2023
1 parent 05fe595 commit f68dc99
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,13 @@ def __init__(
self.convert_to_fp16()

def convert_to_fp16(self):
"""Convert the torso of the model to float16."""
"""Convert the tensor of the model to float16."""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)

def convert_to_fp32(self):
"""Convert the torso of the model to float32."""
"""Convert the tensor of the model to float32."""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def test_karras_sample(self):
model_kwargs=model_kwargs)
assert sample.shape == (batch_size, channel_num, image_size,
image_size)
unet.convert_to_fp32()
unet.convert_to_fp16()

def test_get_generator(self):
self.assertIsInstance(get_generator('dummy'), DummyGenerator)
Expand Down Expand Up @@ -122,6 +124,7 @@ def denoiser(x_t, sigma):
generator_list = ['dummy', 'determ', 'determ-indiv']
for generator_str in generator_list:
generator = get_generator(generator_str, 4, 0)
generator.randint(1, 2, (1, 2), dtype=torch.long, device='cpu')
x_T = generator.randn(*shape, device=device) * sigma_max
for sample in sample_list:
if sample == 'progdist':
Expand Down

0 comments on commit f68dc99

Please sign in to comment.