# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import io import os import pathlib import tempfile import unittest import google.protobuf.message import google.protobuf.text_format import parameterized import onnx from onnx import serialization def _simple_model() -> onnx.ModelProto: model = onnx.ModelProto() model.ir_version = onnx.IR_VERSION model.producer_name = "onnx-test" model.graph.name = "test" return model def _simple_tensor() -> onnx.TensorProto: tensor = onnx.helper.make_tensor( name="test-tensor", data_type=onnx.TensorProto.FLOAT, dims=(2, 3, 4), vals=[x + 0.5 for x in range(24)], ) return tensor @parameterized.parameterized_class( [ {"format": "protobuf"}, {"format": "textproto"}, {"format": "json"}, {"format": "onnxtxt"}, ] ) class TestIO(unittest.TestCase): format: str def test_load_model_when_input_is_bytes(self) -> None: proto = _simple_model() proto_string = serialization.registry.get(self.format).serialize_proto(proto) loaded_proto = onnx.load_model_from_string(proto_string, format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_model_when_input_has_read_function(self) -> None: proto = _simple_model() # When the proto is a bytes representation provided to `save_model`, # it should always be a serialized binary protobuf representation. Aka. format="protobuf" # The saved file format is specified by the `format` argument. proto_string = serialization.registry.get("protobuf").serialize_proto(proto) f = io.BytesIO() onnx.save_model(proto_string, f, format=self.format) loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_model_when_input_is_file_name(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.onnx") onnx.save_model(proto, model_path, format=self.format) loaded_proto = onnx.load_model(model_path, format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_model_when_input_is_pathlike(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = pathlib.Path(temp_dir, "model.onnx") onnx.save_model(proto, model_path, format=self.format) loaded_proto = onnx.load_model(model_path, format=self.format) self.assertEqual(proto, loaded_proto) @parameterized.parameterized_class( [ {"format": "protobuf"}, {"format": "textproto"}, {"format": "json"}, # The onnxtxt format does not support saving/loading tensors yet ] ) class TestIOTensor(unittest.TestCase): """Test loading and saving of TensorProto.""" format: str def test_load_tensor_when_input_is_bytes(self) -> None: proto = _simple_tensor() proto_string = serialization.registry.get(self.format).serialize_proto(proto) loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_tensor_when_input_has_read_function(self) -> None: # Test if input has a read function proto = _simple_tensor() f = io.BytesIO() onnx.save_tensor(proto, f, format=self.format) loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_tensor_when_input_is_file_name(self) -> None: # Test if input is a file name proto = _simple_tensor() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.onnx") onnx.save_tensor(proto, model_path, format=self.format) loaded_proto = onnx.load_tensor(model_path, format=self.format) self.assertEqual(proto, loaded_proto) def test_save_and_load_tensor_when_input_is_pathlike(self) -> None: # Test if input is a file name proto = _simple_tensor() with tempfile.TemporaryDirectory() as temp_dir: model_path = pathlib.Path(temp_dir, "model.onnx") onnx.save_tensor(proto, model_path, format=self.format) loaded_proto = onnx.load_tensor(model_path, format=self.format) self.assertEqual(proto, loaded_proto) class TestSaveAndLoadFileExtensions(unittest.TestCase): def test_save_model_picks_correct_format_from_extension(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.textproto") # No format is specified, so the extension should be used to determine the format onnx.save_model(proto, model_path) loaded_proto = onnx.load_model(model_path, format="textproto") self.assertEqual(proto, loaded_proto) def test_load_model_picks_correct_format_from_extension(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.textproto") onnx.save_model(proto, model_path, format="textproto") # No format is specified, so the extension should be used to determine the format loaded_proto = onnx.load_model(model_path) self.assertEqual(proto, loaded_proto) def test_save_model_uses_format_when_it_is_specified(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.textproto") # `format` is specified. It should take precedence over the extension onnx.save_model(proto, model_path, format="protobuf") loaded_proto = onnx.load_model(model_path, format="protobuf") self.assertEqual(proto, loaded_proto) with self.assertRaises(google.protobuf.text_format.ParseError): # Loading it as textproto (by file extension) should fail onnx.load_model(model_path) def test_load_model_uses_format_when_it_is_specified(self) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, "model.protobuf") onnx.save_model(proto, model_path) with self.assertRaises(google.protobuf.text_format.ParseError): # `format` is specified. It should take precedence over the extension # Loading it as textproto should fail onnx.load_model(model_path, format="textproto") loaded_proto = onnx.load_model(model_path, format="protobuf") self.assertEqual(proto, loaded_proto) def test_load_and_save_model_to_path_without_specifying_extension_succeeds( self, ) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: # No extension is specified model_path = os.path.join(temp_dir, "model") onnx.save_model(proto, model_path, format="textproto") with self.assertRaises(google.protobuf.message.DecodeError): # `format` is not specified. load_model should assume protobuf # and fail to load it onnx.load_model(model_path) loaded_proto = onnx.load_model(model_path, format="textproto") self.assertEqual(proto, loaded_proto) def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf( self, ) -> None: proto = _simple_model() with tempfile.TemporaryDirectory() as temp_dir: # No extension is specified model_path = os.path.join(temp_dir, "model") onnx.save_model(proto, model_path) with self.assertRaises(google.protobuf.text_format.ParseError): # The model is saved as protobuf, so loading it as textproto should fail onnx.load_model(model_path, format="textproto") loaded_proto = onnx.load_model(model_path) self.assertEqual(proto, loaded_proto) loaded_proto_as_explicitly_protobuf = onnx.load_model( model_path, format="protobuf" ) self.assertEqual(proto, loaded_proto_as_explicitly_protobuf) class TestBasicFunctions(unittest.TestCase): def test_protos_exist(self) -> None: # The proto classes should exist _ = onnx.AttributeProto _ = onnx.NodeProto _ = onnx.GraphProto _ = onnx.ModelProto def test_version_exists(self) -> None: model = onnx.ModelProto() # When we create it, graph should not have a version string. self.assertFalse(model.HasField("ir_version")) # We should touch the version so it is annotated with the current # ir version of the running ONNX model.ir_version = onnx.IR_VERSION model_string = model.SerializeToString() model.ParseFromString(model_string) self.assertTrue(model.HasField("ir_version")) # Check if the version is correct. self.assertEqual(model.ir_version, onnx.IR_VERSION) def test_model_and_graph_repr(self) -> None: # Check if the __repr__ methods work without error model = _simple_model() model_repr = repr(model) self.assertEqual( model_repr, "ModelProto(ir_version=12, producer_name='onnx-test', graph=GraphProto('test'))", ) text_model = """ < ir_version: 10, opset_import: [ "" : 19] > agraph (float[N] X) => (float[N] C) < float[1] weight = {1} > { C = Cast(X) } """ model = onnx.parser.parse_model(text_model) model_repr = repr(model) self.assertEqual( model_repr, "ModelProto(ir_version=10, opset_import={'': 19}, graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>))", ) graph_repr = repr(model.graph) self.assertEqual( graph_repr, "GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>)", ) def test_function_repr(self) -> None: text = """ < ir_version: 9, opset_import: [ "" : 15, "custom_domain" : 1], producer_name: "FunctionProtoTest", producer_version: "1.0", model_version: 1, doc_string: "A test model for model local functions." > agraph (float[N] x) => (float[N] out) { out = custom_domain.Selu(x) } < domain: "custom_domain", opset_import: [ "" : 15], doc_string: "Test function proto" > Selu (X) => (C) { constant_alpha = Constant() constant_gamma = Constant() alpha_x = CastLike(constant_alpha, X) gamma_x = CastLike(constant_gamma, X) exp_x = Exp(X) alpha_x_exp_x = Mul(alpha_x, exp_x) alpha_x_exp_x_ = Sub(alpha_x_exp_x, alpha_x) neg = Mul(gamma_x, alpha_x_exp_x_) pos = Mul(gamma_x, X) _zero = Constant() zero = CastLike(_zero, X) less_eq = LessOrEqual(X, zero) C = Where(less_eq, neg, pos) } """ model = onnx.parser.parse_model(text) self.assertEqual( repr(model), "ModelProto(ir_version=9, opset_import={'': 15, 'custom_domain': 1}, producer_name='FunctionProtoTest', producer_version='1.0', graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, node=<1 nodes>), functions=<1 functions>)", ) function_repr = repr(model.functions[0]) self.assertEqual( function_repr, "FunctionProto('Selu', domain='custom_domain', opset_import={'': 15}, input=<1 inputs>, output=<1 outputs>, node=<13 nodes>)", ) if __name__ == "__main__": unittest.main()