Source code for block_zoo.attentions.FullAttention

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import copy
import numpy as np
from utils.DocInherit import DocInherit

from block_zoo.BaseLayer import BaseLayer, BaseConf
from utils.exceptions import ConfigurationError

[docs]class FullAttentionConf(BaseConf): def __init__(self, **kwargs): super(FullAttentionConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): self.hidden_dim = 128 self.activation = 'ReLU'
[docs] @DocInherit def declare(self): self.num_of_inputs = 4 self.input_ranks = [3]
[docs] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[1]) # e.g. use query to represent passage, there fore the output dim depends on query's dim super(FullAttentionConf, self).inference() # PUT THIS LINE AT THE END OF inference()
[docs] @DocInherit def verify_before_inference(self): super(FullAttentionConf, self).verify_before_inference() necessary_attrs_for_user = ['hidden_dim'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr)
[docs] def verify(self): super(FullAttentionConf, self).verify() supported_activation_pytorch = [None, 'Sigmoid', 'Tanh', 'ReLU', 'PReLU', 'ReLU6', 'LeakyReLU', 'LogSigmoid', 'ELU', 'SELU', 'Threshold', 'Hardtanh', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax'] value_checks = [('activation', supported_activation_pytorch)] for attr, legal_values in value_checks: self.add_attr_value_assertion(attr, legal_values)
[docs]class FullAttention(BaseLayer): """ Full-aware fusion of: Via, U., With, T., & To, P. (2018). Fusion Net: Fusing Via Fully-Aware Attention with Application to Machine Comprehension, 1–17. """ def __init__(self, layer_conf): super(FullAttention, self).__init__(layer_conf) self.layer_conf.hidden_dim = layer_conf.hidden_dim self.linear = nn.Linear(layer_conf.input_dims[2][-1], layer_conf.hidden_dim, bias=False) # this requires that input_dims[0][-1] == input_dims[1][-1] if layer_conf.input_dims[2][-1] == layer_conf.input_dims[3][-1]: self.linear2 = self.linear else: self.linear2 = nn.Linear(layer_conf.input_dims[3][-1], layer_conf.hidden_dim, bias=False) self.linear_final = Parameter(torch.ones(1, layer_conf.hidden_dim), requires_grad=True) self.activation = eval("nn." + layer_conf.activation)()
[docs] def forward(self, string1, string1_len, string2, string2_len, string1_HoW, string1_How_len, string2_HoW, string2_HoW_len): """ To get representation of string1, we use string1 and string2 to obtain attention weights and use string2 to represent string1 Note: actually, the semantic information of string1 is not used, we only need string1's seq_len information Args: string1: [batch size, seq_len, input_dim1] string1_len: [batch_size] string2: [batch size, seq_len, input_dim2] string2_len: [batch_size] string1_HoW: [batch size, seq_len, att_dim1] string1_HoW_len: [batch_size] string2_HoW: [batch size, seq_len, att_dim2] string2_HoW_len: [batch_size] Returns: string1's representation string1_len """ string1_key = self.activation(self.linear(string1_HoW.contiguous().view(-1, string1_HoW.size()[2]))) #[bs * seq_len, atten_dim1] -> [bs * seq_len, hidden_dim] string2_key = self.activation(self.linear2(string2_HoW.contiguous().view(-1, string2_HoW.size()[2]))) #[bs * seq_len, atten_dim2] -> [bs * seq_len, hidden_dim] final_v = self.linear_final.expand_as(string2_key) string2_key = final_v * string2_key string1_rep = string1_key.view(-1, string1.size(1), 1, self.layer_conf.hidden_dim).transpose(1, 2).contiguous().view(-1, string1.size(1), self.layer_conf.hidden_dim) # get [bs, seq_len, hidden_dim] string2_rep = string2_key.view(-1, string2.size(1), 1, self.layer_conf.hidden_dim).transpose(1, 2).contiguous().view(-1, string2.size(1), self.layer_conf.hidden_dim) # get [bs, seq_len, hidden_dim] scores = string1_rep.bmm(string2_rep.transpose(1, 2)).view(-1, 1, string1.size(1), string2.size(1)) # [bs, 1, seq_len1, seq_len2] string2_len_np = string2_len.cpu().numpy() if torch.cuda.device_count() > 1: # otherwise, it will raise a Exception because the length inconsistence string2_max_len = string2.shape[1] else: string2_max_len = string2_len_np.max() string2_mask = np.array([[0] * num + [1] * (string2_max_len - num) for num in string2_len_np]) string2_mask = torch.from_numpy(string2_mask).unsqueeze(1).unsqueeze(2).expand_as(scores) if self.is_cuda(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") string2_mask = string2_mask.to(device) scores.data.masked_fill_(string2_mask.data.byte(), -float('inf')) alpha_flat = F.softmax(scores.view(-1, string2.size(1)), dim=1) # [bs * seq_len1, seq_len2] alpha = alpha_flat.view(-1, string1.size(1), string2.size(1)) # [bs, seq_len1, seq_len2] #size_per_level = self.layer_conf.hidden_dim // 1 #string1_atten_seq = alpha.bmm(string2.contiguous().view(-1, string2.size(1), 1, size_per_level).transpose(1, 2).contiguous().view(-1, string2.size(1), size_per_level)) string1_atten_seq = alpha.bmm(string2) #return string1_atten_seq.view(-1, 1, string1.size(1), size_per_level).transpose(1, 2).contiguous().view(-1, string1.size(1), self.layer_conf.hidden_dim), string1_len return string1_atten_seq, string1_len