DriverTrac/venv/lib/python3.12/site-packages/onnx/_mapping.py
2025-11-28 09:08:33 +05:30

114 lines
3.9 KiB
Python

# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import NamedTuple
import ml_dtypes
import numpy as np
from onnx.onnx_pb import TensorProto
class TensorDtypeMap(NamedTuple):
np_dtype: np.dtype
storage_dtype: int
name: str
# tensor_dtype: (numpy type, storage type, string name)
# The storage type is the type used to store the tensor in the *_data field of
# a TensorProto. All available fields are float_data, int32_data, int64_data,
# string_data, uint64_data and double_data.
TENSOR_TYPE_MAP: dict[int, TensorDtypeMap] = {
int(TensorProto.FLOAT): TensorDtypeMap(
np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
),
int(TensorProto.UINT8): TensorDtypeMap(
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
),
int(TensorProto.INT8): TensorDtypeMap(
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
),
int(TensorProto.UINT16): TensorDtypeMap(
np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
),
int(TensorProto.INT16): TensorDtypeMap(
np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
),
int(TensorProto.INT32): TensorDtypeMap(
np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
),
int(TensorProto.INT64): TensorDtypeMap(
np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
),
int(TensorProto.BOOL): TensorDtypeMap(
np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
),
int(TensorProto.FLOAT16): TensorDtypeMap(
np.dtype("float16"), int(TensorProto.INT32), "TensorProto.FLOAT16"
),
int(TensorProto.BFLOAT16): TensorDtypeMap(
np.dtype(ml_dtypes.bfloat16),
int(TensorProto.INT32),
"TensorProto.BFLOAT16",
),
int(TensorProto.DOUBLE): TensorDtypeMap(
np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
),
int(TensorProto.COMPLEX64): TensorDtypeMap(
np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
),
int(TensorProto.COMPLEX128): TensorDtypeMap(
np.dtype("complex128"),
int(TensorProto.DOUBLE),
"TensorProto.COMPLEX128",
),
int(TensorProto.UINT32): TensorDtypeMap(
np.dtype("uint32"), int(TensorProto.UINT64), "TensorProto.UINT32"
),
int(TensorProto.UINT64): TensorDtypeMap(
np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
),
int(TensorProto.STRING): TensorDtypeMap(
np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
),
int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
np.dtype(ml_dtypes.float8_e4m3fn),
int(TensorProto.INT32),
"TensorProto.FLOAT8E4M3FN",
),
int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
np.dtype(ml_dtypes.float8_e4m3fnuz),
int(TensorProto.INT32),
"TensorProto.FLOAT8E4M3FNUZ",
),
int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
np.dtype(ml_dtypes.float8_e5m2),
int(TensorProto.INT32),
"TensorProto.FLOAT8E5M2",
),
int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
np.dtype(ml_dtypes.float8_e5m2fnuz),
int(TensorProto.INT32),
"TensorProto.FLOAT8E5M2FNUZ",
),
int(TensorProto.UINT4): TensorDtypeMap(
np.dtype(ml_dtypes.uint4), int(TensorProto.INT32), "TensorProto.UINT4"
),
int(TensorProto.INT4): TensorDtypeMap(
np.dtype(ml_dtypes.int4), int(TensorProto.INT32), "TensorProto.INT4"
),
int(TensorProto.FLOAT4E2M1): TensorDtypeMap(
np.dtype(ml_dtypes.float4_e2m1fn),
int(TensorProto.INT32),
"TensorProto.FLOAT4E2M1",
),
int(TensorProto.FLOAT8E8M0): TensorDtypeMap(
np.dtype(ml_dtypes.float8_e8m0fnu),
int(TensorProto.INT32),
"TensorProto.FLOAT8E8M0",
),
}