diff --git a/hubconf.py b/hubconf.py index 1031f6b..cd42923 100644 --- a/hubconf.py +++ b/hubconf.py @@ -10,13 +10,14 @@ } -def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True): +def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True, device='cuda'): params = model_params[model_name] model = Generator(params['mel_channel']) if pretrained: state_dict = torch.hub.load_state_dict_from_url(params['model_url'], - progress=progress) + progress=progress, + map_location=torch.device(device)) model.load_state_dict(state_dict['model_g']) model.eval(inference=True)