Computer Vision#

PyTorch-Lightning#

Trainer#

class archai.trainers.cv.pl_trainer.PlTrainer(*, accelerator: str | Accelerator = 'auto', strategy: str | Strategy = 'auto', devices: List[int] | str | int = 'auto', num_nodes: int = 1, precision: Literal[64, 32, 16] | Literal['16-mixed', 'bf16-mixed', '32-true', '64-true'] | Literal['64', '32', '16', 'bf16'] = '32-true', logger: Logger | Iterable[Logger] | bool | None = None, callbacks: List[Callback] | Callback | None = None, fast_dev_run: int | bool = False, max_epochs: int | None = None, min_epochs: int | None = None, max_steps: int = -1, min_steps: int | None = None, max_time: str | timedelta | Dict[str, int] | None = None, limit_train_batches: int | float | None = None, limit_val_batches: int | float | None = None, limit_test_batches: int | float | None = None, limit_predict_batches: int | float | None = None, overfit_batches: int | float = 0.0, val_check_interval: int | float | None = None, check_val_every_n_epoch: int | None = 1, num_sanity_val_steps: int | None = None, log_every_n_steps: int | None = None, enable_checkpointing: bool | None = None, enable_progress_bar: bool | None = None, enable_model_summary: bool | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: int | float | None = None, gradient_clip_algorithm: str | None = None, deterministic: bool | Literal['warn'] | None = None, benchmark: bool | None = None, inference_mode: bool = True, use_distributed_sampler: bool = True, profiler: Profiler | str | None = None, detect_anomaly: bool = False, barebones: bool = False, plugins: PrecisionPlugin | ClusterEnvironment | CheckpointIO | LayerSync | str | List[PrecisionPlugin | ClusterEnvironment | CheckpointIO | LayerSync | str] | None = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: str | Path | None = None)[source]#

PyTorch-Lightning trainer.

train(model: LightningModule, train_dataloaders: Any | LightningDataModule | None = None, val_dataloaders: Any | None = None, datamodule: LightningDataModule | None = None, ckpt_path: str | None = None) None[source]#

Train a model.

This method should contain the logic for training the model.

evaluate(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, ckpt_path: str | None = None, verbose: bool | None = True, datamodule: LightningDataModule | None = None) List[Dict[str, float]][source]#

Evaluate a model.

This method should contain the logic for evaluating the model.

predict(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, datamodule: LightningDataModule | None = None, return_predictions: bool | None = None, ckpt_path: str | None = None) List[Any] | List[List[Any]] | None[source]#

Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks.

Parameters:
  • model – The model to predict with.

  • dataloaders – An iterable or collection of iterables specifying predict samples. Alternatively, a LightningDataModule that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader hook.

  • datamodule – A LightningDataModule that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader hook.

  • return_predictions – Whether to return predictions. True by default except when an accelerator that spawns processes is used (not supported).

  • ckpt_path – Either "best", "last", "hpc" or path to the checkpoint you wish to predict. If None and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous trainer.fit call will be loaded if a checkpoint callback is configured.

For more information about multiple dataloaders, see this section.

Returns:

Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.

See Lightning inference section for more.