Source code for archai.discrete_search.search_spaces.nlp.tfpp.ops.mha

''' Modified from https://github.com/HazyResearch/flash-attention/ '''

import math
from warnings import warn
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from einops import rearrange

try:
    from flash_attn.ops.fused_dense import FusedDense
    from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
except ImportError:
    FusedDense = None
    FlashRotaryEmbedding = None

try:
    import ft_attention
except ImportError:
    ft_attention = None

try:
    from flash_attn.modules.mha import FlashSelfAttention, _update_kv_cache
except ImportError:
    FlashSelfAttention, _update_kv_cache = None, None

from ..utils import get_optim_flag

[docs]class BaseRotaryEmbedding(nn.Module): def __init__(self, dim: int, base=10000, scale_base=0, device=None): super().__init__() if scale_base > 0: raise NotImplementedError # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq) self.scale_base = scale_base scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None) self.register_buffer("scale", scale) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None def _update_cos_sin_cache(self, x, seqlen_offset=0): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) """ seqlen = x.shape[1] + seqlen_offset # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device)) if self.scale is None: self._cos_cached = torch.cos(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype) else: power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base) scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
[docs] def apply_rotary_emb_qkv(self, qkv: torch.FloatTensor, sin: torch.FloatTensor, cos: torch.FloatTensor, sin_k: Optional[torch.FloatTensor] = None, cos_k: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: _, seqlen, three, _, headdim = qkv.shape assert three == 3 rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) q_rot = qkv[:, :, 0, :, :rotary_dim] q_pass = qkv[:, :, 0, :, rotary_dim:] k_rot = qkv[:, :, 1, :, :rotary_dim] k_pass = qkv[:, :, 1, :, rotary_dim:] # Queries q1, q2 = q_rot.chunk(2, dim=-1) c, s = rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d') q_rot = torch.cat([ q1 * c - q2 * s, q1 * s + q2 * c ], axis=-1) # Keys k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) c, s = rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d') k_rot = torch.cat([ k1 * c - k2 * s, k1 * s + k2 * c ], axis=-1) q = torch.cat([ q_rot, q_pass ], axis=-1) k = torch.cat([ k_rot, k_pass ], axis=-1) qkv = torch.cat([ q.unsqueeze(2), k.unsqueeze(2), qkv[:, :, 2:3, :, :] ], axis=2) # inplace, but we still return it for convenience return qkv
[docs] def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset) return self.apply_rotary_emb_qkv( qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:] )
[docs]class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout
[docs] def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S) """ batch_size, seqlen = qkv.shape[0], qkv.shape[1] causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) output = torch.einsum('bhts,bshd->bthd', attention_drop, v) return output
[docs]class MHA(nn.Module): def __init__(self, hf_config: PretrainedConfig, hidden_size: int, total_heads: int, op_heads: int, bias=True, dropout=0.0, softmax_scale=None, causal=True, layer_idx=None, rotary_emb_scale_base=0, return_residual=False, checkpointing=False, device=None, dtype=None, **kwargs) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.hidden_size = hidden_size self.total_heads = total_heads self.op_heads = op_heads assert self.hidden_size % op_heads == 0, "hiddden_size must be divisible by op_heads" self.head_dim = self.hidden_size // total_heads self.op_size = op_heads * self.head_dim self.causal = causal self.layer_idx = layer_idx self.rotary_emb_dim = getattr(hf_config, 'rotary_dim', 0) self.fused_dense = get_optim_flag(hf_config, 'fused_dense') self.flash_attn = get_optim_flag(hf_config, 'flash_attn') self.return_residual = return_residual self.checkpointing = checkpointing if self.rotary_emb_dim > 0: if get_optim_flag(hf_config, 'flash_rotary_emb'): assert FlashRotaryEmbedding is not None, 'rotary_emb is not installed' self.rotary_emb = FlashRotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) else: self.rotary_emb = BaseRotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) else: warn('MHA: rotary_emb_dim is 0, no rotary embedding will be used. Performance may degrade.') linear_cls = nn.Linear if self.fused_dense: assert FusedDense is not None, 'Need to install fused_dense' linear_cls = FusedDense self.Wqkv = linear_cls(hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs) if self.flash_attn: assert FlashSelfAttention is not None, 'flash_attn is not installed' self.inner_attn = FlashSelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) else: self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) """ assert not self.dwconv, 'Generation does not support dwconv yet' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' return _update_kv_cache(kv, inference_params, self.layer_idx)
[docs] def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention. max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. mixer_subset: for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if cu_seqlens is not None: assert max_seqlen is not None assert key_padding_mask is None assert self.flash_attn assert not self.dwconv assert self.rotary_emb_dim == 0 if key_padding_mask is not None: assert cu_seqlens is None assert max_seqlen is None assert not self.flash_attn if inference_params is not None: assert key_padding_mask is None assert cu_seqlens is None and max_seqlen is None assert not self.dwconv attn_kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen} if self.flash_attn else {'key_padding_mask': key_padding_mask}) assert x_kv is None and mixer_subset is None qkv = self.Wqkv(x) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) if inference_params is None: if self.rotary_emb_dim > 0: qkv = self.rotary_emb(qkv) if not self.checkpointing: context = self.inner_attn(qkv, **attn_kwargs) else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attn_kwargs) else: if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0: if self.rotary_emb_dim > 0: qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) q = qkv[:, :, 0] kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. causal = None if inference_params.sequence_len_offset == 0 else False context = self.inner_cross_attn(q, kv, causal=causal) else: assert inference_params.fused_ft_kernel assert ft_attention is not None context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, self.rotary_emb_dim ) context = rearrange(context, 'b h d -> b 1 h d') out = rearrange(context, '... h d -> ... (h d)') return (out, None) if not self.return_residual else ((out, x), None)