Source code for archai.onnx.onnx_forward

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F


[docs]def gpt2_onnx_forward( self, input_ids: torch.LongTensor, past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None, ) -> Dict[str, torch.FloatTensor]: """Forward pass through the GPT-2 model with ONNX exportability. This method overrides the default GPT-2 forward method and returns both output probabilities and past key/values. Args: input_ids: Input tensor. past_key_values: Past pre-computed key/values tensor. Returns: Output probabilities and past key/values. """ outputs_dict = {} outputs = self.transformer(input_ids, past_key_values=past_key_values) last_hidden_state = outputs.last_hidden_state past_key_values = outputs.past_key_values logits = F.softmax(self.lm_head(last_hidden_state[:, -1, :]), dim=-1) outputs_dict["logits"] = logits if past_key_values: past_key_values = tuple([torch.stack(p) for p in past_key_values]) outputs_dict["past_key_values"] = past_key_values return outputs_dict