Utilities#
Augmented Searcher#
Augmented Trainer#
- archai.supergraph.utils.augmented_trainer.run_epoch(conf, logger, model: Module, loader, loss_fn, optimizer, split_type: str, epoch=0, verbose=1, scheduler=None)[source]#
Runs epoch for given dataloader and model. If optimizer is supplied then backprop and model update is done as well. This can be called from test to train modes.
Checkpoint#
- class archai.supergraph.utils.checkpoint.CheckPoint(conf_checkpoint: Config, load_existing: bool)[source]#
Callback based checkpoint model.
Start new checkpoint by calling new() and save it by calling commit(). This class is also dictionary. Items that needs be saved can be done so by setting key, value pairs after new(). As any dictionary key is set, checkpoint becomes dirty. On commit(), dictionary is saved and emptied. Invariant: checkpoint remains dirty until commit() is called.
Heatmap#
- archai.supergraph.utils.heatmap.heatmap(data: array, ax: Axes | None = None, xtick_labels: List[str] | None = None, ytick_labels: List[str] | None = None, cbar_kwargs: Dict[str, Any] | None = None, cbar_label: str | None = None, fmt: str | None = '{x:.2f}', **kwargs) None [source]#
Plot a heatmap.
- Parameters:
data – Data to plot.
ax – Axis to plot on.
xtick_labels – Labels for the x-axis.
ytick_labels – Labels for the y-axis.
cbar_kwargs – Keyword arguments to pass to the color bar.
cbar_label – Label for the color bar.
fmt – Format of the annotations.
Metrics#
- class archai.supergraph.utils.metrics.Metrics(title: str, apex: ApexUtils | None, logger_freq: int = 50)[source]#
Record top1, top5, loss metrics, track best so far.
There are 3 levels of metrics: 1. Run level - these for the one call of ‘fit’, example, best top1 2. Epoch level - these are the averages maintained top1, top5, loss 3. Step level - these are for every step in epoch
The pre_run must be called before fit call which will reset all metrics. Similarly pre_epoch will reset running averages and pre_step will reset step level metrics like average step time.
The post_step will simply update the running averages while post_epoch updates best we have seen for each epoch.
- cur_epoch() EpochMetrics [source]#
- class archai.supergraph.utils.metrics.EpochMetrics(index: int)[source]#
Stores the metrics for each epoch. Training metrics is in top1, top5 etc while validation metrics is in val_metrics
- class archai.supergraph.utils.metrics.RunMetrics[source]#
Metrics for the entire run. It mainly consist of metrics for each epoch
- add_epoch() EpochMetrics [source]#
- cur_epoch() EpochMetrics [source]#
- best_epoch() Tuple[EpochMetrics, EpochMetrics | None, EpochMetrics | None] [source]#
Multi-Optimizer#
- class archai.supergraph.utils.multi_optim.OptimSched(optim: Optimizer, sched: _LRScheduler | None, sched_on_epoch: bool | None)[source]#
Holds the optimizer and scheduler
- class archai.supergraph.utils.multi_optim.MultiOptim[source]#
- append(optim_sched: OptimSched) None [source]#
Tester#
Trainer#
- class archai.supergraph.utils.trainer.Trainer(conf_train: Config, model: Module, checkpoint: CheckPoint | None = None)[source]#
- fit(data_loaders: DataLoaders) Metrics [source]#
- create_multi_optim(train_len: int) MultiOptim [source]#
- create_scheduler(conf_sched: Config, optim: Optimizer, steps_per_epoch: int) Tuple[_LRScheduler | None, bool] [source]#
- pre_fit(data_loaders: DataLoaders) None [source]#
- post_fit(data_loaders: DataLoaders) None [source]#
- pre_epoch(data_loaders: DataLoaders) None [source]#
- post_epoch(data_loaders: DataLoaders) None [source]#
- update_checkpoint(checkpoint: CheckPoint) None [source]#