paraphernalia.torch.lightning module

Tools for working with PyTorch Lightning.

class ImageCheckpoint(path_template, video_path=None, interval=50, preview=True)[source]

A PyTorch Lightning callback for saving and previewing images.

Image batches (b, c, h, w) should be generated by module.forward().

path_template can draw on the following variables:

  • index: the index of the image in the provided batch

  • model: the Lightning model

  • trainer: the Lightning trainer

Parameters
  • path_template (str) – a path template as described above

  • interval (int, optional) – the checkpoint interval

  • preview (bool, optional) – if true display an ipywidget preview panel. Defaults to True.

  • video_path (str) –

save(batch, trainer, module)[source]

Save the image batch.

Parameters
  • batch (torch.Tensor) –

  • trainer (pytorch_lightning.trainer.trainer.Trainer) –

  • module (pytorch_lightning.core.lightning.LightningModule) –

save_frame(batch)[source]
Parameters

batch (torch.Tensor) –

preview(batch)[source]

Preview the image batch if configured, otherwise do nothing.

Parameters

batch (torch.Tensor) –

checkpoint(trainer, module)[source]

Main checkpoint function, called on epoch start and training end.

Parameters
  • trainer (pytorch_lightning.trainer.trainer.Trainer) –

  • module (pytorch_lightning.core.lightning.LightningModule) –

on_batch_end(trainer, module)[source]

Called at the end of each batch.

Checkpoints if a multiple of self.interval.

Parameters
  • trainer (pytorch_lightning.trainer.trainer.Trainer) –

  • module (pytorch_lightning.core.lightning.LightningModule) –

Return type

None

on_epoch_start(trainer, module)[source]

Called at the start of an epoch.

Always checkpoints.

Parameters
  • trainer (pytorch_lightning.trainer.trainer.Trainer) –

  • module (pytorch_lightning.core.lightning.LightningModule) –

on_train_end(trainer, module)[source]

Called when training ends.

Always checkpoints.

Parameters
  • trainer (pytorch_lightning.trainer.trainer.Trainer) –

  • module (pytorch_lightning.core.lightning.LightningModule) –