Source code for block_zoo.attentions.Attention

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import numpy as np
import torch.nn as nn
import copy

from block_zoo.BaseLayer import BaseLayer, BaseConf
from utils.DocInherit import DocInherit


[docs]class AttentionConf(BaseConf): """Configuration for Attention layer """ def __init__(self, **kwargs): super(AttentionConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): pass
[docs] @DocInherit def declare(self): self.num_of_inputs = 2 self.input_ranks = [3, 3]
[docs] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[0]) super(AttentionConf, self).inference() # PUT THIS LINE AT THE END OF inference()
[docs] @DocInherit def verify(self): super(AttentionConf, self).verify()
[docs]class Attention(BaseLayer): """ Attention layer Given sequences X and Y, match sequence Y to each element in X. * o_i = sum(alpha_j * y_j) for i in X * alpha_j = softmax(y_j * x_i) Args: layer_conf (AttentionConf): configuration of a layer """ def __init__(self, layer_conf): super(Attention, self).__init__(layer_conf) assert layer_conf.input_dims[0][-1] == layer_conf.input_dims[1][-1] self.softmax = nn.Softmax(dim=-1)
[docs] def forward(self, x, x_len, y, y_len): """ Args: x (Tensor): [batch_size, x_max_len, dim]. x_len (Tensor): [batch_size], default is None. y (Tensor): [batch_size, y_max_len, dim]. y_len(Tensor): [batch_size], default is None. Returns: output: has the same shape as x. """ scores = x.bmm(y.transpose(2, 1)) # [batch_size, x_max_len, y_max_len] batch_size, y_max_len, _ = y.size() y_length = y_len.cpu().numpy() y_mask = np.ones((batch_size, y_max_len)) for i, single_len in enumerate(y_length): y_mask[i][:single_len] = 0 y_mask = torch.from_numpy(y_mask).byte().to(scores.device) y_mask = y_mask.unsqueeze(1).expand(scores.size()) scores.data.masked_fill_(y_mask.data, float('-inf')) alpha = self.softmax(scores) # [batch_size, x_max_len, y_len] output = alpha.bmm(y) # [batch_size, x_max_len, dim] return output, x_len