Source code for archai.supergraph.utils.tester

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Tuple

import torch
from overrides import EnforceOverrides
from torch import Tensor, nn
from torch.utils.data import DataLoader

from archai.common import ml_utils
from archai.common.apex_utils import ApexUtils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.utils.metrics import Metrics

logger = get_global_logger()


[docs]class Tester(EnforceOverrides): def __init__(self, conf_val:Config, model:nn.Module, apex:ApexUtils)->None: self._title = conf_val['title'] self._logger_freq = conf_val['logger_freq'] conf_lossfn = conf_val['lossfn'] self.batch_chunks = conf_val['batch_chunks'] self._apex = apex self.model = model self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device) self._metrics = None
[docs] def test(self, test_dl: DataLoader)->Metrics: logger.pushd(self._title) self._metrics = self._create_metrics() # recreate metrics for this run self._pre_test() self._test_epoch(test_dl) self._post_test() logger.popd() return self.get_metrics() # type: ignore
def _test_epoch(self, test_dl: DataLoader)->None: self._metrics.pre_epoch() self.model.eval() steps = len(test_dl) with torch.no_grad(), logger.pushd('steps'): for step, (x, y) in enumerate(test_dl): # derived class might alter the mode through pre/post hooks assert not self.model.training logger.pushd(step) self._pre_step(x, y, self._metrics) # pyright: ignore[reportGeneralTypeIssues] # divide batch in to chunks if needed so it fits in GPU RAM if self.batch_chunks > 1: x_chunks, y_chunks = torch.chunk(x, self.batch_chunks), torch.chunk(y, self.batch_chunks) else: x_chunks, y_chunks = (x,), (y,) logits_chunks = [] loss_sum, loss_count = 0.0, 0 for xc, yc in zip(x_chunks, y_chunks): xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True) logits_c = self.model(xc) tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2 if tupled_out: logits_c = logits_c[0] loss_c = self._lossfn(logits_c, yc) loss_sum += loss_c.item() * len(logits_c) loss_count += len(logits_c) logits_chunks.append(logits_c.detach().cpu()) # pyright: ignore[reportGeneralTypeIssues] self._post_step(x, y, ml_utils.join_chunks(logits_chunks), torch.tensor(loss_sum/loss_count), steps, self._metrics) # pyright: ignore[reportGeneralTypeIssues] # TODO: we possibly need to sync so all replicas are upto date self._apex.sync_devices() logger.popd() self._metrics.post_epoch() # no "val" dataset for the test phase
[docs] def get_metrics(self)->Optional[Metrics]: return self._metrics
[docs] def state_dict(self)->dict: return { 'metrics': self._metrics.state_dict() }
[docs] def get_device(self): return self._apex.device
[docs] def load_state_dict(self, state_dict:dict)->None: self._metrics.load_state_dict(state_dict['metrics'])
def _pre_test(self)->None: self._metrics.pre_run() def _post_test(self)->None: self._metrics.post_run() def _pre_step(self, x:Tensor, y:Tensor, metrics:Metrics)->None: metrics.pre_step(x, y) def _post_step(self, x:Tensor, y:Tensor, logits:Tensor, loss:Tensor, steps:int, metrics:Metrics)->None: metrics.post_step(x, y, logits, loss, steps) def _create_metrics(self)->Metrics: return Metrics(self._title, self._apex, logger_freq=self._logger_freq)