paraphernalia.torch.lightning module¶
Tools for working with PyTorch Lightning.
- class ImageCheckpoint(path_template, video_path=None, interval=50, preview=True)[source]¶
A PyTorch Lightning callback for saving and previewing images.
Image batches (b, c, h, w) should be generated by module.forward().
path_template can draw on the following variables:
index: the index of the image in the provided batch
model: the Lightning model
trainer: the Lightning trainer
- Parameters
path_template (str) – a path template as described above
interval (int, optional) – the checkpoint interval
preview (bool, optional) – if true display an ipywidget preview panel. Defaults to True.
video_path (str) –
- save(batch, trainer, module)[source]¶
Save the image batch.
- Parameters
batch (torch.Tensor) –
trainer (pytorch_lightning.trainer.trainer.Trainer) –
module (pytorch_lightning.core.lightning.LightningModule) –
- preview(batch)[source]¶
Preview the image batch if configured, otherwise do nothing.
- Parameters
batch (torch.Tensor) –
- checkpoint(trainer, module)[source]¶
Main checkpoint function, called on epoch start and training end.
- Parameters
trainer (pytorch_lightning.trainer.trainer.Trainer) –
module (pytorch_lightning.core.lightning.LightningModule) –
- on_batch_end(trainer, module)[source]¶
Called at the end of each batch.
Checkpoints if a multiple of self.interval.
- Parameters
trainer (pytorch_lightning.trainer.trainer.Trainer) –
module (pytorch_lightning.core.lightning.LightningModule) –
- Return type
None