Source code for archai.common.apex_utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import contextlib
import os

import psutil
import ray
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.backends import cudnn
from torch.cuda.amp import GradScaler
from torch.nn import SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer

from archai.common import ml_utils, utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.utils.multi_optim import MultiOptim

logger = get_global_logger()


[docs]class ApexUtils: def __init__(self, apex_config:Config)->None: # region conf vars self._enabled = apex_config['enabled'] # global switch to disable anything apex self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode self._mixed_prec_enabled = apex_config['mixed_prec_enabled'] # enable/disable distributed mode # torch.amp has default 'O1' optimization level and cannot be configured further # torch.amp keeps BN in fp32 # There is no loss_scale option in torch.amp self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting seed = apex_config['seed'] detect_anomaly = apex_config['detect_anomaly'] conf_gpu_ids = apex_config['gpus'] conf_ray = apex_config['ray'] self.ray_enabled = conf_ray['enabled'] self.ray_local_mode = conf_ray['local_mode'] # endregion self._scaler = None self._set_ranks(conf_gpu_ids) #_log_info({'apex_config': apex_config.to_dict()}) self._log_info({'ray.enabled': self.is_ray(), 'apex.enabled': self._enabled}) self._log_info({'torch.distributed.is_available': dist.is_available(), 'apex.distributed_enabled': self._distributed_enabled, 'apex.mixed_prec_enabled': self._mixed_prec_enabled}) if dist.is_available(): # dist.* properties are otherwise not accessible self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM, 'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX} self._log_info({'gloo_available': dist.is_gloo_available(), 'mpi_available': dist.is_mpi_available(), 'nccl_available': dist.is_nccl_available()}) if self.is_mixed(): # init enable mixed precision assert cudnn.enabled, "Amp requires cudnn backend to be enabled." self._scaler = GradScaler() # enable distributed processing if self.is_dist(): assert not self.is_ray(), "Ray is not yet enabled for Apex distributed mode" assert dist.is_available() # distributed module is available assert dist.is_nccl_available() if not dist.is_initialized(): dist.init_process_group(backend='nccl', init_method='env://') assert dist.is_initialized() assert dist.get_world_size() == self.world_size assert dist.get_rank() == self.global_rank if self.is_ray(): assert not self.is_dist(), "Ray is not yet enabled for Apex distributed mode" if not ray.is_initialized(): ray.init(local_mode=self.ray_local_mode, include_dashboard=False, # for some reason Ray is detecting wrong number of GPUs num_gpus=torch.cuda.device_count()) ray_cpus = ray.nodes()[0]['Resources']['CPU'] ray_gpus = ray.nodes()[0]['Resources']['GPU'] self._log_info({'ray_cpus': ray_cpus, 'ray_gpus':ray_gpus}) assert self.world_size >= 1 assert not self._min_world_size or self.world_size >= self._min_world_size assert self.local_rank >= 0 and self.local_rank < self.world_size assert self.global_rank >= 0 and self.global_rank < self.world_size assert self._gpu < torch.cuda.device_count() torch.cuda.set_device(self._gpu) self.device = torch.device('cuda', self._gpu) self._setup_gpus(seed, detect_anomaly) self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False, 'world_size': self.world_size, 'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'local_rank': self.local_rank, 'global_rank': self.global_rank}) def _setup_gpus(self, seed:float, detect_anomaly:bool): utils.setup_cuda(seed, local_rank=self.local_rank) torch.autograd.set_detect_anomaly(detect_anomaly) self._log_info({'set_detect_anomaly': detect_anomaly, 'is_anomaly_enabled': torch.is_anomaly_enabled()}) self._log_info({'gpu_names': utils.cuda_device_names(), 'gpu_count': torch.cuda.device_count(), 'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet', 'cudnn.enabled': cudnn.enabled, 'cudnn.benchmark': cudnn.benchmark, 'cudnn.deterministic': cudnn.deterministic, 'cudnn.version': cudnn.version() }) self._log_info({'memory': str(psutil.virtual_memory())}) self._log_info({'CPUs': str(psutil.cpu_count())}) # gpu_usage = os.popen( # 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader' # ).read().split('\n') # for i, line in enumerate(gpu_usage): # vals = line.split(',') # if len(vals) == 2: # _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1])) def _set_ranks(self, conf_gpu_ids:str)->None: # this function needs to work even when torch.distributed is not available if 'WORLD_SIZE' in os.environ: self.world_size = int(os.environ['WORLD_SIZE']) else: self.world_size = 1 if 'LOCAL_RANK' in os.environ: self.local_rank = int(os.environ['LOCAL_RANK']) else: self.local_rank = 0 if 'RANK' in os.environ: self.global_rank = int(os.environ['RANK']) else: self.global_rank = 0 assert self.local_rank < torch.cuda.device_count(), \ f'local_rank={self.local_rank} but device_count={torch.cuda.device_count()}' \ ' Possible cause may be Pytorch is not GPU enabled or you have too few GPUs' self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] # which GPU to use, we will use only 1 GPU per process to avoid complications with apex # remap if GPU IDs are specified if len(self.gpu_ids): assert len(self.gpu_ids) > self.local_rank self._gpu = self.gpu_ids[self.local_rank] else: self._gpu = self.local_rank % torch.cuda.device_count()
[docs] def is_mixed(self)->bool: return self._enabled and self._mixed_prec_enabled
[docs] def is_dist(self)->bool: return self._enabled and self._distributed_enabled and self.world_size > 1
[docs] def is_master(self)->bool: return self.global_rank == 0
[docs] def is_ray(self)->bool: return self.ray_enabled
def _log_info(self, d:dict)->None: if logger is not None: logger.info(d, override_key=True)
[docs] def sync_devices(self)->None: if self.is_dist(): torch.cuda.synchronize(self.device)
[docs] def barrier(self)->None: if self.is_dist(): dist.barrier() # wait for all processes to come to this point
[docs] def reduce(self, val, op='mean'): if self.is_dist(): if not isinstance(val, Tensor): rt = torch.tensor(val).to(self.device) converted = True else: rt = val.clone().to(self.device) converted = False r_op = self._op_map[op] dist.all_reduce(rt, op=r_op) if op=='mean': rt /= self.world_size if converted and len(rt.shape)==0: return rt.item() return rt else: return val
def _get_one_optim(self, multi_optim:MultiOptim)->Optimizer: assert len(multi_optim)==1, \ 'Mixed precision is only supported for one optimizer' \ f' but {len(multi_optim)} optimizers were supplied' return multi_optim[0].optim
[docs] def backward(self, loss:torch.Tensor)->None: if self.is_mixed(): self._scaler.scale(loss).backward() # pyright: ignore[reportGeneralTypeIssues, reportOptionalMemberAccess] else: loss.backward()
[docs] def autocast(self): if self.is_mixed(): return torch.cuda.amp.autocast() else: return contextlib.nullcontext()
[docs] def step(self, multi_optim:MultiOptim)->None: if self.is_mixed(): # self._scaler.unscale_ will be called automatically if it isn't called yet from grad clipping # https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.step for optim_shed in multi_optim: self._scaler.step(optim_shed.optim) # pyright: ignore[reportOptionalMemberAccess] self._scaler.update() # pyright: ignore[reportOptionalMemberAccess] else: multi_optim.step()
[docs] def to_amp(self, model:nn.Module, multi_optim:MultiOptim, batch_size:int)\ ->nn.Module: # conver BNs to sync BNs in distributed mode if self.is_dist() and self._sync_bn: model = SyncBatchNorm.convert_sync_batchnorm(model) self._log_info({'BNs_converted': True}) model = model.to(self.device) # scale LR if self.is_dist() and self._scale_lr: for optim_shed in multi_optim: optim = optim_shed.optim lr = ml_utils.get_optim_lr(optim) scaled_lr = lr * self.world_size / float(batch_size) ml_utils.set_optim_lr(optim, scaled_lr) self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr}) if self.is_dist(): model = DistributedDataParallel(model, device_ids=[self._gpu], output_device=self._gpu) return model
[docs] def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None: if clip > 0.0: if self.is_mixed(): # https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers self._scaler.unscale_(multi_optim[0].optim) # pyright: ignore[reportOptionalMemberAccess] nn.utils.clip_grad_norm_(model.parameters(), clip) else: nn.utils.clip_grad_norm_(model.parameters(), clip)
[docs] def state_dict(self): if self.is_mixed(): return self._scaler.state_dict() # pyright: ignore[reportOptionalMemberAccess] else: return None
[docs] def load_state_dict(self, state_dict): if self.is_mixed(): self._scaler.load_state_dict(state_dict) # pyright: ignore[reportOptionalMemberAccess]