Computer Vision#
PyTorch-Lightning#
Trainer#
- class archai.trainers.cv.pl_trainer.PlTrainer(*, accelerator: str | Accelerator = 'auto', strategy: str | Strategy = 'auto', devices: List[int] | str | int = 'auto', num_nodes: int = 1, precision: Literal[64, 32, 16] | Literal['16-mixed', 'bf16-mixed', '32-true', '64-true'] | Literal['64', '32', '16', 'bf16'] = '32-true', logger: Logger | Iterable[Logger] | bool | None = None, callbacks: List[Callback] | Callback | None = None, fast_dev_run: int | bool = False, max_epochs: int | None = None, min_epochs: int | None = None, max_steps: int = -1, min_steps: int | None = None, max_time: str | timedelta | Dict[str, int] | None = None, limit_train_batches: int | float | None = None, limit_val_batches: int | float | None = None, limit_test_batches: int | float | None = None, limit_predict_batches: int | float | None = None, overfit_batches: int | float = 0.0, val_check_interval: int | float | None = None, check_val_every_n_epoch: int | None = 1, num_sanity_val_steps: int | None = None, log_every_n_steps: int | None = None, enable_checkpointing: bool | None = None, enable_progress_bar: bool | None = None, enable_model_summary: bool | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: int | float | None = None, gradient_clip_algorithm: str | None = None, deterministic: bool | Literal['warn'] | None = None, benchmark: bool | None = None, inference_mode: bool = True, use_distributed_sampler: bool = True, profiler: Profiler | str | None = None, detect_anomaly: bool = False, barebones: bool = False, plugins: PrecisionPlugin | ClusterEnvironment | CheckpointIO | LayerSync | str | List[PrecisionPlugin | ClusterEnvironment | CheckpointIO | LayerSync | str] | None = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: str | Path | None = None)[source]#
PyTorch-Lightning trainer.
- train(model: LightningModule, train_dataloaders: Any | LightningDataModule | None = None, val_dataloaders: Any | None = None, datamodule: LightningDataModule | None = None, ckpt_path: str | None = None) None [source]#
Train a model.
This method should contain the logic for training the model.
- evaluate(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, ckpt_path: str | None = None, verbose: bool | None = True, datamodule: LightningDataModule | None = None) List[Dict[str, float]] [source]#
Evaluate a model.
This method should contain the logic for evaluating the model.
- predict(model: LightningModule | None = None, dataloaders: Any | LightningDataModule | None = None, datamodule: LightningDataModule | None = None, return_predictions: bool | None = None, ckpt_path: str | None = None) List[Any] | List[List[Any]] | None [source]#
Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks.
- Parameters:
model – The model to predict with.
dataloaders – An iterable or collection of iterables specifying predict samples. Alternatively, a
LightningDataModule
that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader hook.datamodule – A
LightningDataModule
that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader hook.return_predictions – Whether to return predictions.
True
by default except when an accelerator that spawns processes is used (not supported).ckpt_path – Either
"best"
,"last"
,"hpc"
or path to the checkpoint you wish to predict. IfNone
and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previoustrainer.fit
call will be loaded if a checkpoint callback is configured.
For more information about multiple dataloaders, see this section.
- Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
See Lightning inference section for more.