DriverTrac/venv/lib/python3.12/site-packages/onnx/test/basic_test.py

317 lines
13 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import io
import os
import pathlib
import tempfile
import unittest
import google.protobuf.message
import google.protobuf.text_format
import parameterized
import onnx
from onnx import serialization
def _simple_model() -> onnx.ModelProto:
model = onnx.ModelProto()
model.ir_version = onnx.IR_VERSION
model.producer_name = "onnx-test"
model.graph.name = "test"
return model
def _simple_tensor() -> onnx.TensorProto:
tensor = onnx.helper.make_tensor(
name="test-tensor",
data_type=onnx.TensorProto.FLOAT,
dims=(2, 3, 4),
vals=[x + 0.5 for x in range(24)],
)
return tensor
@parameterized.parameterized_class(
[
{"format": "protobuf"},
{"format": "textproto"},
{"format": "json"},
{"format": "onnxtxt"},
]
)
class TestIO(unittest.TestCase):
format: str
def test_load_model_when_input_is_bytes(self) -> None:
proto = _simple_model()
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
loaded_proto = onnx.load_model_from_string(proto_string, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_has_read_function(self) -> None:
proto = _simple_model()
# When the proto is a bytes representation provided to `save_model`,
# it should always be a serialized binary protobuf representation. Aka. format="protobuf"
# The saved file format is specified by the `format` argument.
proto_string = serialization.registry.get("protobuf").serialize_proto(proto)
f = io.BytesIO()
onnx.save_model(proto_string, f, format=self.format)
loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_is_file_name(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.onnx")
onnx.save_model(proto, model_path, format=self.format)
loaded_proto = onnx.load_model(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_is_pathlike(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = pathlib.Path(temp_dir, "model.onnx")
onnx.save_model(proto, model_path, format=self.format)
loaded_proto = onnx.load_model(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
@parameterized.parameterized_class(
[
{"format": "protobuf"},
{"format": "textproto"},
{"format": "json"},
# The onnxtxt format does not support saving/loading tensors yet
]
)
class TestIOTensor(unittest.TestCase):
"""Test loading and saving of TensorProto."""
format: str
def test_load_tensor_when_input_is_bytes(self) -> None:
proto = _simple_tensor()
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_has_read_function(self) -> None:
# Test if input has a read function
proto = _simple_tensor()
f = io.BytesIO()
onnx.save_tensor(proto, f, format=self.format)
loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_is_file_name(self) -> None:
# Test if input is a file name
proto = _simple_tensor()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.onnx")
onnx.save_tensor(proto, model_path, format=self.format)
loaded_proto = onnx.load_tensor(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_is_pathlike(self) -> None:
# Test if input is a file name
proto = _simple_tensor()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = pathlib.Path(temp_dir, "model.onnx")
onnx.save_tensor(proto, model_path, format=self.format)
loaded_proto = onnx.load_tensor(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
class TestSaveAndLoadFileExtensions(unittest.TestCase):
def test_save_model_picks_correct_format_from_extension(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
# No format is specified, so the extension should be used to determine the format
onnx.save_model(proto, model_path)
loaded_proto = onnx.load_model(model_path, format="textproto")
self.assertEqual(proto, loaded_proto)
def test_load_model_picks_correct_format_from_extension(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
onnx.save_model(proto, model_path, format="textproto")
# No format is specified, so the extension should be used to determine the format
loaded_proto = onnx.load_model(model_path)
self.assertEqual(proto, loaded_proto)
def test_save_model_uses_format_when_it_is_specified(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
# `format` is specified. It should take precedence over the extension
onnx.save_model(proto, model_path, format="protobuf")
loaded_proto = onnx.load_model(model_path, format="protobuf")
self.assertEqual(proto, loaded_proto)
with self.assertRaises(google.protobuf.text_format.ParseError):
# Loading it as textproto (by file extension) should fail
onnx.load_model(model_path)
def test_load_model_uses_format_when_it_is_specified(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.protobuf")
onnx.save_model(proto, model_path)
with self.assertRaises(google.protobuf.text_format.ParseError):
# `format` is specified. It should take precedence over the extension
# Loading it as textproto should fail
onnx.load_model(model_path, format="textproto")
loaded_proto = onnx.load_model(model_path, format="protobuf")
self.assertEqual(proto, loaded_proto)
def test_load_and_save_model_to_path_without_specifying_extension_succeeds(
self,
) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
# No extension is specified
model_path = os.path.join(temp_dir, "model")
onnx.save_model(proto, model_path, format="textproto")
with self.assertRaises(google.protobuf.message.DecodeError):
# `format` is not specified. load_model should assume protobuf
# and fail to load it
onnx.load_model(model_path)
loaded_proto = onnx.load_model(model_path, format="textproto")
self.assertEqual(proto, loaded_proto)
def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf(
self,
) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
# No extension is specified
model_path = os.path.join(temp_dir, "model")
onnx.save_model(proto, model_path)
with self.assertRaises(google.protobuf.text_format.ParseError):
# The model is saved as protobuf, so loading it as textproto should fail
onnx.load_model(model_path, format="textproto")
loaded_proto = onnx.load_model(model_path)
self.assertEqual(proto, loaded_proto)
loaded_proto_as_explicitly_protobuf = onnx.load_model(
model_path, format="protobuf"
)
self.assertEqual(proto, loaded_proto_as_explicitly_protobuf)
class TestBasicFunctions(unittest.TestCase):
def test_protos_exist(self) -> None:
# The proto classes should exist
_ = onnx.AttributeProto
_ = onnx.NodeProto
_ = onnx.GraphProto
_ = onnx.ModelProto
def test_version_exists(self) -> None:
model = onnx.ModelProto()
# When we create it, graph should not have a version string.
self.assertFalse(model.HasField("ir_version"))
# We should touch the version so it is annotated with the current
# ir version of the running ONNX
model.ir_version = onnx.IR_VERSION
model_string = model.SerializeToString()
model.ParseFromString(model_string)
self.assertTrue(model.HasField("ir_version"))
# Check if the version is correct.
self.assertEqual(model.ir_version, onnx.IR_VERSION)
def test_model_and_graph_repr(self) -> None:
# Check if the __repr__ methods work without error
model = _simple_model()
model_repr = repr(model)
self.assertEqual(
model_repr,
"ModelProto(ir_version=12, producer_name='onnx-test', graph=GraphProto('test'))",
)
text_model = """
<
ir_version: 10,
opset_import: [ "" : 19]
>
agraph (float[N] X) => (float[N] C)
<
float[1] weight = {1}
>
{
C = Cast<to=1>(X)
}
"""
model = onnx.parser.parse_model(text_model)
model_repr = repr(model)
self.assertEqual(
model_repr,
"ModelProto(ir_version=10, opset_import={'': 19}, graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>))",
)
graph_repr = repr(model.graph)
self.assertEqual(
graph_repr,
"GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, initializer=<1 initializers>, node=<1 nodes>)",
)
def test_function_repr(self) -> None:
text = """
<
ir_version: 9,
opset_import: [ "" : 15, "custom_domain" : 1],
producer_name: "FunctionProtoTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for model local functions."
>
agraph (float[N] x) => (float[N] out)
{
out = custom_domain.Selu<alpha=2.0, gamma=3.0>(x)
}
<
domain: "custom_domain",
opset_import: [ "" : 15],
doc_string: "Test function proto"
>
Selu
<alpha: float=1.67326319217681884765625, gamma: float=1.05070102214813232421875>
(X) => (C)
{
constant_alpha = Constant<value_float: float=@alpha>()
constant_gamma = Constant<value_float: float=@gamma>()
alpha_x = CastLike(constant_alpha, X)
gamma_x = CastLike(constant_gamma, X)
exp_x = Exp(X)
alpha_x_exp_x = Mul(alpha_x, exp_x)
alpha_x_exp_x_ = Sub(alpha_x_exp_x, alpha_x)
neg = Mul(gamma_x, alpha_x_exp_x_)
pos = Mul(gamma_x, X)
_zero = Constant<value_float=0.0>()
zero = CastLike(_zero, X)
less_eq = LessOrEqual(X, zero)
C = Where(less_eq, neg, pos)
}
"""
model = onnx.parser.parse_model(text)
self.assertEqual(
repr(model),
"ModelProto(ir_version=9, opset_import={'': 15, 'custom_domain': 1}, producer_name='FunctionProtoTest', producer_version='1.0', graph=GraphProto('agraph', input=<1 inputs>, output=<1 outputs>, node=<1 nodes>), functions=<1 functions>)",
)
function_repr = repr(model.functions[0])
self.assertEqual(
function_repr,
"FunctionProto('Selu', domain='custom_domain', opset_import={'': 15}, input=<1 inputs>, output=<1 outputs>, node=<13 nodes>)",
)
if __name__ == "__main__":
unittest.main()