Source code for block_zoo.transformer.MultiHeadAttention

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import math

from block_zoo.BaseLayer import BaseLayer,BaseConf
from utils.DocInherit import DocInherit
import copy

[docs]class MultiHeadAttentionConf(BaseConf): """ Configuration of MultiHeadAttention Layer Args: n_head (int): the head number of attention scale (bool): if need to scale attn_dropout (float): the dropout of attention layer resid_dropout (float): the dropout of last Linear """ def __init__(self,**kwargs): super(MultiHeadAttentionConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): self.n_head = 12 self.scale = True self.attn_dropout = 0.1 self.resid_dropout = 0.1
[docs] @DocInherit def declare(self): self.num_of_inputs = 1 self.input_ranks = [3]
[docs] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[0]) super(MultiHeadAttentionConf, self).inference()
[docs] @DocInherit def verify(self): super(MultiHeadAttentionConf, self).verify()
[docs]class MultiHeadAttention(nn.Module): """ MultiHeadAttention Layer Args: layer_conf (MultiHeadAttentionConf): configuration of a layer """ def __init__(self, layer_conf): super(MultiHeadAttention, self).__init__() self.layer_conf = layer_conf self.split_size = self.layer_conf.input_dims[0][-1] self.n_state = self.layer_conf.input_dims[0][-1] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") assert self.n_state % self.layer_conf.n_head == 0 self.c_attn = nn.Linear(self.layer_conf.input_dims[0][-1],self.n_state * 3) self.c_proj = nn.Linear(self.layer_conf.input_dims[0][-1],self.n_state) def _attn(self, q, k, v): w = torch.matmul(q, k).to(self.device) if self.layer_conf.scale: w = w / math.sqrt(v.size(-1)) w = w * self.b + -1e9 * (1 - self.b) w = nn.Softmax(dim=-1)(w) w = nn.Dropout(self.layer_conf.attn_dropout)(w) return torch.matmul(w, v)
[docs] def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) return x.view(*new_x_shape)
[docs] def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.layer_conf.n_head, x.size(-1) // self.layer_conf.n_head) x = x.view(*new_x_shape) if k: return x.permute(0, 2, 3, 1) else: return x.permute(0, 2, 1, 3)
[docs] def forward(self, string, string_len): """ process input Args: string, string_len e.g. string (Tensor): [batch_size, seq_len, dim], string_len (Tensor): [batch_size] Returns: Tensor: [batch_size, seq_len, output_dim], [batch_size] """ self.register_buffer('b', torch.tril(torch.ones(string.shape[1], string.shape[1]).to(self.device)).view(1, 1, string.shape[1], string.shape[1])) x = self.c_attn(string) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) a = nn.Dropout(self.layer_conf.resid_dropout)(a) return a, string_len