from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenConfig, fixed_pos_embedding, apply_rotary_pos_emb
from archai.discrete_search.search_spaces.config import ArchConfig
[docs]class CausalSelfAttention(nn.Module):
def __init__(self, arch_config: ArchConfig, hf_config: CodeGenConfig, hidden_size: int,
total_heads: int, op_heads: int, **kwargs):
assert hidden_size % total_heads == 0
max_positions = hf_config.max_position_embeddings
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
self.hidden_size = hidden_size
self.total_heads = total_heads
self.op_heads = op_heads
self.head_size = hidden_size // total_heads
self.op_size = (self.hidden_size // total_heads) * op_heads
self.max_positions = max_positions
self.scale_attn_weights = hf_config.scale_attn_weights
self.attn_dropout = nn.Dropout(hf_config.attn_pdrop)
self.scale_attn = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.hidden_size, self.op_size * 3, bias=False)
self.rotary_dim = getattr(hf_config, 'rotary_dim', None)
def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
Merges attn_head_size dim and num_attn_heads dim into n_ctx
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query =
key =
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / self.scale_attn
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights =
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
[docs] def forward(
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
qkv = self.qkv_proj(hidden_states)
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 1
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
local_dim = self.head_size * self.op_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
query = self._split_heads(query, self.op_heads, self.head_size, mp_num=mp_num)
key = self._split_heads(key, self.op_heads, self.head_size, mp_num=mp_num)
value = self._split_heads(value, self.op_heads, self.head_size, mp_num=mp_num)
value = value.permute(0, 2, 1, 3)
seq_len = key.shape[1]
offset = 0
if layer_past is not None:
offset = layer_past[0].shape[-2]
seq_len += offset
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
key =[k_rot, k_pass], dim=-1)
query =[q_rot, q_pass], dim=-1)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key =, key), dim=-2)
value =, value), dim=-2)
if use_cache is True:
present = (key, value)
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, _ = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.op_heads, self.head_size)
return attn_output, present