Source code for archai.discrete_search.evaluators.pt_profiler_utils.pt_profiler_hooks

# Copyright (c) DeepSpeed Team - Microsoft Corporation.
# Licensed under the MIT License.
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py

from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F

FLOPS = []
MACS = []
TORCH_FUNCTIONS = {}


def __shape_inner_product(dims: Tuple[int, ...]) -> int:
    p = 1
    for v in dims:
        p *= v

    return p


def _linear_hook(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> Tuple[int, int]:
    out_features = weight.shape[0]
    macs = torch.numel(input) * out_features

    return 2 * macs, macs


def _relu_hook(input: torch.Tensor, inplace: Optional[bool] = False) -> Tuple[int, int]:
    return torch.numel(input), 0


def _prelu_hook(input: torch.Tensor, weight: torch.Tensor) -> Tuple[int, int]:
    return torch.numel(input), 0


def _elu_hook(input: torch.Tensor, alpha: Optional[float] = 1.0, inplace: Optional[bool] = False) -> Tuple[int, int]:
    return torch.numel(input), 0


def _leakyrelu_hook(
    input: torch.Tensor, negative_slope: Optional[float] = 0.01, inplace: Optional[bool] = False
) -> Tuple[int, int]:
    return torch.numel(input), 0


def _relu6_hook(input: torch.Tensor, inplace: Optional[bool] = False) -> Tuple[int, int]:
    return torch.numel(input), 0


def _silu_hook(input: torch.Tensor, inplace: Optional[bool] = False) -> Tuple[int, int]:
    return torch.numel(input), 0


def _gelu_hook(input: torch.Tensor, approximate: str = "none") -> Tuple[int, int]:
    return torch.numel(input), 0


def _pool_hook(
    input: torch.Tensor,
    kernel_size: Union[int, Tuple[int, int]],
    stride: Optional[Union[int, Tuple[int, int]]] = None,
    padding: Optional[int] = 0,
    dilation: Optional[int] = None,
    ceil_mode: Optional[bool] = False,
    count_include_pad: Optional[bool] = True,
    divisor_override: Optional[int] = None,
    return_indices: Optional[bool] = None,
) -> Tuple[int, int]:
    return torch.numel(input), 0


def _conv_hook(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    stride: Optional[Union[int, Tuple[int, ...]]] = 1,
    padding: Optional[Union[int, str]] = 0,
    dilation: Optional[Union[int, Tuple[int, ...]]] = 1,
    groups: Optional[int] = 1,
) -> Tuple[int, int]:
    assert weight.shape[1] * groups == input.shape[1]

    batch_size = input.shape[0]

    in_channels = input.shape[1]
    out_channels = weight.shape[0]

    kernel_dims = list(weight.shape[2:])
    input_dims = list(input.shape[2:])

    length = len(input_dims)

    paddings = padding if type(padding) is tuple else (padding,) * length
    strides = stride if type(stride) is tuple else (stride,) * length
    dilations = dilation if type(dilation) is tuple else (dilation,) * length

    output_dims = []
    for idx, input_dim in enumerate(input_dims):
        output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] * (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
        output_dims.append(output_dim)

    filters_per_channel = out_channels // groups
    conv_per_position_macs = int(__shape_inner_product(kernel_dims)) * in_channels * filters_per_channel
    active_elements_count = batch_size * int(__shape_inner_product(output_dims))
    overall_conv_macs = conv_per_position_macs * active_elements_count
    overall_conv_flops = 2 * overall_conv_macs

    bias_flops = 0
    if bias is not None:
        bias_flops = out_channels * active_elements_count

    return int(overall_conv_flops + bias_flops), int(overall_conv_macs)


def _conv_transpose_hook(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    stride: Optional[Union[int, Tuple[int, ...]]] = 1,
    padding: Optional[Union[int, str]] = 0,
    output_padding: Optional[int] = 0,
    dilation: Optional[Union[int, Tuple[int, ...]]] = 1,
    groups: Optional[int] = 1,
) -> Tuple[int, int]:
    batch_size = input.shape[0]

    in_channels = input.shape[1]
    out_channels = weight.shape[0]

    kernel_dims = list(weight.shape[2:])
    input_dims = list(input.shape[2:])

    length = len(input_dims)

    paddings = padding if type(padding) is tuple else (padding,) * length
    strides = stride if type(stride) is tuple else (stride,) * length
    dilations = dilation if type(dilation) is tuple else (dilation,) * length

    output_dims = []
    for idx, input_dim in enumerate(input_dims):
        output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] * (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1
        output_dims.append(output_dim)

    paddings = padding if type(padding) is tuple else (padding, padding)
    strides = stride if type(stride) is tuple else (stride, stride)
    dilations = dilation if type(dilation) is tuple else (dilation, dilation)

    filters_per_channel = out_channels // groups
    conv_per_position_macs = int(__shape_inner_product(kernel_dims)) * in_channels * filters_per_channel
    active_elements_count = batch_size * int(__shape_inner_product(input_dims))
    overall_conv_macs = conv_per_position_macs * active_elements_count
    overall_conv_flops = 2 * overall_conv_macs

    bias_flops = 0
    if bias is not None:
        bias_flops = out_channels * batch_size * int(__shape_inner_product(output_dims))

    return int(overall_conv_flops + bias_flops), int(overall_conv_macs)


def _batch_norm_hook(
    input: torch.Tensor,
    running_mean: Optional[torch.Tensor] = None,
    running_var: Optional[torch.Tensor] = None,
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    training: Optional[bool] = False,
    momentum: Optional[float] = 0.1,
    eps: Optional[float] = 1e-05,
) -> Tuple[int, int]:
    has_affine = weight is not None

    if training:
        return torch.numel(input) * (5 if has_affine else 4), 0

    flops = torch.numel(input) * (2 if has_affine else 1)

    return flops, 0


def _layer_norm_hook(
    input: torch.Tensor,
    normalized_shape: List[int],
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    eps: Optional[float] = 1e-5,
) -> Tuple[int, int]:
    has_affine = weight is not None

    return torch.numel(input) * (5 if has_affine else 4), 0


def _instance_norm_hook(
    input: torch.Tensor,
    running_mean: Optional[torch.Tensor] = None,
    running_var: Optional[torch.Tensor] = None,
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    use_input_stats: Optional[bool] = True,
    momentum: Optional[float] = 0.1,
    eps: Optional[float] = 1e-5,
) -> Tuple[int, int]:
    has_affine = weight is not None

    return torch.numel(input) * (5 if has_affine else 4), 0


def _group_norm_hook(
    input: torch.Tensor,
    num_groups: int,
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    eps: Optional[float] = 1e-5,
) -> Tuple[int, int]:
    has_affine = weight is not None

    return torch.numel(input) * (5 if has_affine else 4), 0


def _upsample_hook(
    input: torch.Tensor,
    size: Optional[Union[int, Tuple[int, ...]]] = None,
    scale_factor: Optional[Union[float, Tuple[float]]] = None,
    mode: Optional[str] = "nearest",
    align_corners: Optional[bool] = None,
    recompute_scale_factor: Optional[bool] = None,
) -> Tuple[int, int]:
    if size is not None:
        if isinstance(size, tuple):
            return int(__shape_inner_product(size)), 0
        else:
            return int(size), 0

    assert scale_factor is not None, "Either `size` or `scale_factor` should be defined."

    flops = torch.numel(input)
    if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
        flops * int(__shape_inner_product(scale_factor))
    else:
        flops * scale_factor ** len(input)

    return flops, 0


def _softmax_hook(
    input: torch.Tensor, dim: Optional[int] = None, _stacklevel: Optional[int] = 3, dtype: Optional[torch.dtype] = None
) -> Tuple[int, int]:
    return torch.numel(input), 0


def _embedding_hook(
    input: torch.Tensor,
    weight: torch.Tensor,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: Optional[float] = 2.0,
    scale_grad_by_freq: Optional[bool] = False,
    sparse: Optional[bool] = False,
) -> Tuple[int, int]:
    return 0, 0


def _matmul_hook(input: torch.Tensor, other: torch.Tensor, *, out: Optional[Tuple[int, ...]] = None) -> Tuple[int, int]:
    macs = __shape_inner_product(input.shape) * other.shape[-1]

    return 2 * macs, macs


def _addmm_hook(
    input: torch.Tensor,
    mat1: torch.Tensor,
    mat2: torch.Tensor,
    *,
    beta: Optional[int] = 1,
    alpha: Optional[int] = 1,
    out: Optional[Tuple[int, ...]] = None
) -> Tuple[int, int]:
    macs = __shape_inner_product(mat1.shape) * mat2.shape[-1]

    return 2 * macs + __shape_inner_product(input.shape), macs


def _einsum_hook(equation: str, *operands) -> Tuple[int, int]:
    equation = equation.replace(" ", "")

    # Fix for `opt_einsum.contract`
    if len(operands) == 1 and isinstance(operands[0], tuple):
        operands = operands[0]

    input_shapes = [o.shape for o in operands]

    letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
    mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
    equation = equation.translate(mapping)

    np_arrs = [np.zeros(s) for s in input_shapes]
    optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
    for line in optim.split("\n"):
        if "optimized flop" in line.lower():
            flop = int(float(line.split(":")[-1]))

            return flop, 0

    raise NotImplementedError("Unsupported einsum operation.")


def __elementwise_hook(input: torch.Tensor, other: torch.Tensor) -> Tuple[int, int]:
    if not torch.is_tensor(input):
        if torch.is_tensor(other):
            return __shape_inner_product(other.shape), 0
        else:
            return 1, 0

    elif not torch.is_tensor(other):
        return __shape_inner_product(input.shape), 0

    else:
        dim_input = len(input.shape)
        dim_other = len(other.shape)
        max_dim = max(dim_input, dim_other)

        final_shape = []
        for i in range(max_dim):
            in_i = input.shape[i] if i < dim_input else 1
            ot_i = other.shape[i] if i < dim_other else 1
            if in_i > ot_i:
                final_shape.append(in_i)
            else:
                final_shape.append(ot_i)

        flops = __shape_inner_product(final_shape)

        return flops, 0


def _mul_hook(input: torch.Tensor, other: torch.Tensor, *, out: Optional[Tuple[int, ...]] = None) -> Tuple[int, int]:
    return __elementwise_hook(input, other)


def _add_hook(
    input: torch.Tensor, other: torch.Tensor, *, alpha: Optional[int] = 1, out: Optional[Tuple[int, ...]] = None
) -> Tuple[int, int]:
    return __elementwise_hook(input, other)


def _wrap_fn(fn: Callable, new_fn: Callable) -> Callable:
    """Wraps a function with another function.

    Args:
        fn: Current function.
        new_fn: New function.

    Returns:
        (Callable): Wrapped function.

    """

    old_fn = fn
    name = fn.__name__
    TORCH_FUNCTIONS[name] = old_fn

    def __wrap_fn(*args, **kwargs):
        flops, macs = new_fn(*args, **kwargs)

        if FLOPS:
            FLOPS[-1].append((name, flops))
        if MACS and macs:
            MACS[-1].append((name, macs))

        return old_fn(*args, **kwargs)

    __wrap_fn.__name__ = fn.__name__

    return __wrap_fn


[docs]def enable_functional_hooks() -> None: """Enables functional API profiler hooks.""" F.linear = _wrap_fn(F.linear, _linear_hook) F.conv1d = _wrap_fn(F.conv1d, _conv_hook) F.conv2d = _wrap_fn(F.conv2d, _conv_hook) F.conv3d = _wrap_fn(F.conv3d, _conv_hook) F.conv_transpose1d = _wrap_fn(F.conv_transpose1d, _conv_transpose_hook) F.conv_transpose2d = _wrap_fn(F.conv_transpose2d, _conv_transpose_hook) F.conv_transpose3d = _wrap_fn(F.conv_transpose3d, _conv_transpose_hook) F.relu = _wrap_fn(F.relu, _relu_hook) F.prelu = _wrap_fn(F.prelu, _prelu_hook) F.elu = _wrap_fn(F.elu, _elu_hook) F.leaky_relu = _wrap_fn(F.leaky_relu, _leakyrelu_hook) F.relu6 = _wrap_fn(F.relu6, _relu6_hook) if hasattr(F, "silu"): F.silu = _wrap_fn(F.silu, _silu_hook) F.gelu = _wrap_fn(F.gelu, _gelu_hook) F.batch_norm = _wrap_fn(F.batch_norm, _batch_norm_hook) F.layer_norm = _wrap_fn(F.layer_norm, _layer_norm_hook) F.instance_norm = _wrap_fn(F.instance_norm, _instance_norm_hook) F.group_norm = _wrap_fn(F.group_norm, _group_norm_hook) F.avg_pool1d = _wrap_fn(F.avg_pool1d, _pool_hook) F.avg_pool2d = _wrap_fn(F.avg_pool2d, _pool_hook) F.avg_pool3d = _wrap_fn(F.avg_pool3d, _pool_hook) F.max_pool1d = _wrap_fn(F.max_pool1d, _pool_hook) F.max_pool2d = _wrap_fn(F.max_pool2d, _pool_hook) F.max_pool3d = _wrap_fn(F.max_pool3d, _pool_hook) F.adaptive_avg_pool1d = _wrap_fn(F.adaptive_avg_pool1d, _pool_hook) F.adaptive_avg_pool2d = _wrap_fn(F.adaptive_avg_pool2d, _pool_hook) F.adaptive_avg_pool3d = _wrap_fn(F.adaptive_avg_pool3d, _pool_hook) F.adaptive_max_pool1d = _wrap_fn(F.adaptive_max_pool1d, _pool_hook) F.adaptive_max_pool2d = _wrap_fn(F.adaptive_max_pool2d, _pool_hook) F.adaptive_max_pool3d = _wrap_fn(F.adaptive_max_pool3d, _pool_hook) F.upsample = _wrap_fn(F.upsample, _upsample_hook) F.interpolate = _wrap_fn(F.interpolate, _upsample_hook) F.softmax = _wrap_fn(F.softmax, _softmax_hook) F.embedding = _wrap_fn(F.embedding, _embedding_hook)
[docs]def disable_functional_hooks() -> None: """Disables functional API profiler hooks.""" F.linear = TORCH_FUNCTIONS[F.linear.__name__] F.conv1d = TORCH_FUNCTIONS[F.conv1d.__name__] F.conv2d = TORCH_FUNCTIONS[F.conv2d.__name__] F.conv3d = TORCH_FUNCTIONS[F.conv3d.__name__] F.conv_transpose1d = TORCH_FUNCTIONS[F.conv_transpose1d.__name__] F.conv_transpose2d = TORCH_FUNCTIONS[F.conv_transpose2d.__name__] F.conv_transpose3d = TORCH_FUNCTIONS[F.conv_transpose3d.__name__] F.relu = TORCH_FUNCTIONS[F.relu.__name__] F.prelu = TORCH_FUNCTIONS[F.prelu.__name__] F.elu = TORCH_FUNCTIONS[F.elu.__name__] F.leaky_relu = TORCH_FUNCTIONS[F.leaky_relu.__name__] F.relu6 = TORCH_FUNCTIONS[F.relu6.__name__] F.batch_norm = TORCH_FUNCTIONS[F.batch_norm.__name__] F.layer_norm = TORCH_FUNCTIONS[F.layer_norm.__name__] F.instance_norm = TORCH_FUNCTIONS[F.instance_norm.__name__] F.group_norm = TORCH_FUNCTIONS[F.group_norm.__name__] F.avg_pool1d = TORCH_FUNCTIONS[F.avg_pool1d.__name__] F.avg_pool2d = TORCH_FUNCTIONS[F.avg_pool2d.__name__] F.avg_pool3d = TORCH_FUNCTIONS[F.avg_pool3d.__name__] F.max_pool1d = TORCH_FUNCTIONS[F.max_pool1d.__name__] F.max_pool2d = TORCH_FUNCTIONS[F.max_pool2d.__name__] F.max_pool3d = TORCH_FUNCTIONS[F.max_pool3d.__name__] F.adaptive_avg_pool1d = TORCH_FUNCTIONS[F.adaptive_avg_pool1d.__name__] F.adaptive_avg_pool2d = TORCH_FUNCTIONS[F.adaptive_avg_pool2d.__name__] F.adaptive_avg_pool3d = TORCH_FUNCTIONS[F.adaptive_avg_pool3d.__name__] F.adaptive_max_pool1d = TORCH_FUNCTIONS[F.adaptive_max_pool1d.__name__] F.adaptive_max_pool2d = TORCH_FUNCTIONS[F.adaptive_max_pool2d.__name__] F.adaptive_max_pool3d = TORCH_FUNCTIONS[F.adaptive_max_pool3d.__name__] F.upsample = TORCH_FUNCTIONS[F.upsample.__name__] F.interpolate = TORCH_FUNCTIONS[F.interpolate.__name__] F.softmax = TORCH_FUNCTIONS[F.softmax.__name__] F.embedding = TORCH_FUNCTIONS[F.embedding.__name__]
[docs]def enable_tensor_hooks() -> None: """Enables tensor-based operations profiler hooks.""" torch.matmul = _wrap_fn(torch.matmul, _matmul_hook) torch.mm = _wrap_fn(torch.mm, _matmul_hook) torch.bmm = _wrap_fn(torch.bmm, _matmul_hook) torch.addmm = _wrap_fn(torch.addmm, _addmm_hook) torch.mul = _wrap_fn(torch.mul, _mul_hook) torch.add = _wrap_fn(torch.add, _add_hook) torch.einsum = _wrap_fn(torch.einsum, _einsum_hook)
[docs]def disable_tensor_hooks() -> None: """Disables tensor-based operations profiler hooks.""" torch.matmul = TORCH_FUNCTIONS[torch.matmul.__name__] torch.mm = TORCH_FUNCTIONS[torch.mm.__name__] torch.bmm = TORCH_FUNCTIONS[torch.bmm.__name__] torch.addmm = TORCH_FUNCTIONS[torch.addmm.__name__] torch.mul = TORCH_FUNCTIONS[torch.mul.__name__] torch.add = TORCH_FUNCTIONS[torch.add.__name__] torch.einsum = TORCH_FUNCTIONS[torch.einsum.__name__]