# Copyright (c) Microsoft Corporation.# Licensed under the MIT License."""ONNX IR enums that matches the ONNX spec."""from__future__importannotationsimportenumimportml_dtypesimportnumpyasnpclassAttributeType(enum.IntEnum):"""Enum for the types of ONNX attributes."""UNDEFINED=0FLOAT=1INT=2STRING=3TENSOR=4GRAPH=5FLOATS=6INTS=7STRINGS=8TENSORS=9GRAPHS=10SPARSE_TENSOR=11SPARSE_TENSORS=12TYPE_PROTO=13TYPE_PROTOS=14def__repr__(self)->str:returnself.namedef__str__(self)->str:returnself.__repr__()classDataType(enum.IntEnum):"""Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""# NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,# but we should stick to the names used in the ONNX spec for consistency.UNDEFINED=0FLOAT=1UINT8=2INT8=3UINT16=4INT16=5INT32=6INT64=7STRING=8BOOL=9FLOAT16=10DOUBLE=11UINT32=12UINT64=13COMPLEX64=14COMPLEX128=15BFLOAT16=16FLOAT8E4M3FN=17FLOAT8E4M3FNUZ=18FLOAT8E5M2=19FLOAT8E5M2FNUZ=20UINT4=21INT4=22FLOAT4E2M1=23
[docs]@classmethoddeffrom_numpy(cls,dtype:np.dtype)->DataType:"""Returns the ONNX data type for the numpy dtype. Raises: TypeError: If the data type is not supported by ONNX. """ifdtypein_NP_TYPE_TO_DATA_TYPE:returncls(_NP_TYPE_TO_DATA_TYPE[dtype])ifnp.issubdtype(dtype,np.str_):returnDataType.STRING# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.pyifhasattr(dtype,"names"):ifdtype.names==("bfloat16",):returnDataType.BFLOAT16ifdtype.names==("e4m3fn",):returnDataType.FLOAT8E4M3FNifdtype.names==("e4m3fnuz",):returnDataType.FLOAT8E4M3FNUZifdtype.names==("e5m2",):returnDataType.FLOAT8E5M2ifdtype.names==("e5m2fnuz",):returnDataType.FLOAT8E5M2FNUZifdtype.names==("uint4",):returnDataType.UINT4ifdtype.names==("int4",):returnDataType.INT4ifdtype.names==("float4e2m1",):returnDataType.FLOAT4E2M1raiseTypeError(f"Unsupported numpy data type: {dtype}")
@propertydefitemsize(self)->float:"""Returns the size of the data type in bytes."""return_ITEMSIZE_MAP[self]
[docs]defnumpy(self)->np.dtype:"""Returns the numpy dtype for the ONNX data type. Raises: TypeError: If the data type is not supported by numpy. """ifselfnotin_DATA_TYPE_TO_NP_TYPE:raiseTypeError(f"Numpy does not support ONNX data type: {self}")return_DATA_TYPE_TO_NP_TYPE[self]
def__repr__(self)->str:returnself.namedef__str__(self)->str:returnself.__repr__()_ITEMSIZE_MAP={DataType.FLOAT:4,DataType.UINT8:1,DataType.INT8:1,DataType.UINT16:2,DataType.INT16:2,DataType.INT32:4,DataType.INT64:8,DataType.STRING:1,DataType.BOOL:1,DataType.FLOAT16:2,DataType.DOUBLE:8,DataType.UINT32:4,DataType.UINT64:8,DataType.COMPLEX64:8,DataType.COMPLEX128:16,DataType.BFLOAT16:2,DataType.FLOAT8E4M3FN:1,DataType.FLOAT8E4M3FNUZ:1,DataType.FLOAT8E5M2:1,DataType.FLOAT8E5M2FNUZ:1,DataType.UINT4:0.5,DataType.INT4:0.5,DataType.FLOAT4E2M1:0.5,}# We use ml_dtypes to support dtypes that are not in numpy._NP_TYPE_TO_DATA_TYPE={np.dtype("bool"):DataType.BOOL,np.dtype("complex128"):DataType.COMPLEX128,np.dtype("complex64"):DataType.COMPLEX64,np.dtype("float16"):DataType.FLOAT16,np.dtype("float32"):DataType.FLOAT,np.dtype("float64"):DataType.DOUBLE,np.dtype("int16"):DataType.INT16,np.dtype("int32"):DataType.INT32,np.dtype("int64"):DataType.INT64,np.dtype("int8"):DataType.INT8,np.dtype("object"):DataType.STRING,np.dtype("uint16"):DataType.UINT16,np.dtype("uint32"):DataType.UINT32,np.dtype("uint64"):DataType.UINT64,np.dtype("uint8"):DataType.UINT8,np.dtype(ml_dtypes.bfloat16):DataType.BFLOAT16,np.dtype(ml_dtypes.float8_e4m3fn):DataType.FLOAT8E4M3FN,np.dtype(ml_dtypes.float8_e4m3fnuz):DataType.FLOAT8E4M3FNUZ,np.dtype(ml_dtypes.float8_e5m2):DataType.FLOAT8E5M2,np.dtype(ml_dtypes.float8_e5m2fnuz):DataType.FLOAT8E5M2FNUZ,np.dtype(ml_dtypes.int4):DataType.INT4,np.dtype(ml_dtypes.uint4):DataType.UINT4,}# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE_NP_TYPE_TO_DATA_TYPE.update({np.dtype(ml_dtypes.float4_e2m1fn):DataType.FLOAT4E2M1}ifhasattr(ml_dtypes,"float4_e2m1fn")else{})# ONNX DataType to Numpy dtype._DATA_TYPE_TO_NP_TYPE={v:kfork,vin_NP_TYPE_TO_DATA_TYPE.items()}