Source code for block_zoo.BiLSTMAtt

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np

from block_zoo.BaseLayer import BaseLayer, BaseConf
from utils.DocInherit import DocInherit


[docs]class BiLSTMAttConf(BaseConf): """ Configuration of BiLSTMAtt layer Args: hidden_dim (int): dimension of hidden state dropout (float): dropout rate num_layers (int): number of BiLSTM layers """ def __init__(self, **kwargs): super(BiLSTMAttConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): self.hidden_dim = 128 self.dropout = 0.0 self.num_layers = 1
[docs] @DocInherit def declare(self): self.num_of_inputs = 1 self.input_ranks = [3]
[docs] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[0]) self.output_dim[-1] = 2 * self.hidden_dim self.attention_dim = 2 * self.hidden_dim super(BiLSTMAttConf, self).inference() # PUT THIS LINE AT THE END OF inference()
[docs] @DocInherit def verify_before_inference(self): super(BiLSTMAttConf, self).verify_before_inference() necessary_attrs_for_user = ['hidden_dim'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr)
[docs] @DocInherit def verify(self): super(BiLSTMAttConf, self).verify() necessary_attrs_for_user = ['dropout', 'attention_dim'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr)
[docs]class BiLSTMAtt(BaseLayer): """ BiLSTM with self attention Args: layer_conf (BiLSTMAttConf): configuration of a layer """ def __init__(self, layer_conf): super(BiLSTMAtt, self).__init__(layer_conf) self.lstm = nn.LSTM(layer_conf.input_dims[0][-1], layer_conf.hidden_dim, layer_conf.num_layers, bidirectional=True, dropout=layer_conf.dropout, batch_first=True) self.att = nn.Parameter(torch.randn(layer_conf.attention_dim, layer_conf.attention_dim), requires_grad=True) nn.init.uniform_(self.att, a=0, b=1) self.softmax = nn.Softmax()
[docs] def forward(self, string, string_len): """ process inputs Args: string (Tensor): [batch_size, seq_len, dim] string_len (Tensor): [batch_size] Returns: Tensor: [batch_size, seq_len, 2 * hidden_dim] """ padded_seq_len = string.shape[1] # Sort by length (keep idx) string_len_sorted, idx_sort = (-string_len).sort() string_len_sorted = -string_len_sorted idx_unsort = idx_sort.sort()[1] bsize = string.shape[0] string = string.index_select(0, idx_sort) # Handling padding in Recurrent Networks string_packed = nn.utils.rnn.pack_padded_sequence(string, string_len_sorted, batch_first=True) self.lstm.flatten_parameters() string_output = self.lstm(string_packed)[0] # seqlen x batch x 2*nhid string_output = nn.utils.rnn.pad_packed_sequence(string_output, batch_first=True, total_length=padded_seq_len)[0] # Un-sort by length string_output = string_output.index_select(0, idx_unsort).contiguous() # [batch, seqlen, 2*nhid] # Self Attention alphas = string_output.matmul(self.att).bmm(string_output.transpose(1, 2).contiguous()) # [batch, seqlen, seqlen] # Set probas of padding to zero in softmax alphas = alphas + ((alphas == 0).float() * -10000) # softmax alphas = self.softmax(alphas.view(-1, int(padded_seq_len))) # [batch*seglen, seqlen] alphas = alphas.view(bsize, -1, int(padded_seq_len)) # [batch, seglen, seq_len] string_output = alphas.bmm(string_output) # [batch, seglen, 2*nhid] return string_output, string_len