Source code for block_zoo.attentions.LinearAttention

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np

from block_zoo.BaseLayer import BaseLayer, BaseConf
from utils.DocInherit import DocInherit

[docs]class LinearAttentionConf(BaseConf): """Configuration for Linear attention layer Args: keep_dim (bool): Whether to sum up the sequence representation along the sequence axis. if False, the layer would return (batch_size, dim) if True, the layer would keep the same dimension as input, thus return (batch_size, sequence_length, dim) """ def __init__(self, **kwargs): super(LinearAttentionConf, self).__init__(**kwargs)
[docs] @DocInherit def default(self): self.keep_dim = False
[docs] @DocInherit def declare(self): self.num_of_inputs = 1 self.input_ranks = [-1]
[docs] @DocInherit def inference(self): self.attention_weight_dim = self.input_dims[0][-1] if self.keep_dim: self.output_dim = copy.deepcopy(self.input_dims[0]) else: self.output_dim = [] for idx, dim in enumerate(self.input_dims[0]): if idx != len(self.input_dims[0]) - 2: self.output_dim.append(dim) super(LinearAttentionConf, self).inference() # PUT THIS LINE AT THE END OF inference()
[docs] @DocInherit def verify_before_inference(self): super(LinearAttentionConf, self).verify_before_inference() necessary_attrs_for_user = ['keep_dim'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr)
[docs] @DocInherit def verify(self): super(LinearAttentionConf, self).verify() necessary_attrs_for_user = ['attention_weight_dim', 'keep_dim'] for attr in necessary_attrs_for_user: self.add_attr_exist_assertion_for_user(attr) type_checks = [('attention_weight_dim', int), ('keep_dim', bool)] for attr, attr_type in type_checks: self.add_attr_type_assertion(attr, attr_type) assert self.input_ranks[0] in set([2, 3]) and not (self.input_ranks[0] == 2 and self.keep_dim == False)
[docs]class LinearAttention(BaseLayer): """ Linear attention. Combinate the original sequence along the sequence_length dimension. Args: layer_conf (LinearAttentionConf): configuration of a layer """ def __init__(self, layer_conf): """ Args: layer_conf (LinearAttentionConf): configuration of a layer """ super(LinearAttention, self).__init__(layer_conf) self.attention_weight = nn.Parameter(torch.FloatTensor(torch.randn(self.layer_conf.attention_weight_dim, 1)))
[docs] def forward(self, string, string_len=None): """ process inputs Args: string (Variable): (batch_size, sequence_length, dim) string_len (ndarray or None): [batch_size] Returns: Variable: if keep_dim == False: Output dimention: (batch_size, dim) else: just reweight along the sequence_length dimension: (batch_size, sequence_length, dim) """ attention_weight = torch.mm(string.contiguous().view(string.shape[0] * string.shape[1], string.shape[2]), self.attention_weight) attention_weight = nn.functional.softmax(attention_weight.view(string.shape[0], string.shape[1]), dim=1) attention_tiled = attention_weight.unsqueeze(2).expand_as(string) string_reweighted = torch.mul(string, attention_tiled) if self.layer_conf.keep_dim is False: string_reweighted = torch.sum(string_reweighted, 1) return string_reweighted, string_len