DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_constant.py
2025-11-28 09:08:33 +05:30

120 lines
4.2 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun, RefAttrName
def _check_dtype(val):
dtype = val.dtype
if not isinstance(dtype, np.dtype):
raise TypeError(
f"Type ({dtype}, {type(dtype)}) is not a numpy type (operator 'Constant')"
)
class ConstantCommon(OpRun):
def _check(self, cst):
if isinstance(cst, tuple):
raise TypeError(f"Unexpected type {type(cst)} for a constant.")
return cst
class Constant_1(ConstantCommon):
def __init__(self, onnx_node, run_params):
ConstantCommon.__init__(self, onnx_node, run_params)
self.cst = self.value
_check_dtype(self.cst)
def _run(self, **overridden_attributes):
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_9(Constant_1):
def __init__(self, onnx_node, run_params):
Constant_1.__init__(self, onnx_node, run_params)
class Constant_11(ConstantCommon):
def __init__(self, onnx_node, run_params):
ConstantCommon.__init__(self, onnx_node, run_params)
if getattr(self, "sparse_value", None) is None:
self.cst = self.value
else:
self.cst = self.sparse_value
_check_dtype(self.cst)
def _run(self, **overridden_attributes):
if overridden_attributes and (
len(overridden_attributes) > 1
or "value" not in overridden_attributes
or id(overridden_attributes["value"]) != id(self.value)
):
raise RuntimeError(
"Function attributes are not implemented for opset <= 11. Use opset > 12."
)
return (self._check(self.cst),)
class Constant_12(ConstantCommon):
def __init__(self, onnx_node, run_params):
ConstantCommon.__init__(self, onnx_node, run_params)
if hasattr(self, "sparse_value") and self.sparse_value is not None:
self.cst_name = "sparse_value"
self.cst = self.sparse_value
self.cst_convert = lambda v: v
elif hasattr(self, "value") and self.value is not None:
self.cst_name = "value"
self.cst = self.value
self.cst_convert = lambda v: v
else:
for attr, np_dtype in {
"value_float": np.float32,
"value_floats": np.float32,
"value_int": np.int64,
"value_ints": np.int64,
"value_string": np.str_,
"value_strings": np.str_,
}.items():
if hasattr(self, attr) and getattr(self, attr) is not None:
self.cst_name = attr
v = getattr(self, attr)
self.cst = (
v if isinstance(v, RefAttrName) else np.array(v, dtype=np_dtype)
)
self.cst_convert = lambda v, np_dtype=np_dtype: np.array(
v, dtype=np_dtype
)
break
if not hasattr(self, "cst_name"):
raise AttributeError(
f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}."
)
def _run(self, **overridden_attributes):
if self.has_linked_attribute:
if overridden_attributes is None:
raise RuntimeError(
f"Attributes are empty, cannot retrieve value for {self.cst!r}."
)
if self.cst_name not in overridden_attributes:
raise RuntimeError(
f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}."
)
value = overridden_attributes[self.cst_name]
if isinstance(value, np.ndarray):
return (value,)
return (self.cst_convert(value),)
return (self._check(self.cst),)