Source code for archai.discrete_search.search_spaces.nlp.transformer_flex.models.modeling_mem_transformer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# Copyright (c) 2018, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from transformers.file_utils import ModelOutput
from transformers.models.transfo_xl.modeling_transfo_xl import (
TransfoXLModel,
TransfoXLPreTrainedModel,
)
from archai.discrete_search.search_spaces.nlp.transformer_flex.models.configuration_mem_transformer import (
MemTransformerConfig,
)
from archai.discrete_search.search_spaces.nlp.transformer_flex.models.mem_transformer_utils.adaptive_embedding import (
AdaptiveEmbedding,
)
from archai.discrete_search.search_spaces.nlp.transformer_flex.models.mem_transformer_utils.positional_embedding import (
PositionalEmbedding,
)
from archai.discrete_search.search_spaces.nlp.transformer_flex.models.mem_transformer_utils.projected_adaptive_log_softmax import (
ProjectedAdaptiveLogSoftmax,
)
from archai.discrete_search.search_spaces.nlp.transformer_flex.models.mem_transformer_utils.rel_partial_learnable_decoder import (
RelPartialLearnableDecoderLayer,
)
[docs]class MemTransformerBaseOutput(ModelOutput):
last_hidden_state: torch.FloatTensor
past_key_values: Optional[Tuple[torch.FloatTensor]] = None
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
[docs]class MemTransformerOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
prediction_scores: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[torch.FloatTensor]] = None
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self) -> torch.FloatTensor:
return self.prediction_scores
[docs]class MemTransformerModel(TransfoXLModel):
config_class = MemTransformerConfig
def __init__(self, config: MemTransformerConfig) -> None:
super().__init__(config)
self.word_emb = AdaptiveEmbedding(
config.vocab_size,
config.d_embed,
config.d_model,
config.cutoffs,
div_val=config.div_val,
fp16=config.fp16,
)
self.layers = nn.ModuleList()
for _ in range(config.n_layer):
layer_i = RelPartialLearnableDecoderLayer(
config.n_head,
config.d_model,
config.d_head,
config.d_inner,
config.dropout,
dropatt=config.dropatt,
primer_conv=config.primer_conv,
primer_square=config.primer_square,
pre_lnorm=config.pre_lnorm,
layer_norm_epsilon=config.layer_norm_epsilon,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
)
self.layers.append(layer_i)
self.pos_embeds = PositionalEmbedding(self.config.d_model)
self.init_weights()
[docs] def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> MemTransformerBaseOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Original Transformer-XL uses [q_length, batch_size], where
# we prefer to use [batch_size, q_length]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both `input_ids` and `inputs_embeds` at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
q_length, batch_size = input_ids.size()
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
q_length, batch_size = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either `input_ids` or `inputs_embeds`")
if mems is None:
mems = self.init_mems(batch_size)
# (n_hidden_layers, q_length, k_length, batch_size, n_head)
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
# Guarantees 16-bit floating point compatibility
head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
else:
head_mask = [None] * self.n_layer
if inputs_embeds is not None:
word_embeds = inputs_embeds
else:
word_embeds = self.word_emb(input_ids)
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.n_layer)
else:
past_length = past_key_values[0][0].size(0)
mem_length = mems[0].size(0) if mems is not None else 0
k_length = mem_length + q_length
if self.same_length:
all_ones = word_embeds.new_ones((q_length, k_length + past_length), dtype=torch.uint8)
mask_length = k_length - self.mem_len
if mask_length > 0:
mask_shifted_length = q_length - mask_length
else:
mask_shifted_length = q_length
dec_attn_mask = (
torch.triu(all_ones, 1 + mem_length + past_length) + torch.tril(all_ones, -mask_shifted_length)
)[:, :, None]
else:
dec_attn_mask = torch.triu(
word_embeds.new_ones((q_length, k_length + past_length), dtype=torch.uint8),
diagonal=1 + mem_length + past_length,
)[:, :, None]
hidden_states = []
attentions = [] if output_attentions else None
presents = () if use_cache else None
pos_sequence = torch.arange(
k_length + past_length - 1,
past_length - 1,
-1.0,
device=word_embeds.device,
dtype=word_embeds.dtype,
)
if self.clamp_len > 0:
pos_sequence.clamp_(max=self.clamp_len)
pos_embeds = self.pos_emb(pos_sequence)
pos_embeds = self.drop(pos_embeds)
output = self.drop(word_embeds)
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
hidden_states.append(output)
mems_i = None if mems is None else mems[i]
layer_output = layer(
output,
pos_embeds,
layer_past=layer_past,
dec_attn_mask=dec_attn_mask,
mems=mems_i,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
output = layer_output[0]
if use_cache is True:
presents = presents + (layer_output[1],)
if output_attentions:
attentions.append(layer_output[2])
output = self.drop(output)
new_mems = self._update_mems(hidden_states, mems, mem_length, q_length)
if output_hidden_states:
# (batch_size, length, d_model)
hidden_states.append(output)
hidden_states = tuple(t.transpose(0, 1).contiguous() for t in hidden_states)
else:
hidden_states = None
if output_attentions:
# (batch_size, n_heads, q_length, k_length)
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
# (batch_size, length, d_model)
output = output.transpose(0, 1).contiguous()
if not return_dict:
return tuple(v for v in [output, presents, new_mems, hidden_states, attentions] if v is not None)
return MemTransformerBaseOutput(
last_hidden_state=output,
past_key_values=presents,
mems=new_mems,
hidden_states=hidden_states,
attentions=attentions,
)
[docs]class MemTransformerLMHeadModel(TransfoXLPreTrainedModel):
config_class = MemTransformerConfig
def __init__(self, config: MemTransformerConfig) -> None:
super().__init__(config)
self.transformer = MemTransformerModel(config)
if self.config.tie_word_embeddings:
emb_weights = [emb_layer.weight for emb_layer in self.transformer.word_emb.emb_layers]
else:
emb_weights = None
emb_projs = self.transformer.word_emb.emb_projs
self.crit = ProjectedAdaptiveLogSoftmax(
config.vocab_size,
config.d_embed,
config.d_model,
config.cutoffs,
config.tie_projs,
emb_projs=emb_projs,
emb_weights=emb_weights,
div_val=config.div_val,
)
self.init_weights()
[docs] def tie_weights(self) -> None:
# Mockup to disable weight tieing as it is already being done
pass
[docs] def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> MemTransformerOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None:
batch_size, target_length = input_ids.size(0), input_ids.size(1)
elif inputs_embeds is not None:
batch_size, target_length = inputs_embeds.size(0), inputs_embeds.size(1)
else:
raise ValueError("You have to specify either `input_ids` or `inputs_embeds`")
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = transformer_outputs[0]
pred_hidden_state = last_hidden_state[:, -target_length:]
if labels is not None:
# Prevents all labels being -100 and throwing an error
# when backwarding the loss
miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100
if miss_valid_label:
# Sets an <EOS> token, just to prevent loss from being NaN
labels[0, 1] = self.config.eos_token_id
softmax_output = self.crit(pred_hidden_state, labels)
if labels is not None:
prediction_scores = self.crit(pred_hidden_state, None).detach()
prediction_scores = prediction_scores.view(batch_size, target_length, -1)
loss = softmax_output.view(batch_size, target_length - 1)
loss = loss[loss != 0].mean()
else:
prediction_scores = softmax_output.view(batch_size, target_length, -1)
loss = None
if not return_dict:
output = (prediction_scores,) + transformer_outputs[1:]
if loss is not None:
return (loss,) + output
return output
return MemTransformerOutput(
loss=loss,
prediction_scores=prediction_scores,
past_key_values=transformer_outputs.past_key_values,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)