317 lines
13 KiB
Python
317 lines
13 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import os
|
|
import pathlib
|
|
import tempfile
|
|
import unittest
|
|
|
|
import google.protobuf.message
|
|
import google.protobuf.text_format
|
|
import parameterized
|
|
|
|
import onnx
|
|
from onnx import serialization
|
|
|
|
|
|
def _simple_model() -> onnx.ModelProto:
|
|
model = onnx.ModelProto()
|
|
model.ir_version = onnx.IR_VERSION
|
|
model.producer_name = "onnx-test"
|
|
model.graph.name = "test"
|
|
return model
|
|
|
|
|
|
def _simple_tensor() -> onnx.TensorProto:
|
|
tensor = onnx.helper.make_tensor(
|
|
name="test-tensor",
|
|
data_type=onnx.TensorProto.FLOAT,
|
|
dims=(2, 3, 4),
|
|
vals=[x + 0.5 for x in range(24)],
|
|
)
|
|
return tensor
|
|
|
|
|
|
@parameterized.parameterized_class(
|
|
[
|
|
{"format": "protobuf"},
|
|
{"format": "textproto"},
|
|
{"format": "json"},
|
|
{"format": "onnxtxt"},
|
|
]
|
|
)
|
|
class TestIO(unittest.TestCase):
|
|
format: str
|
|
|
|
def test_load_model_when_input_is_bytes(self) -> None:
|
|
proto = _simple_model()
|
|
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
|
|
loaded_proto = onnx.load_model_from_string(proto_string, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_model_when_input_has_read_function(self) -> None:
|
|
proto = _simple_model()
|
|
# When the proto is a bytes representation provided to `save_model`,
|
|
# it should always be a serialized binary protobuf representation. Aka. format="protobuf"
|
|
# The saved file format is specified by the `format` argument.
|
|
proto_string = serialization.registry.get("protobuf").serialize_proto(proto)
|
|
f = io.BytesIO()
|
|
onnx.save_model(proto_string, f, format=self.format)
|
|
loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_model_when_input_is_file_name(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.onnx")
|
|
onnx.save_model(proto, model_path, format=self.format)
|
|
loaded_proto = onnx.load_model(model_path, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_model_when_input_is_pathlike(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = pathlib.Path(temp_dir, "model.onnx")
|
|
onnx.save_model(proto, model_path, format=self.format)
|
|
loaded_proto = onnx.load_model(model_path, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
|
|
@parameterized.parameterized_class(
|
|
[
|
|
{"format": "protobuf"},
|
|
{"format": "textproto"},
|
|
{"format": "json"},
|
|
# The onnxtxt format does not support saving/loading tensors yet
|
|
]
|
|
)
|
|
class TestIOTensor(unittest.TestCase):
|
|
"""Test loading and saving of TensorProto."""
|
|
|
|
format: str
|
|
|
|
def test_load_tensor_when_input_is_bytes(self) -> None:
|
|
proto = _simple_tensor()
|
|
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
|
|
loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_tensor_when_input_has_read_function(self) -> None:
|
|
# Test if input has a read function
|
|
proto = _simple_tensor()
|
|
f = io.BytesIO()
|
|
onnx.save_tensor(proto, f, format=self.format)
|
|
loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_tensor_when_input_is_file_name(self) -> None:
|
|
# Test if input is a file name
|
|
proto = _simple_tensor()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.onnx")
|
|
onnx.save_tensor(proto, model_path, format=self.format)
|
|
loaded_proto = onnx.load_tensor(model_path, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_and_load_tensor_when_input_is_pathlike(self) -> None:
|
|
# Test if input is a file name
|
|
proto = _simple_tensor()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = pathlib.Path(temp_dir, "model.onnx")
|
|
onnx.save_tensor(proto, model_path, format=self.format)
|
|
loaded_proto = onnx.load_tensor(model_path, format=self.format)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
|
|
class TestSaveAndLoadFileExtensions(unittest.TestCase):
|
|
def test_save_model_picks_correct_format_from_extension(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.textproto")
|
|
# No format is specified, so the extension should be used to determine the format
|
|
onnx.save_model(proto, model_path)
|
|
loaded_proto = onnx.load_model(model_path, format="textproto")
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_load_model_picks_correct_format_from_extension(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.textproto")
|
|
onnx.save_model(proto, model_path, format="textproto")
|
|
# No format is specified, so the extension should be used to determine the format
|
|
loaded_proto = onnx.load_model(model_path)
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_save_model_uses_format_when_it_is_specified(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.textproto")
|
|
# `format` is specified. It should take precedence over the extension
|
|
onnx.save_model(proto, model_path, format="protobuf")
|
|
loaded_proto = onnx.load_model(model_path, format="protobuf")
|
|
self.assertEqual(proto, loaded_proto)
|
|
with self.assertRaises(google.protobuf.text_format.ParseError):
|
|
# Loading it as textproto (by file extension) should fail
|
|
onnx.load_model(model_path)
|
|
|
|
def test_load_model_uses_format_when_it_is_specified(self) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
model_path = os.path.join(temp_dir, "model.protobuf")
|
|
onnx.save_model(proto, model_path)
|
|
with self.assertRaises(google.protobuf.text_format.ParseError):
|
|
# `format` is specified. It should take precedence over the extension
|
|
# Loading it as textproto should fail
|
|
onnx.load_model(model_path, format="textproto")
|
|
|
|
loaded_proto = onnx.load_model(model_path, format="protobuf")
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_load_and_save_model_to_path_without_specifying_extension_succeeds(
|
|
self,
|
|
) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# No extension is specified
|
|
model_path = os.path.join(temp_dir, "model")
|
|
onnx.save_model(proto, model_path, format="textproto")
|
|
with self.assertRaises(google.protobuf.message.DecodeError):
|
|
# `format` is not specified. load_model should assume protobuf
|
|
# and fail to load it
|
|
onnx.load_model(model_path)
|
|
|
|
loaded_proto = onnx.load_model(model_path, format="textproto")
|
|
self.assertEqual(proto, loaded_proto)
|
|
|
|
def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf(
|
|
self,
|
|
) -> None:
|
|
proto = _simple_model()
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# No extension is specified
|
|
model_path = os.path.join(temp_dir, "model")
|
|
onnx.save_model(proto, model_path)
|
|
with self.assertRaises(google.protobuf.text_format.ParseError):
|
|
# The model is saved as protobuf, so loading it as textproto should fail
|
|
onnx.load_model(model_path, format="textproto")
|
|
|
|
loaded_proto = onnx.load_model(model_path)
|
|
self.assertEqual(proto, loaded_proto)
|
|
loaded_proto_as_explicitly_protobuf = onnx.load_model(
|
|
model_path, format="protobuf"
|
|
)
|
|
self.assertEqual(proto, loaded_proto_as_explicitly_protobuf)
|
|
|
|
|
|
class TestBasicFunctions(unittest.TestCase):
|
|
def test_protos_exist(self) -> None:
|
|
# The proto classes should exist
|
|
_ = onnx.AttributeProto
|
|
_ = onnx.NodeProto
|
|
_ = onnx.GraphProto
|
|
_ = onnx.ModelProto
|
|
|
|
def test_version_exists(self) -> None:
|
|
model = onnx.ModelProto()
|
|
# When we create it, graph should not have a version string.
|
|
self.assertFalse(model.HasField("ir_version"))
|
|
# We should touch the version so it is annotated with the current
|
|
# ir version of the running ONNX
|
|
model.ir_version = onnx.IR_VERSION
|
|
model_string = model.SerializeToString()
|
|
model.ParseFromString(model_string)
|
|
self.assertTrue(model.HasField("ir_version"))
|
|
# Check if the version is correct.
|
|
self.assertEqual(model.ir_version, onnx.IR_VERSION)
|
|
|
|
def test_model_and_graph_repr(self) -> None:
|
|
# Check if the __repr__ methods work without error
|
|
model = _simple_model()
|
|
model_repr = repr(model)
|
|
self.assertEqual(
|
|
model_repr,
|
|
"ModelProto(ir_version=12, producer_name='onnx-test', graph=GraphProto('test'))",
|
|
)
|
|
|
|
text_model = """
|
|
<
|
|
ir_version: 10,
|
|
opset_import: [ "" : 19]
|
|
>
|
|
agraph (float[N] X) => (float[N] C)
|
|
<
|
|
float[1] weight = {1}
|
|
>
|
|
{
|
|
C = Cast<to=1>(X)
|
|
}
|
|
"""
|
|
model = onnx.parser.parse_model(text_model)
|
|
model_repr = repr(model)
|
|
self.assertEqual(
|
|
model_repr,
|
|
"ModelProto(ir_version=10, opset_import={'': 19}, graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>))",
|
|
)
|
|
|
|
graph_repr = repr(model.graph)
|
|
self.assertEqual(
|
|
graph_repr,
|
|
"GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>)",
|
|
)
|
|
|
|
def test_function_repr(self) -> None:
|
|
text = """
|
|
<
|
|
ir_version: 9,
|
|
opset_import: [ "" : 15, "custom_domain" : 1],
|
|
producer_name: "FunctionProtoTest",
|
|
producer_version: "1.0",
|
|
model_version: 1,
|
|
doc_string: "A test model for model local functions."
|
|
>
|
|
agraph (float[N] x) => (float[N] out)
|
|
{
|
|
out = custom_domain.Selu<alpha=2.0, gamma=3.0>(x)
|
|
}
|
|
<
|
|
domain: "custom_domain",
|
|
opset_import: [ "" : 15],
|
|
doc_string: "Test function proto"
|
|
>
|
|
Selu
|
|
<alpha: float=1.67326319217681884765625, gamma: float=1.05070102214813232421875>
|
|
(X) => (C)
|
|
{
|
|
constant_alpha = Constant<value_float: float=@alpha>()
|
|
constant_gamma = Constant<value_float: float=@gamma>()
|
|
alpha_x = CastLike(constant_alpha, X)
|
|
gamma_x = CastLike(constant_gamma, X)
|
|
exp_x = Exp(X)
|
|
alpha_x_exp_x = Mul(alpha_x, exp_x)
|
|
alpha_x_exp_x_ = Sub(alpha_x_exp_x, alpha_x)
|
|
neg = Mul(gamma_x, alpha_x_exp_x_)
|
|
pos = Mul(gamma_x, X)
|
|
_zero = Constant<value_float=0.0>()
|
|
zero = CastLike(_zero, X)
|
|
less_eq = LessOrEqual(X, zero)
|
|
C = Where(less_eq, neg, pos)
|
|
}
|
|
"""
|
|
model = onnx.parser.parse_model(text)
|
|
self.assertEqual(
|
|
repr(model),
|
|
"ModelProto(ir_version=9, opset_import={'': 15, 'custom_domain': 1}, producer_name='FunctionProtoTest', producer_version='1.0', graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, node=<1 nodes>), functions=<1 functions>)",
|
|
)
|
|
function_repr = repr(model.functions[0])
|
|
self.assertEqual(
|
|
function_repr,
|
|
"FunctionProto('Selu', domain='custom_domain', opset_import={'': 15}, input=<1 inputs>, output=<1 outputs>, node=<13 nodes>)",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|