# Copyright (c) ONNX Project Contributors # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations __all__ = [ # Constants "ONNX_ML", "IR_VERSION", "IR_VERSION_2017_10_10", "IR_VERSION_2017_10_30", "IR_VERSION_2017_11_3", "IR_VERSION_2019_1_22", "IR_VERSION_2019_3_18", "IR_VERSION_2019_9_19", "IR_VERSION_2020_5_8", "IR_VERSION_2021_7_30", "IR_VERSION_2023_5_5", "IR_VERSION_2024_3_25", "EXPERIMENTAL", "STABLE", # Modules "checker", "compose", "defs", "gen_proto", "helper", "hub", "numpy_helper", "parser", "printer", "shape_inference", "utils", "version_converter", # Proto classes "AttributeProto", "DeviceConfigurationProto", "FunctionProto", "GraphProto", "IntIntListEntryProto", "MapProto", "ModelProto", "NodeDeviceConfigurationProto", "NodeProto", "OperatorProto", "OperatorSetIdProto", "OperatorSetProto", "OperatorStatus", "OptionalProto", "SequenceProto", "SimpleShardedDimProto", "ShardedDimProto", "ShardingSpecProto", "SparseTensorProto", "StringStringEntryProto", "TensorAnnotation", "TensorProto", "TensorShapeProto", "TrainingInfoProto", "TypeProto", "ValueInfoProto", "Version", # Utility functions "convert_model_to_external_data", "load_external_data_for_model", "load_model_from_string", "load_model", "load_tensor_from_string", "load_tensor", "save_model", "save_tensor", "write_external_data_tensors", ] # isort:skip_file import os import typing from typing import IO, Literal, Union from onnx import serialization from onnx.onnx_cpp2py_export import ONNX_ML from onnx.external_data_helper import ( load_external_data_for_model, write_external_data_tensors, convert_model_to_external_data, ) from onnx.onnx_pb import ( AttributeProto, DeviceConfigurationProto, EXPERIMENTAL, FunctionProto, GraphProto, IntIntListEntryProto, IR_VERSION, IR_VERSION_2017_10_10, IR_VERSION_2017_10_30, IR_VERSION_2017_11_3, IR_VERSION_2019_1_22, IR_VERSION_2019_3_18, IR_VERSION_2019_9_19, IR_VERSION_2020_5_8, IR_VERSION_2021_7_30, IR_VERSION_2023_5_5, IR_VERSION_2024_3_25, ModelProto, NodeDeviceConfigurationProto, NodeProto, OperatorSetIdProto, OperatorStatus, STABLE, SimpleShardedDimProto, ShardedDimProto, ShardingSpecProto, SparseTensorProto, StringStringEntryProto, TensorAnnotation, TensorProto, TensorShapeProto, TrainingInfoProto, TypeProto, ValueInfoProto, Version, ) from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto import onnx.version # Import common subpackages so they're available when you 'import onnx' from onnx import ( checker, compose, defs, gen_proto, helper, hub, numpy_helper, parser, printer, shape_inference, utils, version_converter, ) if typing.TYPE_CHECKING: from collections.abc import Sequence __version__ = onnx.version.version # Supported model formats that can be loaded from and saved to # The literals are formats with built-in support. But we also allow users to # register their own formats. So we allow str as well. _SupportedFormat = Union[ Literal["protobuf", "textproto", "onnxtxt", "json"], str # noqa: PYI051 ] # Default serialization format _DEFAULT_FORMAT = "protobuf" def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes: if hasattr(f, "read") and callable(typing.cast("IO[bytes]", f).read): content = typing.cast("IO[bytes]", f).read() else: f = typing.cast("Union[str, os.PathLike]", f) with open(f, "rb") as readable: content = readable.read() return content def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None: if hasattr(f, "write") and callable(typing.cast("IO[bytes]", f).write): typing.cast("IO[bytes]", f).write(content) else: f = typing.cast("Union[str, os.PathLike]", f) with open(f, "wb") as writable: writable.write(content) def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None: if isinstance(f, (str, os.PathLike)): return os.path.abspath(f) if hasattr(f, "name"): assert f is not None return os.path.abspath(f.name) return None def _get_serializer( fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None ) -> serialization.ProtoSerializer: """Get the serializer for the given path and format from the serialization registry.""" # Use fmt if it is specified if fmt is not None: return serialization.registry.get(fmt) if (file_path := _get_file_path(f)) is not None: _, ext = os.path.splitext(file_path) fmt = serialization.registry.get_format_from_file_extension(ext) # Failed to resolve format if fmt is None. Use protobuf as default fmt = fmt or _DEFAULT_FORMAT assert fmt is not None return serialization.registry.get(fmt) def load_model( f: IO[bytes] | str | os.PathLike, format: _SupportedFormat | None = None, # noqa: A002 load_external_data: bool = True, ) -> ModelProto: """Loads a serialized ModelProto into memory. Args: f: can be a file-like object (has "read" function) or a string/PathLike containing a file name format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. load_external_data: Whether to load the external data. Set to True if the data is under the same directory of the model. If not, users need to call :func:`load_external_data_for_model` with directory to load external data from. Returns: Loaded in-memory ModelProto. """ model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto()) if load_external_data: model_filepath = _get_file_path(f) if model_filepath: base_dir = os.path.dirname(model_filepath) load_external_data_for_model(model, base_dir) return model def load_tensor( f: IO[bytes] | str | os.PathLike, format: _SupportedFormat | None = None, # noqa: A002 ) -> TensorProto: """Loads a serialized TensorProto into memory. Args: f: can be a file-like object (has "read" function) or a string/PathLike containing a file name format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. Returns: Loaded in-memory TensorProto. """ return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto()) def load_model_from_string( s: bytes | str, format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 ) -> ModelProto: """Loads a binary string (bytes) that contains serialized ModelProto. Args: s: a string, which contains serialized ModelProto format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. Returns: Loaded in-memory ModelProto. """ return _get_serializer(format).deserialize_proto(s, ModelProto()) def load_tensor_from_string( s: bytes, format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 ) -> TensorProto: """Loads a binary string (bytes) that contains serialized TensorProto. Args: s: a string, which contains serialized TensorProto format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. Returns: Loaded in-memory TensorProto. """ return _get_serializer(format).deserialize_proto(s, TensorProto()) def save_model( proto: ModelProto | bytes, f: IO[bytes] | str | os.PathLike, format: _SupportedFormat | None = None, # noqa: A002 *, save_as_external_data: bool = False, all_tensors_to_one_file: bool = True, location: str | None = None, size_threshold: int = 1024, convert_attribute: bool = False, ) -> None: """Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving. Args: proto: should be a in-memory ModelProto f: can be a file-like object (has "write" function) or a string containing a file name or a pathlike object format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. save_as_external_data: If true, save tensors to external file(s). all_tensors_to_one_file: Effective only if save_as_external_data is True. If true, save all tensors to one external file specified by location. If false, save each tensor to a file named with the tensor name. location: Effective only if save_as_external_data is true. Specify the external file that all tensors to save to. Path is relative to the model path. If not specified, will use the model name. size_threshold: Effective only if save_as_external_data is True. Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0. convert_attribute: Effective only if save_as_external_data is True. If true, convert all tensors to external data If false, convert only non-attribute tensors to external data """ if isinstance(proto, bytes): proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto()) if save_as_external_data: convert_model_to_external_data( proto, all_tensors_to_one_file, location, size_threshold, convert_attribute ) model_filepath = _get_file_path(f) if model_filepath is not None: basepath = os.path.dirname(model_filepath) proto = write_external_data_tensors(proto, basepath) serialized = _get_serializer(format, model_filepath).serialize_proto(proto) _save_bytes(serialized, f) def save_tensor( proto: TensorProto, f: IO[bytes] | str | os.PathLike, format: _SupportedFormat | None = None, # noqa: A002 ) -> None: """Saves the TensorProto to the specified path. Args: proto: should be a in-memory TensorProto f: can be a file-like object (has "write" function) or a string containing a file name or a pathlike object. format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. """ serialized = _get_serializer(format, f).serialize_proto(proto) _save_bytes(serialized, f) # For backward compatibility load = load_model load_from_string = load_model_from_string save = save_model def _model_proto_repr(self: ModelProto) -> str: if self.domain: domain = f", domain='{self.domain}'" else: domain = "" if self.producer_name: producer_name = f", producer_name='{self.producer_name}'" else: producer_name = "" if self.producer_version: producer_version = f", producer_version='{self.producer_version}'" else: producer_version = "" if self.graph: graph = f", graph={self.graph!r}" else: graph = "" if self.functions: functions = f", functions=<{len(self.functions)} functions>" else: functions = "" if self.opset_import: opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}" else: opset_import = "" return f"ModelProto(ir_version={self.ir_version}{opset_import}{domain}{producer_name}{producer_version}{graph}{functions})" def _graph_proto_repr(self: GraphProto) -> str: if self.initializer: initializer = f", initializer=<{len(self.initializer)} initializers>" else: initializer = "" if self.node: node = f", node=<{len(self.node)} nodes>" else: node = "" if self.value_info: value_info = f", value_info=<{len(self.value_info)} value_info>" else: value_info = "" if self.input: input = f", input=<{len(self.input)} inputs>" else: input = "" if self.output: output = f", output=<{len(self.output)} outputs>" else: output = "" return f"GraphProto('{self.name}'{input}{output}{initializer}{node}{value_info})" def _function_proto_repr(self: FunctionProto) -> str: if self.domain: domain = f", domain='{self.domain}'" else: domain = "" if self.overload: overload = f", overload='{self.overload}'" else: overload = "" if self.node: node = f", node=<{len(self.node)} nodes>" else: node = "" if self.attribute: attribute = f", attribute={self.attribute}" else: attribute = "" if self.opset_import: opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}" else: opset_import = "" if self.input: input = f", input=<{len(self.input)} inputs>" else: input = "" if self.output: output = f", output=<{len(self.output)} outputs>" else: output = "" return f"FunctionProto('{self.name}'{domain}{overload}{opset_import}{input}{output}{attribute}{node})" def _operator_set_protos_repr(protos: Sequence[OperatorSetIdProto]) -> str: opset_imports = {proto.domain: proto.version for proto in protos} return repr(opset_imports) # Override __repr__ for some proto classes to make it more efficient ModelProto.__repr__ = _model_proto_repr # type: ignore[method-assign,assignment] GraphProto.__repr__ = _graph_proto_repr # type: ignore[method-assign,assignment] FunctionProto.__repr__ = _function_proto_repr # type: ignore[method-assign,assignment]