Optimization Utilities#

Fusion Options#

class archai.onnx.optimization_utils.fusion_options.AttentionMaskFormat[source]#

Enumerate the attention mask shape.

MaskIndexEnd = 0#
MaskIndexEndAndStart = 1#
AttentionMask = 2#
NoMask = 3#
class archai.onnx.optimization_utils.fusion_options.FusionOptions(model_type: str)[source]#

Options to control the fusion of operators in the ONNX graph.

use_raw_attention_mask(use_raw_mask: bool | None = True) None[source]#

Enable the usage of raw attention mask.

Parameters:

use_raw_mask – Whether raw mask should be used or not.

disable_attention_mask() None[source]#

Disable the usage of attention mask.

Transfomer-XL ONNX Model#

class archai.onnx.optimization_utils.transfo_xl_onnx_model.TransfoXLOnnxModel(model: ModelProto)[source]#

ONNX model optimized for Transformer-XL models.

This model extends the OnnxModel class by enabling additional ONNX optimizations.

change_graph_input_type(graph: GraphProto, graph_input: ValueInfoProto, new_type: int | None = 6) Tuple[NodeProto, List[NodeProto]][source]#

Change the input type of the graph and add Cast nodes if necessary.

Parameters:
  • graph – Graph instance.

  • graph_input – Graph input value.

  • new_type – New data type.

Returns:

A tuple containing a Cast node to be added and a list of Cast nodes to be removed.

change_graph_inputs_to_int32() None[source]#

Change the inputs to int32.

fuse_layer_norm() None[source]#

Fuse the appropriate nodes into a LayerNormalization layer.

fuse_skip_layer_norm() None[source]#

Fuse the appropriate nodes into a SkipLayerNormalization layer.

fuse_add_bias_skip_layer_norm() None[source]#

Fuse the appropriate nodes into a BiasSkipLayerNormalization layer.

fuse_attention() None[source]#

Fuse the appropriate nodes into an Attention layer.

fuse_reshape() None[source]#

Fuse the appropriate nodes into a Reshape layer.

fuse_shape() None[source]#

Fuse the appropriate nodes into a Shape layer.

use_dynamic_axes(dynamic_batch_dim: str | None = 'batch', dynamic_seq_len: str | None = 'sequence') None[source]#

Update inputs and outputs shapes to use dynamic axes.

Parameters:
  • dynamic_batch_dim – Name of batch size dimension.

  • dynamic_seq_len – Name of sequence length dimension.

adjust_reshape_and_expand() None[source]#

Clean up unncessary reshape nodes.

clean_graph() None[source]#

Clean the graph after fusing nodes.

optimize(options: FusionOptions | None = None, add_dynamic_axes: bool | None = False) None[source]#

Perform additional transformer-based optimization.

Parameters:
  • options – Options holding which operators should be fused.

  • add_dynamic_axes – Whether dynamic axes should be added.