Source code for archai.discrete_search.search_spaces.nlp.tfpp.ops.causal_self_attn

from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers.models.codegen.modeling_codegen import (
    CodeGenConfig, fixed_pos_embedding, apply_rotary_pos_emb
)
from archai.discrete_search.search_spaces.config import ArchConfig


[docs]class CausalSelfAttention(nn.Module): def __init__(self, arch_config: ArchConfig, hf_config: CodeGenConfig, hidden_size: int, total_heads: int, op_heads: int, **kwargs): assert hidden_size % total_heads == 0 super().__init__() max_positions = hf_config.max_position_embeddings self.register_buffer( "causal_mask", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 1, 1, max_positions, max_positions ), ) self.hidden_size = hidden_size self.total_heads = total_heads self.op_heads = op_heads self.head_size = hidden_size // total_heads self.op_size = (self.hidden_size // total_heads) * op_heads self.max_positions = max_positions self.scale_attn_weights = hf_config.scale_attn_weights self.attn_dropout = nn.Dropout(hf_config.attn_pdrop) self.scale_attn = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) self.qkv_proj = nn.Linear(self.hidden_size, self.op_size * 3, bias=False) self.rotary_dim = getattr(hf_config, 'rotary_dim', None) def _split_heads(self, x, n_head, dim_head, mp_num): reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:]) return reshaped def _merge_heads(self, tensor, num_attention_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into n_ctx """ if len(tensor.shape) == 5: tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() elif len(tensor.shape) == 4: tensor = tensor.permute(0, 2, 1, 3).contiguous() else: raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) def _attn( self, query, key, value, attention_mask=None, head_mask=None, ): # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) key = key.to(torch.float32) attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = attn_weights / self.scale_attn mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights
[docs] def forward( self, hidden_states: Optional[torch.FloatTensor], attention_mask: Optional[torch.FloatTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs ) -> Union[ Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: qkv = self.qkv_proj(hidden_states) # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic mp_num = 1 qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) local_dim = self.head_size * self.op_heads // mp_num query, value, key = torch.split(qkv_split, local_dim, dim=-1) query = self._split_heads(query, self.op_heads, self.head_size, mp_num=mp_num) key = self._split_heads(key, self.op_heads, self.head_size, mp_num=mp_num) value = self._split_heads(value, self.op_heads, self.head_size, mp_num=mp_num) value = value.permute(0, 2, 1, 3) seq_len = key.shape[1] offset = 0 if layer_past is not None: offset = layer_past[0].shape[-2] seq_len += offset if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) if layer_past is not None: past_key = layer_past[0] past_value = layer_past[1] key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None # compute self-attention: V x Softmax(QK^T) attn_output, _ = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.op_heads, self.head_size) return attn_output, present