Source code for archai.supergraph.nas.nas_utils
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
import tensorwatch as tw
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.nas.model import Model
from archai.supergraph.utils.checkpoint import CheckPoint
logger = get_global_logger()
[docs]def checkpoint_empty(checkpoint:Optional[CheckPoint])->bool:
return checkpoint is None or checkpoint.is_empty()
[docs]def create_checkpoint(conf_checkpoint:Config, resume:bool)->Optional[CheckPoint]:
"""Creates checkpoint given its config. If resume is True then attempt is
made to load existing checkpoint otherwise an empty checkpoint is created.
"""
checkpoint = CheckPoint(conf_checkpoint, resume) \
if conf_checkpoint is not None else None
logger.info({'checkpoint_empty': checkpoint_empty(checkpoint),
'conf_checkpoint_none': conf_checkpoint is None, 'resume': resume,
'checkpoint_path': None if checkpoint is None else checkpoint.filepath})
return checkpoint
[docs]def get_model_stats(model:Model,
input_tensor_shape=[1,3,32,32], clone_model=True)->tw.ModelStats:
# model stats is doing some hooks so do it last
model_stats = tw.ModelStats(model, input_tensor_shape,
clone_model=clone_model)
return model_stats