Skip to content

Commit

Permalink
Branch was auto-updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] authored Jan 27, 2021
2 parents 4aa67de + 3da28fd commit 6820889
Show file tree
Hide file tree
Showing 28 changed files with 354 additions and 59 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))


- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))


Expand Down Expand Up @@ -68,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464))


- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579))


- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))


Expand Down Expand Up @@ -120,7 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Removed deprecated `TrainResult` ([#5323](https://github.com/PyTorchLightning/pytorch-lightning/pull/5323))


- Removed deprecated `EvalResult` ([#5633](https://github.com/PyTorchLightning/pytorch-lightning/pull/5633))

Expand Down Expand Up @@ -155,7 +159,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))
- Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417))
- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406))
- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406))
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))


Expand Down
24 changes: 23 additions & 1 deletion docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,30 @@ Or maybe we have a model that we use to do generation
z = sample_noise()
generated_imgs = model(z)
How you split up what goes in ``forward`` vs ``training_step`` depends on how you want to use this model for
To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function
By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic.

.. code-block:: python
class LitMNISTDreamer(LightningModule):
def forward(self, z):
imgs = self.decoder(z)
return imgs
def predict(self, batch, batch_idx: int , dataloader_idx: int = None):
return self(batch)
model = LitMNISTDreamer()
trainer.predict(model, datamodule)
How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for
prediction.
However, we recommend ``forward`` to contain only tensor operation with your model, ``training_step`` to encapsulate ``forward`` logic with logging,
metrics and loss computation and ``predict`` to encapsulate ``forward`` with preprocess, postprocess functions.

----------------

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def predict(self, args):
return self._step(self.trainer.model.predict, args)

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(args)

def predict(self, args):
return self._step(args)

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def predict(self, args):
return self._step(self.trainer.model.predict, args)

def to_device(self, batch):
gpu_id = 0
if isinstance(self.trainer.data_parallel_device_ids, list):
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def predict(self, args):
return self._step(self.trainer.model.predict, args)

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
optimizer.synchronize()
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def validation_step(self, args):
def test_step(self, args):
return self._step(self.trainer.model.test_step, args)

def predict(self, args):
return self._step(self.trainer.model.predict, args)

def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ def init_validation_tqdm(self) -> tqdm:
)
return bar

def init_test_tqdm(self) -> tqdm:
def init_test_tqdm(self, trainer=None) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
bar = tqdm(
desc='Testing',
desc=desc,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
Expand Down Expand Up @@ -361,7 +363,7 @@ def on_train_end(self, trainer, pl_module):

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class LightningModule(
"on_gpu",
"current_epoch",
"global_step",
"running_stage",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, *args, **kwargs):
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self.running_stage = None
self._automatic_optimization: bool = True

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
Expand Down Expand Up @@ -982,6 +984,12 @@ def test_epoch_end(self, outputs):
self.log('final_metric', final_value)
"""

def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None):
"""
Use this function with trainer.predict(...). Override if you need to add any processing logic.
"""
return self(batch)

def configure_optimizers(
self,
):
Expand Down
44 changes: 34 additions & 10 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache


Expand Down Expand Up @@ -78,14 +79,22 @@ def forward(self, *inputs, **kwargs):
"them on device: {}".format(self.src_device_obj, t.device))

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)

if len(self.device_ids) == 1:
# lightning
if self.module.training:

running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
return self.module.training_step(*inputs[0], **kwargs[0])
if self.module.testing:

elif running_stage == RunningStage.TESTING:
return self.module.test_step(*inputs[0], **kwargs[0])

return self.module.validation_step(*inputs[0], **kwargs[0])
elif running_stage == RunningStage.EVALUATING:
return self.module.validation_step(*inputs[0], **kwargs[0])

else:
return self.module.predict(*inputs[0], **kwargs[0])

replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
Expand Down Expand Up @@ -187,15 +196,24 @@ def __init__(self, pl_module: LightningModule):
self.module = pl_module

def forward(self, *inputs, **kwargs):
if self.module.training:

running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif self.module.testing:

elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
else:

elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

else:
output = self.module.predict(*inputs, **kwargs)

return output


Expand Down Expand Up @@ -276,16 +294,22 @@ def _worker(i, module, input, kwargs, device=None):

# ---------------
# CHANGE
if module.training:
if module.running_stage == RunningStage.TRAINING:
output = module.training_step(*input, **kwargs)
fx_called = 'training_step'
elif module.testing:

elif module.running_stage == RunningStage.TESTING:
output = module.test_step(*input, **kwargs)
fx_called = 'test_step'
else:

elif module.running_stage == RunningStage.EVALUATING:
output = module.validation_step(*input, **kwargs)
fx_called = 'validation_step'

else:
output = module.predict(*input, **kwargs)
fx_called = 'predict'

if output is None:
warn_missing_output(fx_called)

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE

LightningShardedDataParallel = None
Expand All @@ -23,10 +24,18 @@ def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()

if self.module.training:
running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
outputs = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:

elif running_stage == RunningStage.TESTING:
outputs = self.module.test_step(*inputs, **kwargs)
else:

elif running_stage == RunningStage.EVALUATING:
outputs = self.module.validation_step(*inputs, **kwargs)

else:
outputs = self.module.predict(*inputs, **kwargs)

return outputs
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model):
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader:
if not has_train_dataloader and not self.trainer._predicting:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
Expand All @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model):
# verify model has optimizer
# -----------------------------------
has_optimizers = is_overridden('configure_optimizers', model)
if not has_optimizers:
if not has_optimizers and not self.trainer._predicting:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DistributedType, LightningEnum


class LoggerStages(LightningEnum):
""" Train/validation/test phase in each training step.
>>> # you can math the type with string
>>> LoggerStages.TRAIN == 'train'
True
Expand Down Expand Up @@ -371,7 +371,7 @@ def update_logger_connector(self) -> None:
callback_metrics = {}
batch_pbar_metrics = {}
batch_log_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value
is_train = self._stage in RunningStage.TRAINING

if not self._has_batch_loop_finished:
# get pbar
Expand Down
Loading

0 comments on commit 6820889

Please sign in to comment.