114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
from typing import NamedTuple
|
|
|
|
import ml_dtypes
|
|
import numpy as np
|
|
|
|
from onnx.onnx_pb import TensorProto
|
|
|
|
|
|
class TensorDtypeMap(NamedTuple):
|
|
np_dtype: np.dtype
|
|
storage_dtype: int
|
|
name: str
|
|
|
|
|
|
# tensor_dtype: (numpy type, storage type, string name)
|
|
# The storage type is the type used to store the tensor in the *_data field of
|
|
# a TensorProto. All available fields are float_data, int32_data, int64_data,
|
|
# string_data, uint64_data and double_data.
|
|
TENSOR_TYPE_MAP: dict[int, TensorDtypeMap] = {
|
|
int(TensorProto.FLOAT): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
|
|
),
|
|
int(TensorProto.UINT8): TensorDtypeMap(
|
|
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
|
|
),
|
|
int(TensorProto.INT8): TensorDtypeMap(
|
|
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
|
|
),
|
|
int(TensorProto.UINT16): TensorDtypeMap(
|
|
np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
|
|
),
|
|
int(TensorProto.INT16): TensorDtypeMap(
|
|
np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
|
|
),
|
|
int(TensorProto.INT32): TensorDtypeMap(
|
|
np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
|
|
),
|
|
int(TensorProto.INT64): TensorDtypeMap(
|
|
np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
|
|
),
|
|
int(TensorProto.BOOL): TensorDtypeMap(
|
|
np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
|
|
),
|
|
int(TensorProto.FLOAT16): TensorDtypeMap(
|
|
np.dtype("float16"), int(TensorProto.INT32), "TensorProto.FLOAT16"
|
|
),
|
|
int(TensorProto.BFLOAT16): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.bfloat16),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.BFLOAT16",
|
|
),
|
|
int(TensorProto.DOUBLE): TensorDtypeMap(
|
|
np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
|
|
),
|
|
int(TensorProto.COMPLEX64): TensorDtypeMap(
|
|
np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
|
|
),
|
|
int(TensorProto.COMPLEX128): TensorDtypeMap(
|
|
np.dtype("complex128"),
|
|
int(TensorProto.DOUBLE),
|
|
"TensorProto.COMPLEX128",
|
|
),
|
|
int(TensorProto.UINT32): TensorDtypeMap(
|
|
np.dtype("uint32"), int(TensorProto.UINT64), "TensorProto.UINT32"
|
|
),
|
|
int(TensorProto.UINT64): TensorDtypeMap(
|
|
np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
|
|
),
|
|
int(TensorProto.STRING): TensorDtypeMap(
|
|
np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
|
|
),
|
|
int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float8_e4m3fn),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT8E4M3FN",
|
|
),
|
|
int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float8_e4m3fnuz),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT8E4M3FNUZ",
|
|
),
|
|
int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float8_e5m2),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT8E5M2",
|
|
),
|
|
int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float8_e5m2fnuz),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT8E5M2FNUZ",
|
|
),
|
|
int(TensorProto.UINT4): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.uint4), int(TensorProto.INT32), "TensorProto.UINT4"
|
|
),
|
|
int(TensorProto.INT4): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.int4), int(TensorProto.INT32), "TensorProto.INT4"
|
|
),
|
|
int(TensorProto.FLOAT4E2M1): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float4_e2m1fn),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT4E2M1",
|
|
),
|
|
int(TensorProto.FLOAT8E8M0): TensorDtypeMap(
|
|
np.dtype(ml_dtypes.float8_e8m0fnu),
|
|
int(TensorProto.INT32),
|
|
"TensorProto.FLOAT8E8M0",
|
|
),
|
|
}
|