# Copyright (c) ONNX Project Contributors # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import sys from typing import TYPE_CHECKING, Any import ml_dtypes import numpy as np import numpy.typing as npt import typing_extensions import onnx.external_data_helper from onnx import helper, subbyte if TYPE_CHECKING: from collections.abc import Sequence # System is little endian _IS_LITTLE_ENDIAN = sys.byteorder == "little" @typing_extensions.deprecated( "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion", category=DeprecationWarning, ) def bfloat16_to_float32( data: np.int16 | np.int32 | np.ndarray, dims: int | Sequence[int] | None = None, ) -> np.ndarray: """Converts ndarray of bf16 (as uint32) to f32 (as uint32). Args: data: A numpy array, empty dimensions are allowed if dims is None. dims: If specified, the function reshapes the results. Returns: A numpy array of float32 with the same dimension if dims is None, or reshaped to dims if specified """ shift = lambda x: x << 16 # noqa: E731 if dims is None: if len(data.shape) == 0: return shift(np.array([data]).astype(np.int32)).view(np.float32)[0] # type: ignore[no-any-return] return shift(data.astype(np.int32)).view(np.float32) # type: ignore[no-any-return] return shift(data.astype(np.int32)).reshape(dims).view(np.float32) # type: ignore[no-any-return] def _float8e4m3_to_float32_scalar(ival: int, fn: bool, uz: bool) -> np.float32: if not fn: raise NotImplementedError("fn=False is not implemented.") if ival < 0 or ival > 255: # noqa: PLR2004 raise ValueError(f"{ival} is not a float8.") if uz: exponent_bias = 8 if ival == 0x80: # noqa: PLR2004 return np.nan # type: ignore[return-value] else: exponent_bias = 7 if ival == 255: # noqa: PLR2004 return np.float32(-np.nan) if ival == 127: # noqa: PLR2004 return np.float32(np.nan) ival = np.uint32(ival) # type: ignore[assignment] expo = (ival & 0x78) >> 3 mant = ival & 0x07 sign = ival & 0x80 res = sign << 24 if expo == 0: if mant > 0: expo = 0x7F - exponent_bias if mant & 0x4 == 0: mant &= 0x3 mant <<= 1 expo -= 1 if mant & 0x4 == 0: mant &= 0x3 mant <<= 1 expo -= 1 res |= (mant & 0x3) << 21 res |= expo << 23 else: res |= mant << 20 expo += 0x7F - exponent_bias res |= expo << 23 f = np.uint32(res).view(np.float32) return f _float8e4m3_to_float32 = np.vectorize( _float8e4m3_to_float32_scalar, excluded=["fn", "uz"] ) @typing_extensions.deprecated( "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion", category=DeprecationWarning, ) def float8e4m3_to_float32( data: np.int16 | np.int32 | np.ndarray, dims: int | Sequence[int] | None = None, fn: bool = True, uz: bool = False, ) -> np.ndarray: """Converts ndarray of float8, e4m3 (as uint32) to f32 (as uint32). See :ref:`onnx-detail-float8` for technical details. Args: data: A numpy array, empty dimensions are allowed if dims is None. dims: If specified, the function reshapes the results. fn: No infinite values. uz: No negative zero. Returns: A numpy array of float32 with the same dimension if dims is None, or reshaped to dims if specified. """ if not fn: raise NotImplementedError( "float32_to_float8e4m3 not implemented with fn=False." ) res = _float8e4m3_to_float32(data, fn=fn, uz=uz) if dims is None: return res # type: ignore[no-any-return] return res.reshape(dims) # type: ignore[no-any-return] def _float8e5m2_to_float32_scalar(ival: int, fn: bool, uz: bool) -> np.float32: if fn and uz: if ival == 0x80: # noqa: PLR2004 return np.float32(np.nan) exponent_bias = 16 elif not fn and not uz: if ival in {253, 254, 255}: return np.float32(-np.nan) if ival in {125, 126, 127}: return np.float32(np.nan) if ival == 252: # noqa: PLR2004 return np.float32(-np.inf) if ival == 124: # noqa: PLR2004 return np.float32(np.inf) exponent_bias = 15 else: raise NotImplementedError("fn and uz must be both False or True.") ival = np.uint32(ival) # type: ignore[assignment] expo = (ival & 0x7C) >> 2 mant = ival & 0x03 sign = ival & 0x80 res = sign << 24 if expo == 0: if mant > 0: expo = 0x7F - exponent_bias if mant & 0x2 == 0: mant &= 0x1 mant <<= 1 expo -= 1 res |= (mant & 0x1) << 22 res |= expo << 23 else: res |= mant << 21 expo += 0x7F - exponent_bias res |= expo << 23 f = np.uint32(res).view(np.float32) return f _float8e5m2_to_float32 = np.vectorize( _float8e5m2_to_float32_scalar, excluded=["fn", "uz"] ) @typing_extensions.deprecated( "Deprecated since 1.18. Scheduled to remove in 1.20. Consider using libraries like ml_dtypes for dtype conversion", category=DeprecationWarning, ) def float8e5m2_to_float32( data: np.int16 | np.int32 | np.ndarray, dims: int | Sequence[int] | None = None, fn: bool = False, uz: bool = False, ) -> np.ndarray: """Converts ndarray of float8, e5m2 (as uint32) to f32 (as uint32). See :ref:`onnx-detail-float8` for technical details. Args: data: A numpy array, empty dimensions are allowed if dims is None. dims: If specified, the function reshapes the results. fn: No infinite values. uz: No negative zero. Returns: A numpy array of float32 with the same dimension if dims is None, or reshaped to dims if specified """ res = _float8e5m2_to_float32(data, fn=fn, uz=uz) if dims is None: return res # type: ignore[no-any-return] return res.reshape(dims) # type: ignore[no-any-return] def to_float8e8m0( x: np.ndarray, saturate: bool = True, round_mode: str = "up", ) -> np.ndarray: """Convert float32 NumPy array to float8e8m0 representation. If the input is not a float32 array, it will be cast to one first. Args: x: Input array to convert. saturate: Whether to saturate at max/min float8e8m0 value. round_mode: "nearest", "up", or "down". Returns: np.ndarray: Array of ml_dtypes.float8_e8m0fnu values. """ x_f32 = np.asarray(x, dtype=np.float32) f_bits = x_f32.view(np.uint32) # Extract exponent bits exponent = (f_bits >> 23) & 0xFF exponent = exponent.astype( np.uint16 ) # use uint16 to prevent overflow during computation # Identify NaN or Inf special_mask = exponent == 0xFF # noqa: PLR2004 output = np.zeros_like(exponent, dtype=np.uint8) output[special_mask] = 0xFF # Preserve NaN/Inf as max exponent # Process normal numbers normal_mask = ~special_mask if round_mode == "nearest": # Get guard, round, sticky, and least significant bits g = ((f_bits & 0x400000) > 0).astype(np.uint8) r = ((f_bits & 0x200000) > 0).astype(np.uint8) s = ((f_bits & 0x1FFFFF) > 0).astype(np.uint8) lsb = (exponent > 0).astype(np.uint8) round_up = (g == 1) & ((r == 1) | (s == 1) | (lsb == 1)) increment = np.zeros_like(exponent) increment[round_up & normal_mask] = 1 if saturate: max_mask = (exponent == 0xFE) & round_up & normal_mask # noqa: PLR2004 increment[max_mask] = 0 # Don't overflow past max value exponent += increment elif round_mode == "up": has_fraction = (f_bits & 0x4FFFFF) > 0 round_up = has_fraction & normal_mask if saturate: max_mask = (exponent == 0xFE) & round_up # noqa: PLR2004 round_up[max_mask] = False exponent += round_up.astype(np.uint16) elif round_mode == "down": pass # No rounding needed else: raise ValueError(f"Unsupported rounding mode: {round_mode}") # Clip exponent to uint8 range exponent = exponent.astype(np.uint8) output[normal_mask] = exponent[normal_mask] return output.view(ml_dtypes.float8_e8m0fnu) @typing_extensions.deprecated( "Deprecated since 1.18. Scheduled to remove in 1.20. Consider implementing your own unpack logic", category=DeprecationWarning, ) def unpack_int4( data: np.int32 | np.ndarray, dims: int | Sequence[int], signed: bool, ) -> np.ndarray: """Converts ndarray of int4 (as packed uint8) to f32 See :ref:`onnx-detail-int4` for technical details. Args: data: A numpy array, empty dimensions are allowed if dims is None. dims: The dimensions are used to reshape the unpacked buffer signed: Whether the 4 bit integer is signed or unsigned Returns: A numpy array of float32 reshaped to dims. """ single_func = lambda x: subbyte.unpack_single_4bitx2(x, signed) # noqa: E731 func = np.frompyfunc(single_func, 1, 2) res_high, res_low = func(data.ravel()) res = np.empty((res_high.size + res_low.size,), dtype=np.float32) res[0::2] = res_high res[1::2] = res_low if ( res.size == np.prod(dims) + 1 ): # handle single-element padding due to odd number of elements res = res.ravel()[:-1] res = res.reshape(dims) return res def _unpacked_float4e2m1_to_float32( x: npt.NDArray[np.uint8], ) -> npt.NDArray[np.float32]: """Evaluate the numerical value of an array of unpacked float4e2m1 values (as uint8) See :ref:`onnx-detail-int4` for technical details. Args: x: an array of uint8 elements representing a float4e2m1 (using the 4 LSB) Returns: An array of float32 elements representing the values of the float4e2m1 input. """ # x is stored in 4 LSB of int sign = np.where(np.bitwise_and(x, 0x08), -1, 1) mantissa = (x & 0x01).astype(np.float32) exponent = ((x & 0x06) >> 1).astype(np.float32) val = np.where( exponent == 0, sign * (mantissa / 2.0), sign * (1.0 + mantissa / 2.0) * 2.0 ** (exponent - 1), ) # denormalized, normalized return val def _unpack_4bit( data: npt.NDArray[np.uint8], dims: Sequence[int] ) -> npt.NDArray[np.uint8]: """Convert a packed uint4 array to unpacked uint4 array represented as uint8. Args: data: A numpy array. dims: The dimensions are used to reshape the unpacked buffer. Returns: A numpy array of int8/uint8 reshaped to dims. """ result = np.empty([data.size * 2], dtype=data.dtype) array_low = data & np.uint8(0x0F) array_high = data & np.uint8(0xF0) array_high >>= np.uint8(4) result[0::2] = array_low result[1::2] = array_high if result.size == np.prod(dims) + 1: # handle single-element padding due to odd number of elements result = result[:-1] result.resize(dims, refcheck=False) return result def _pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]: """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range.""" # Create a 1D copy array_flat = array.ravel().view(np.uint8).copy() size = array.size odd_sized = size % 2 == 1 if odd_sized: array_flat.resize([size + 1], refcheck=False) array_flat &= 0x0F array_flat[1::2] <<= 4 return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type] def to_array(tensor: onnx.TensorProto, base_dir: str = "") -> np.ndarray: # noqa: PLR0911 """Converts a tensor def object to a numpy array. This function uses ml_dtypes if the dtype is not a native numpy dtype. Args: tensor: a TensorProto object. base_dir: if external tensor exists, base_dir can help to find the path to it Returns: arr: the converted array. """ if tensor.HasField("segment"): raise ValueError("Currently not supporting loading segments.") if tensor.data_type == onnx.TensorProto.UNDEFINED: raise TypeError("The element type in the input tensor is UNDEFINED.") tensor_dtype = tensor.data_type np_dtype = helper.tensor_dtype_to_np_dtype(tensor_dtype) storage_np_dtype = helper.tensor_dtype_to_np_dtype( helper.tensor_dtype_to_storage_tensor_dtype(tensor_dtype) ) storage_field = helper.tensor_dtype_to_field(tensor_dtype) dims = tensor.dims if tensor.data_type == onnx.TensorProto.STRING: utf8_strings = getattr(tensor, storage_field) ss = [s.decode("utf-8") for s in utf8_strings] return np.asarray(ss).astype(np_dtype).reshape(dims) # Load raw data from external tensor if it exists if onnx.external_data_helper.uses_external_data(tensor): onnx.external_data_helper.load_external_data_for_tensor(tensor, base_dir) if tensor.HasField("raw_data"): # Raw_bytes support: using frombuffer. raw_data = tensor.raw_data if sys.byteorder == "big": # Convert endian from little to big raw_data = np.frombuffer(raw_data, dtype=np_dtype).byteswap().tobytes() if tensor_dtype in { onnx.TensorProto.INT4, onnx.TensorProto.UINT4, onnx.TensorProto.FLOAT4E2M1, }: data = np.frombuffer(raw_data, dtype=np.uint8) return _unpack_4bit(data, dims).view(np_dtype) return np.frombuffer(raw_data, dtype=np_dtype).reshape(dims) if tensor_dtype in { onnx.TensorProto.BFLOAT16, onnx.TensorProto.FLOAT16, onnx.TensorProto.INT16, onnx.TensorProto.UINT16, }: return ( np.array(tensor.int32_data, dtype=np.int32) .view(np.uint32) .astype(np.uint16) .reshape(dims) .view(np_dtype) ) if tensor_dtype in { onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E8M0, onnx.TensorProto.BOOL, }: return ( np.array(tensor.int32_data, dtype=np.int32) .view(np.uint32) .astype(np.uint8) .view(np_dtype) .reshape(dims) ) if tensor_dtype in { onnx.TensorProto.UINT4, onnx.TensorProto.INT4, onnx.TensorProto.FLOAT4E2M1, }: data = ( np.array(tensor.int32_data, dtype=np.int32).view(np.uint32).astype(np.uint8) ) return _unpack_4bit(data, dims).view(np_dtype) data = getattr(tensor, storage_field) if tensor_dtype in (onnx.TensorProto.COMPLEX64, onnx.TensorProto.COMPLEX128): return np.array(data, dtype=storage_np_dtype).view(dtype=np_dtype).reshape(dims) return np.asarray(data, dtype=storage_np_dtype).astype(np_dtype).reshape(dims) def from_array(array: np.ndarray, /, name: str | None = None) -> onnx.TensorProto: """Converts an array into a TensorProto including Args: array: a numpy array. name: (optional) the name of the tensor. Returns: TensorProto: the converted tensor def. """ tensor = onnx.TensorProto() tensor.dims.extend(array.shape) if name: tensor.name = name if array.dtype == object or np.issubdtype(array.dtype, np.str_): # Special care for strings. tensor.data_type = onnx.TensorProto.STRING # TODO: Introduce full string support. # We flatten the array in case there are n-D arrays are specified # If you want more complex shapes then follow the below instructions. # Unlike other types where the shape is automatically inferred from # nested arrays of values, the only reliable way now to feed strings # is to put them into a flat array then specify type astype(object) # (otherwise all strings may have different types depending on their length) # and then specify shape .reshape([x, y, z]) flat_array = array.flatten() for e in flat_array: if isinstance(e, str): tensor.string_data.append(e.encode("utf-8")) elif isinstance(e, bytes): tensor.string_data.append(e) else: raise NotImplementedError( "Unrecognized object in the object array, expect a string, or array of bytes: ", str(type(e)), ) return tensor dtype = helper.np_dtype_to_tensor_dtype(array.dtype) if dtype in { onnx.TensorProto.INT4, onnx.TensorProto.UINT4, onnx.TensorProto.FLOAT4E2M1, }: # Pack the array into int4 array = _pack_4bitx2(array) if not _IS_LITTLE_ENDIAN: array = array.view(array.dtype.newbyteorder("<")) tensor.raw_data = array.tobytes() tensor.data_type = dtype return tensor def to_list(sequence: onnx.SequenceProto) -> list[Any]: """Converts a sequence def to a Python list. Args: sequence: a SequenceProto object. Returns: list: the converted list. """ elem_type = sequence.elem_type if elem_type == onnx.SequenceProto.TENSOR: return [to_array(v) for v in sequence.tensor_values] if elem_type == onnx.SequenceProto.SPARSE_TENSOR: return [to_array(v) for v in sequence.sparse_tensor_values] # type: ignore[arg-type] if elem_type == onnx.SequenceProto.SEQUENCE: return [to_list(v) for v in sequence.sequence_values] if elem_type == onnx.SequenceProto.MAP: return [to_dict(v) for v in sequence.map_values] raise TypeError("The element type in the input sequence is not supported.") def from_list( lst: list[Any], name: str | None = None, dtype: int | None = None ) -> onnx.SequenceProto: """Converts a list into a sequence def. Args: lst: a Python list name: (optional) the name of the sequence. dtype: (optional) type of element in the input list, used for specifying sequence values when converting an empty list. Returns: SequenceProto: the converted sequence def. """ sequence = onnx.SequenceProto() if name: sequence.name = name if dtype: elem_type = dtype elif len(lst) > 0: first_elem = lst[0] if isinstance(first_elem, dict): elem_type = onnx.SequenceProto.MAP elif isinstance(first_elem, list): elem_type = onnx.SequenceProto.SEQUENCE else: elem_type = onnx.SequenceProto.TENSOR else: # if empty input list and no dtype specified # choose sequence of tensors on default elem_type = onnx.SequenceProto.TENSOR sequence.elem_type = elem_type if (len(lst) > 0) and not all(isinstance(elem, type(lst[0])) for elem in lst): raise TypeError( "The element type in the input list is not the same " "for all elements and therefore is not supported as a sequence." ) if elem_type == onnx.SequenceProto.TENSOR: for tensor in lst: sequence.tensor_values.extend([from_array(np.asarray(tensor))]) elif elem_type == onnx.SequenceProto.SEQUENCE: for seq in lst: sequence.sequence_values.extend([from_list(seq)]) elif elem_type == onnx.SequenceProto.MAP: for mapping in lst: sequence.map_values.extend([from_dict(mapping)]) else: raise TypeError( "The element type in the input list is not a tensor, " "sequence, or map and is not supported." ) return sequence def to_dict(map_proto: onnx.MapProto) -> dict[Any, Any]: """Converts a map def to a Python dictionary. Args: map_proto: a MapProto object. Returns: The converted dictionary. """ key_list: list[Any] = [] if map_proto.key_type == onnx.TensorProto.STRING: key_list = list(map_proto.string_keys) else: key_list = list(map_proto.keys) value_list = to_list(map_proto.values) if len(key_list) != len(value_list): raise IndexError( "Length of keys and values for MapProto (map name: ", map_proto.name, ") are not the same.", ) dictionary = dict(zip(key_list, value_list)) return dictionary def from_dict(dict_: dict[Any, Any], name: str | None = None) -> onnx.MapProto: """Converts a Python dictionary into a map def. Args: dict_: Python dictionary name: (optional) the name of the map. Returns: MapProto: the converted map def. """ map_proto = onnx.MapProto() if name: map_proto.name = name keys = list(dict_) raw_key_type = np.result_type(keys[0]) key_type = helper.np_dtype_to_tensor_dtype(raw_key_type) valid_key_int_types = { onnx.TensorProto.INT8, onnx.TensorProto.INT16, onnx.TensorProto.INT32, onnx.TensorProto.INT64, onnx.TensorProto.UINT8, onnx.TensorProto.UINT16, onnx.TensorProto.UINT32, onnx.TensorProto.UINT64, } if not (all(np.result_type(key) == raw_key_type for key in keys)): raise TypeError( "The key type in the input dictionary is not the same " "for all keys and therefore is not valid as a map." ) values = list(dict_.values()) raw_value_type = np.result_type(values[0]) if not all(np.result_type(val) == raw_value_type for val in values): raise TypeError( "The value type in the input dictionary is not the same " "for all values and therefore is not valid as a map." ) value_seq = from_list(values) map_proto.key_type = key_type if key_type == onnx.TensorProto.STRING: map_proto.string_keys.extend(keys) elif key_type in valid_key_int_types: map_proto.keys.extend(keys) map_proto.values.CopyFrom(value_seq) return map_proto def to_optional(optional: onnx.OptionalProto) -> Any | None: """Converts an optional def to a Python optional. Args: optional: an OptionalProto object. Returns: opt: the converted optional. """ elem_type = optional.elem_type if elem_type == onnx.OptionalProto.UNDEFINED: return None if elem_type == onnx.OptionalProto.TENSOR: return to_array(optional.tensor_value) if elem_type == onnx.OptionalProto.SPARSE_TENSOR: return to_array(optional.sparse_tensor_value) # type: ignore[arg-type] if elem_type == onnx.OptionalProto.SEQUENCE: return to_list(optional.sequence_value) if elem_type == onnx.OptionalProto.MAP: return to_dict(optional.map_value) if elem_type == onnx.OptionalProto.OPTIONAL: return to_optional(optional.optional_value) raise TypeError("The element type in the input optional is not supported.") def from_optional( opt: Any | None, name: str | None = None, dtype: int | None = None ) -> onnx.OptionalProto: """Converts an optional value into a Optional def. Args: opt: a Python optional name: (optional) the name of the optional. dtype: (optional) type of element in the input, used for specifying optional values when converting empty none. dtype must be a valid OptionalProto.DataType value Returns: optional: the converted optional def. """ # TODO: create a map and replace conditional branches optional = onnx.OptionalProto() if name: optional.name = name if dtype is not None: # dtype must be a valid onnx.OptionalProto.DataType if dtype not in onnx.OptionalProto.DataType.values(): raise TypeError(f"{dtype} must be a valid OptionalProto.DataType.") elem_type = dtype elif isinstance(opt, dict): elem_type = onnx.OptionalProto.MAP elif isinstance(opt, list): elem_type = onnx.OptionalProto.SEQUENCE elif opt is None: elem_type = onnx.OptionalProto.UNDEFINED else: elem_type = onnx.OptionalProto.TENSOR optional.elem_type = elem_type if opt is not None: if elem_type == onnx.OptionalProto.TENSOR: optional.tensor_value.CopyFrom(from_array(opt)) elif elem_type == onnx.OptionalProto.SEQUENCE: optional.sequence_value.CopyFrom(from_list(opt)) elif elem_type == onnx.OptionalProto.MAP: optional.map_value.CopyFrom(from_dict(opt)) else: raise TypeError( "The element type in the input is not a tensor, " "sequence, or map and is not supported." ) return optional def create_random_int( input_shape: tuple[int], dtype: np.dtype, seed: int = 1 ) -> np.ndarray: """Create random integer array for backend/test/case/node. Args: input_shape: The shape for the returned integer array. dtype: The NumPy data type for the returned integer array. seed: The seed for np.random. Returns: np.ndarray: Random integer array. """ np.random.seed(seed) if dtype in ( np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64, ): # the range of np.random.randint is int32; set a fixed boundary if overflow end = min(np.iinfo(dtype).max, np.iinfo(np.int32).max) start = max(np.iinfo(dtype).min, np.iinfo(np.int32).min) return np.random.randint(start, end, size=input_shape).astype(dtype) else: raise TypeError(f"{dtype} is not supported by create_random_int.") def saturate_cast(x: np.ndarray, dtype: np.dtype) -> np.ndarray: """Saturate cast for numeric types. This function ensures that values outside the representable range of the target dtype are clamped to the maximum or minimum representable value of that dtype. """ if np.issubdtype(dtype, np.integer) or dtype in (ml_dtypes.int4, ml_dtypes.uint4): info = ml_dtypes.iinfo(dtype) x = np.round(x) else: info = ml_dtypes.finfo(dtype) # type: ignore[assignment] return np.clip(x, info.min, info.max).astype(dtype)