diff --git a/launch.py b/launch.py index 5c8d305e..b5a77fb9 100644 --- a/launch.py +++ b/launch.py @@ -1,3 +1,4 @@ +import sys import argparse import os import time @@ -87,13 +88,20 @@ def main(): TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name), CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs') ] - + + if sys.platform == 'win32': + # does not support multi-gpu on windows + strategy = 'dp' + assert n_gpus == 1 + else: + strategy = 'ddp_find_unused_parameters_false' + trainer = Trainer( devices=n_gpus, accelerator='gpu', callbacks=callbacks, logger=loggers, - strategy='ddp_find_unused_parameters_false', + strategy=strategy, **config.trainer ) diff --git a/models/network_utils.py b/models/network_utils.py index 38b70453..f5bafbbf 100644 --- a/models/network_utils.py +++ b/models/network_utils.py @@ -8,6 +8,7 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_debug, _get_rank from utils.misc import config_to_primitive +from models.utils import get_activation @@ -69,9 +70,12 @@ def __init__(self, dim_in, dim_out, config): self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] self.layers = nn.Sequential(*self.layers) + self.output_activation = get_activation(config['output_activation']) def forward(self, x): - return self.layers(x.float()) + x = self.layers(x.float()) + x = self.output_activation(x) + return x def make_linear(self, dim_in, dim_out, is_first, is_last): layer = nn.Linear(dim_in, dim_out, bias=False) diff --git a/models/utils.py b/models/utils.py index 3918bd60..5dacbf65 100644 --- a/models/utils.py +++ b/models/utils.py @@ -67,7 +67,8 @@ def backward(ctx, g): # pylint: disable=arguments-differ def get_activation(name): - if name is None or name in ['none', 'None']: + name = name.lower() + if name is None or name == 'none': return nn.Identity() elif name.startswith('scale'): scale_factor = float(name[5:]) diff --git a/systems/nerf.py b/systems/nerf.py index 61a10a06..44a81274 100644 --- a/systems/nerf.py +++ b/systems/nerf.py @@ -152,8 +152,13 @@ def validation_epoch_end(self, out): if self.trainer.is_global_zero: out_set = {} for step_out in out: - for oi, index in enumerate(step_out['index']): - out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) @@ -180,8 +185,13 @@ def test_epoch_end(self, out): if self.trainer.is_global_zero: out_set = {} for step_out in out: - for oi, index in enumerate(step_out['index']): - out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) diff --git a/systems/neus.py b/systems/neus.py index 7bb2c268..a609c9ad 100644 --- a/systems/neus.py +++ b/systems/neus.py @@ -163,8 +163,13 @@ def validation_epoch_end(self, out): if self.trainer.is_global_zero: out_set = {} for step_out in out: - for oi, index in enumerate(step_out['index']): - out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) @@ -195,8 +200,13 @@ def test_epoch_end(self, out): if self.trainer.is_global_zero: out_set = {} for step_out in out: - for oi, index in enumerate(step_out['index']): - out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True)