Source code for block_zoo.attentions.BiAttFlow

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
from block_zoo.BaseLayer import BaseConf, BaseLayer
from utils.DocInherit import DocInherit
import copy


[docs]class BiAttFlowConf(BaseConf): """Configuration for AttentionFlow layer Args: attention_dropout(float): dropout rate of attention matrix dropout operation """ def __init__(self, **kwargs): super(BiAttFlowConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): self.attention_dropout = 0
[docs] @DocInherit def declare(self): self.num_of_inputs = 2 self.input_ranks = [3]
[docs] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[0]) self.output_dim[-1] = 4 * self.input_dims[0][-1] super(BiAttFlowConf, self).inference()
[docs] @DocInherit def verify(self): super(BiAttFlowConf, self).verify() necessary_attrs_for_user = ['attention_dropout'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr)
[docs]class BiAttFlow(BaseLayer): """ implement AttentionFlow layer for BiDAF [paper]: https://arxiv.org/pdf/1611.01603.pdf Args: layer_conf(AttentionFlowConf): configuration of the AttentionFlowConf """ def __init__(self, layer_conf): super(BiAttFlow, self).__init__(layer_conf) self.layer_conf = layer_conf # self.W = nn.Linear(layer_conf.input_dims[0][-1]*3, 1) self.attention_weight_content = nn.Linear(layer_conf.input_dims[0][-1], 1) self.attention_weight_query = nn.Linear(layer_conf.input_dims[0][-1], 1) self.attention_weight_cq = nn.Linear(layer_conf.input_dims[0][-1], 1) self.attention_dropout = nn.Dropout(layer_conf.attention_dropout)
[docs] def forward(self, content, content_len, query, query_len=None): """ implement the attention flow layer of BiDAF model :param content (Tensor): [batch_size, content_seq_len, dim] :param content_len: [batch_size] :param query (Tensor): [batch_size, query_seq_len, dim] :param query_len: [batch_size] :return: the tensor has same shape as content """ assert content.size()[2] == query.size()[2], 'The dimension of axis 2 of content and query must be consistent! But now, content.size() is %s and query.size() is %s' % (content.size(), query.size()) batch_size = content.size()[0] content_seq_len = content.size()[1] query_seq_len = query.size()[1] feature_dim = content.size()[2] # content_aug = content.unsqueeze(1).expand(batch_size, query_seq_len, content_seq_len, feature_dim) #[batch_size, string2_seq_len, string_seq_len, feature_dim] content_aug = content.unsqueeze(2).expand(batch_size, content_seq_len, query_seq_len, feature_dim) # [batch_size, string2_seq_len, string_seq_len, feature_dim] query_aug = query.unsqueeze(1).expand(batch_size, content_seq_len, query_seq_len, feature_dim) #[batch_size, string_seq_len, string2_seq_len, feature_dim] content_aug = content_aug.contiguous().view(batch_size*content_seq_len*query_seq_len, feature_dim) query_aug = query_aug.contiguous().view(batch_size*content_seq_len*query_seq_len, feature_dim) content_query_comb = torch.cat((query_aug, content_aug, content_aug * query_aug), 1) attention = self.W(content_query_comb) attention = self.attention_dropout(attention) # [batch_size, string_seq_len, string2_seq_len] attention = attention.view(batch_size, content_seq_len, query_seq_len) attention_logits = F.softmax(attention, dim=2) # [batch_size*content_seq_len*query_seq_len] * [batch_size*query_seq_len*feature_dim] --> [batch_size*content_seq_len*feature_dim] content2query_att = torch.bmm(attention_logits, query) # [batch_size, 1, content_seq_len] b = F.softmax(torch.max(attention, dim=2)[0], dim=1).unsqueeze(1) # [batch_size, 1, content_seq_len]*[batch_size, content_seq_len, feature_dim] = [batch_size, feature_dim] query2content_att = torch.bmm(b, content).squeeze(1) # [batch_size, content_seq_len, feature_dim] query2content_att = query2content_att.unsqueeze(1).expand(-1, content_seq_len, -1) result = torch.cat([content, content2query_att, content*content2query_att, content*query2content_att], dim=-1) return result, content_len