# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
from collections import OrderedDict
from numbers import Number
from typing import Iterable, Mapping, Sequence

import torch
import torch.nn as nn

[docs]def summary(model, input_size): result, params_info = summary_string(model, input_size) print(result) return params_info
[docs]def is_scaler(o): return isinstance(o, Number) or isinstance(o, str) or o is None
[docs]def get_tensor_stat(tensor): assert isinstance(tensor, torch.Tensor) # some pytorch low-level memory management constant # the minimal allocate memory size (Byte) PYTORCH_MIN_ALLOCATE = 2**9 # the minimal cache memory size (Byte) # PYTORCH_MIN_CACHE = 2**20 numel = tensor.numel() element_size = tensor.element_size() fact_numel = fact_memory_size = fact_numel * element_size # since pytorch allocate at least 512 Bytes for any tensor, round # up to a multiple of 512 memory_size = math.ceil(fact_memory_size / PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE # should be the actual object related to memory # allocation # data_ptr = size = tuple(tensor.size()) # torch scalar has empty size if not size: size = (1,) return ([size], numel, memory_size)
[docs]def get_all_tensor_stats(o): if is_scaler(o): return ([[]], 0, 0) elif isinstance(o, torch.Tensor): return get_tensor_stat(o) elif isinstance(o, Mapping): return get_all_tensor_stats(o.values()) elif isinstance(o, Iterable): # tuple, list, maps stats = [[]], 0, 0 for oi in o: tz = get_all_tensor_stats(oi) stats = tuple(x + y for x, y in zip(stats, tz)) return stats elif hasattr(o, "__dict__"): return get_all_tensor_stats(o.__dict__) else: return ([[]], 0, 0)
[docs]def get_shape(o): if is_scaler(o): return str(o) elif hasattr(o, "shape"): return f"shape{o.shape}" elif hasattr(o, "size"): return f"size{o.size()}" elif isinstance(o, Sequence): if len(o) == 0: return "seq[]" elif is_scaler(o[0]): return f"seq[{len(o)}]" return f"seq{[get_shape(oi) for oi in o]}" elif isinstance(o, Mapping): if len(o) == 0: return "map[]" elif is_scaler(next(o)): return f"map[{len(o)}]" arr = [(get_shape(ki), get_shape(vi)) for ki, vi in o] return f"map{arr}" else: return "N/A"
[docs]def summary_string(model, input_size, dtype=torch.float32): summary_str = "" # create properties summary = OrderedDict() hooks = [] def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] module_idx = len(summary) m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() summary[m_key]["input"] = get_all_tensor_stats(input) summary[m_key]["output"] = get_all_tensor_stats(output) params = 0 if hasattr(module, "weight") and hasattr(module.weight, "size"): params += summary[m_key]["trainable"] = module.weight.requires_grad if hasattr(module, "bias") and hasattr(module.bias, "size"): params += summary[m_key]["nb_params"] = params if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList): hooks.append(module.register_forward_hook(hook)) # batch_size of 2 for batchnorm x = torch.rand(input_size, dtype=dtype, device=next(model.parameters()).device) # register hook model.apply(register_hook) # make a forward pass # print(x.shape) model(x) # remove these hooks for h in hooks: h.remove() summary_str += "----------------------------------------------------------------" + "\n" line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output (elments, mem)", "Param #") summary_str += line_new + "\n" summary_str += "================================================================" + "\n" total_params = 0 total_input = get_tensor_stat(x) total_output = [[], 0, 0] trainable_params = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params line_new = "{:>20} {:>25} {:>15}".format( layer, str(summary[layer]["output"][1:]), "{0:,}".format(summary[layer]["nb_params"]), ) total_params += summary[layer]["nb_params"] total_output = tuple(x + y for x, y in zip(total_output, summary[layer]["output"])) if "trainable" in summary[layer]: if summary[layer]["trainable"] is True: trainable_params += summary[layer]["nb_params"] summary_str += line_new + "\n" total_numel = total_params + total_output[1] + total_input[1] summary_str += "================================================================" + "\n" summary_str += "Total params: {0:,}".format(total_params) + "\n" summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" summary_str += "Non-trainable params: {0:,}".format(total_params - trainable_params) + "\n" summary_str += "----------------------------------------------------------------" + "\n" summary_str += f"Input Elments: {total_input[1]:.4e}\n" summary_str += f"Input Mem: {total_input[2]:.4e}\n" summary_str += f"Layer Output Elements: {total_output[1]:.4e}\n" summary_str += f"Layer Output Mem: {total_output[2]:.4e}\n" summary_str += f"Params {total_params:.4e}\n" summary_str += f"Total Elements {total_numel:.4e}\n" summary_str += "----------------------------------------------------------------" + "\n" # return summary return summary_str, (total_params, trainable_params)