DriverTrac/venv/lib/python3.12/site-packages/onnxslim/utils.py

795 lines
27 KiB
Python

from __future__ import annotations
import importlib.util
import logging
import os
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import onnx
from onnx import checker, helper
import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.misc.tabulate import SEPARATING_LINE, tabulate
from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER
logger = logging.getLogger("onnxslim")
import ml_dtypes
try:
from onnx._mapping import TensorDtypeMap
except ImportError:
from onnx.mapping import TensorDtypeMap
TENSOR_TYPE_MAP = {}
candidates = [
("BFLOAT16", "bfloat16", "UINT16"),
("FLOAT8E4M3FN", "float8_e4m3fn", "UINT8"),
("FLOAT8E4M3FNUZ", "float8_e4m3fnuz", "UINT8"),
("FLOAT8E5M2", "float8_e5m2", "UINT8"),
("FLOAT8E5M2FNUZ", "float8_e5m2fnuz", "UINT8"),
("UINT4", "uint4", "INT32"),
("INT4", "int4", "INT32"),
("FLOAT4E2M1", "float4_e2m1fn", "UINT8"),
]
for onnx_name, ml_name, storage_name in candidates:
if hasattr(onnx.TensorProto, onnx_name) and hasattr(ml_dtypes, ml_name):
TENSOR_TYPE_MAP[int(getattr(onnx.TensorProto, onnx_name))] = TensorDtypeMap(
np.dtype(getattr(ml_dtypes, ml_name)),
int(getattr(onnx.TensorProto, storage_name)),
f"TensorProto.{onnx_name}",
)
def init_logging(verbose=False):
"""Configure the logging settings for the application based on the verbosity level."""
logger = logging.getLogger("onnxslim")
if verbose: # DEBUG
logger.setLevel(logging.DEBUG)
G_LOGGER.severity = logging.DEBUG
else: # ERROR
logger.setLevel(logging.ERROR)
G_LOGGER.severity = logging.ERROR
if not logger.handlers:
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
G_LOGGER.colors = False
if is_onnxruntime_available():
import onnxruntime as ort
ort.set_default_logger_severity(3)
return logger
def format_bytes(size: int | tuple[int, ...]) -> str:
"""Convert byte sizes into human-readable format with appropriate units (B, KB, MB, GB)."""
if isinstance(size, int):
size = (size,)
elif isinstance(size, np.integer):
size = (int(size),)
units = ["B", "KB", "MB", "GB"]
formatted_sizes = []
for size_in_bytes in size:
unit_index = 0
while size_in_bytes >= 1024 and unit_index < len(units) - 1:
size_in_bytes /= 1024
unit_index += 1
formatted_size = f"{size_in_bytes:.2f} {units[unit_index]}"
formatted_sizes.append(formatted_size)
if len(formatted_sizes) == 1:
return formatted_sizes[0]
else:
return f"{formatted_sizes[0]} ({formatted_sizes[1]})"
def onnx_dtype_to_numpy(onnx_dtype: int) -> np.dtype:
"""Maps an ONNX dtype to its corresponding NumPy dtype."""
tensor_dtype = TENSOR_TYPE_MAP.get(onnx_dtype)
if tensor_dtype:
return tensor_dtype.np_dtype
if onnx_dtype in onnx.helper.get_all_tensor_dtypes():
return np.dtype(helper.tensor_dtype_to_np_dtype(onnx_dtype))
return "UNDEFINED"
def gen_onnxruntime_input_data(
model: onnx.ModelProto, model_check_inputs: list[str] | None = None
) -> dict[str, np.ndarray]:
"""Generate random input data for an ONNX model considering potential specific input shapes and types."""
input_info = {}
for input_tensor in model.graph.input:
name = input_tensor.name
shape = []
for dim in input_tensor.type.tensor_type.shape.dim:
if dim.HasField("dim_param"):
shape.append(dim.dim_param)
elif dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append(None)
dtype = onnx_dtype_to_numpy(input_tensor.type.tensor_type.elem_type)
input_info[name] = {"shape": shape, "dtype": dtype}
if model_check_inputs:
for model_check_input in model_check_inputs:
key, value = model_check_input.rsplit(":", 1)
if value.endswith(".npy"):
if key not in input_info:
raise Exception(
f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}"
)
data = np.load(value)
input_info[key] = {"data": data}
else:
values_list = [int(val) for val in value.split(",")]
if key in input_info:
input_info[key]["shape"] = values_list
else:
raise Exception(
f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}"
)
input_data_dict = {}
for name, info in input_info.items():
if "data" in info:
input_data_dict[name] = info["data"]
else:
shapes = [shape if (shape != -1 and not isinstance(shape, str)) else 1 for shape in info["shape"]]
shapes = shapes or [1]
dtype = info["dtype"]
if dtype in {np.int32, np.int64}:
random_data = np.random.randint(10, size=shapes).astype(dtype)
else:
random_data = np.random.rand(*shapes).astype(dtype)
input_data_dict[name] = random_data
return input_data_dict
def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> dict[str, np.array]:
"""Perform inference using ONNX Runtime on the given model and input data."""
import os
import tempfile
import onnx
import onnxruntime as rt
if model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
tmp_dir = tempfile.TemporaryDirectory()
tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
location = f"{os.path.basename(tmp_path)}.data"
if os.path.exists(location):
os.remove(location)
onnx.save(
model,
tmp_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=location,
)
onnx_model = tmp_path
else:
onnx_model = model.SerializeToString()
sess = rt.InferenceSession(onnx_model, providers=["CPUExecutionProvider"])
onnx_output = sess.run(None, input_data)
output_names = [output.name for output in sess.get_outputs()]
onnx_output = dict(zip(output_names, onnx_output))
if isinstance(onnx_model, str):
model = onnx.load(onnx_model)
return onnx_output, model
def format_model_info(model_info_list: dict | list[dict], elapsed_time: float | None = None):
assert model_info_list, "model_info_list must contain more than one model info"
from colorama import Fore, init
init()
if not isinstance(model_info_list, (list, tuple)):
model_info_list = [model_info_list]
final_op_info = []
final_op_info.extend(
(
["Model Name"] + [item.tag for item in model_info_list],
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Model Info"]
+ ["Op Set: " + item.op_set + " / IR Version: " + item.ir_version for item in model_info_list],
[SEPARATING_LINE] * (len(model_info_list) + 1),
)
)
def get_io_info(model_info_list, tag=None):
if tag == "OUT":
ios = [op_type for model_info in model_info_list for op_type in model_info.output_info]
else:
ios = [op_type for model_info in model_info_list for op_type in model_info.input_info]
ios = list(dict.fromkeys([io.name for io in ios]))
io_info = []
for io in ios:
input_info_list = [f"{tag}: {io}"]
for model_info in model_info_list:
if tag == "OUT":
io_tensor = model_info.output_maps.get(io, None)
else:
io_tensor = model_info.input_maps.get(io, None)
inputs_shape = (io_tensor.dtype, io_tensor.shape) if io_tensor else ""
if isinstance(inputs_shape, (list, tuple)):
inputs_shape = ": ".join([str(i) for i in inputs_shape])
input_info_list.append(inputs_shape)
io_info.append(input_info_list)
return io_info
final_op_info.extend(get_io_info(model_info_list, "IN"))
final_op_info.extend(get_io_info(model_info_list, "OUT"))
final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1))
all_ops = {op_type for model_info in model_info_list for op_type in model_info.op_type_counts}
sorted_ops = sorted(all_ops)
for op in sorted_ops:
op_info_list = [op]
float_number = model_info_list[0].op_type_counts.get(op, 0)
op_info_list.append(float_number)
for model_info in model_info_list[1:]:
slimmed_number = model_info.op_type_counts.get(op, 0)
if float_number > slimmed_number:
slimmed_number = Fore.GREEN + str(slimmed_number) + Fore.WHITE
op_info_list.append(slimmed_number)
final_op_info.append(op_info_list)
final_op_info.extend(
(
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Model Size"] + [format_bytes(model_info.model_size) for model_info in model_info_list],
)
)
if elapsed_time:
final_op_info.extend(
(
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Elapsed Time", f"{elapsed_time:.2f} s"],
)
)
return final_op_info
def print_model_info_as_table(model_info_list: dict | list[dict], elapsed_time: float | None = None):
"""Prints the model information as a formatted table for the given model name and list of model details."""
if not isinstance(model_info_list, (list, tuple)):
model_info_list = [model_info_list]
final_op_info = format_model_info(model_info_list, elapsed_time)
lines = tabulate(
final_op_info,
headers=[],
tablefmt="pretty",
maxcolwidths=[None] + [40] * len(model_info_list),
).split("\n")
if elapsed_time:
time_row = lines[-2].split("|")
time_row[-3] = (
time_row[-2][: len(time_row[-2]) // 2 + 1] + time_row[-3] + time_row[-2][len(time_row[-2]) // 2 :]
)
time_row.pop(-2)
lines[-2] = "|".join(time_row)
output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines])
print(output)
def dump_model_info_to_disk(model_info: dict):
"""Writes model information to a CSV file for a given model name and dictionary of model info."""
import csv
csv_file_path = f"{model_info.tag}_model_info.csv"
with open(csv_file_path, "a", newline="") as csvfile: # Use 'a' for append mode
fieldnames = ["NodeName", "OpType", "OutputDtype", "OutputShape"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# If the file is empty, write the header
if csvfile.tell() == 0:
writer.writeheader()
# Write the data
for node_name, info in model_info.op_info.items():
op_type, output_info_list = info.op, info.outputs
if len(output_info_list) >= 1:
# Write the first row with actual NodeName and OpType
row_data_first = {
"NodeName": node_name,
"OpType": op_type,
"OutputDtype": output_info_list[0].dtype, # First entry in the list
"OutputShape": output_info_list[0].shape, # First entry in the list
}
writer.writerow(row_data_first)
# Write subsequent rows with empty strings for NodeName and OpType
for output_dtype, output_shape in output_info_list[1:]:
row_data_empty = {
"NodeName": "",
"OpType": "",
"OutputDtype": output_dtype,
"OutputShape": output_shape,
}
writer.writerow(row_data_empty)
print(f"Model info written to {csv_file_path}")
def get_opset(model: onnx.ModelProto) -> int:
"""Returns the ONNX opset version for a given model."""
try:
for importer in model.opset_import:
if importer.domain in {"", "ai.onnx"}:
return importer.version
return None
except Exception:
return None
def get_ir_version(model: onnx.ModelProto) -> int:
"""Returns the ONNX ir version for a given model."""
try:
return model.ir_version
except Exception:
return None
class TensorInfo:
def __init__(self, tensor):
self.dtype: np.dtype = np.float32
self.shape: tuple[str | int] = None
self._extract_info(tensor)
def _extract_info(self, tensor):
"""Extract the data type and shape of an ONNX tensor."""
self.dtype = onnx_dtype_to_numpy(tensor.type.tensor_type.elem_type)
shape = None
if tensor.type.tensor_type.HasField("shape"):
shape = []
for dim in tensor.type.tensor_type.shape.dim:
if dim.HasField("dim_param"):
shape.append(dim.dim_param)
elif dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("?")
self.shape = tuple(shape) if shape is not None else None
self.name = tensor.name
class OperatorInfo:
def __init__(self, operator, outputs=None):
self.name: str = None
self.op: str = None
self._extract_info(operator)
self.outputs = outputs
def _extract_info(self, operator):
self.name: str = operator.name
self.op: str = operator.op_type
class ModelInfo:
def __init__(self, model: str | onnx.ModelProto, tag: str = "OnnxSlim"):
if isinstance(model, str):
tag = Path(model).name
model = onnx.load(model)
self.tag: str = tag
self.model_size: int = -1
self.op_set: str = None
self.ir_version: str = None
self.op_type_counts: dict[str, int] = defaultdict(int)
self.op_info: dict[str, dict] = {}
self.input_info: list[str, tuple[str, tuple]] = []
self.output_info: list[str, tuple[str, tuple]] = []
self._summarize_model(model)
def _summarize_model(self, model):
self.op_set = str(get_opset(model))
self.ir_version = str(get_ir_version(model))
self.model_size = get_initializer_size(model)
for input in model.graph.input:
self.input_info.append(TensorInfo(input))
for output in model.graph.output:
self.output_info.append(TensorInfo(output))
value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}
def get_graph_node_info(graph: onnx.GraphProto) -> dict[str, list[str]]:
for node in graph.node:
op_type = node.op_type
self.op_type_counts[op_type] += 1
output_tensor_info = []
for output in node.output:
if output in value_info_dict:
tensor = value_info_dict[output]
tensor_info = TensorInfo(tensor)
output_tensor_info.append(tensor_info)
self.op_info[node.name] = OperatorInfo(node, output_tensor_info)
for attr in node.attribute:
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
if attr.type in ATTR_TYPE_MAPPING:
attr_str = ATTR_TYPE_MAPPING[attr.type]
if attr_str == "GRAPH":
get_graph_node_info(attr.g)
get_graph_node_info(model.graph)
@property
def input_maps(self):
self.input_dict = {input_info.name: input_info for input_info in self.input_info}
return self.input_dict
@property
def output_maps(self):
self.output_dict = {output_info.name: output_info for output_info in self.output_info}
return self.output_dict
def summarize_model(model: str | onnx.ModelProto, tag="OnnxModel") -> dict:
"""Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
logger.debug("Start summarizing model.")
model_info = ModelInfo(model, tag)
logger.debug("Finish summarizing model.")
return model_info
def model_save_as_external_data(model: onnx.ModelProto, model_path: str):
"""Save an ONNX model with tensor data as an external file."""
location = f"{os.path.basename(model_path)}.data"
if os.path.exists(location):
os.remove(location)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=location,
)
def check_onnx(model: onnx.ModelProto, model_check_inputs=None):
"""Validates an ONNX model by generating input data and performing inference to check outputs."""
input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs)
raw_onnx_output, model = onnxruntime_inference(model, input_data_dict)
return input_data_dict, raw_onnx_output, model
def check_point(model: onnx.ModelProto):
"""Imports an ONNX model checkpoint into a Graphsurgeon graph representation."""
return gs.import_onnx(model)
def save(
model: onnx.ModelProto,
model_path: str,
model_check: bool = False,
save_as_external_data: bool = False,
model_info: dict | None = None,
):
"""Save an ONNX model to a specified path, with optional model checking for validity."""
if model_check:
try:
checker.check_model(model)
except ValueError:
logger.warning("Model too large and cannot be checked.")
if model_path: # model larger than 2GB can be saved, but compiler like trtexec won't parse it
if get_initializer_size(model) <= checker.MAXIMUM_PROTOBUF and not save_as_external_data:
onnx.save(model, model_path)
else:
import os
location = f"{os.path.basename(model_path)}.data"
if os.path.exists(location):
os.remove(location)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=location,
)
logger.debug("Model too large and saved as external data automatically.")
if model_info:
model_size = model.ByteSize()
model_info.model_size = [model_size, model_info.model_size]
def check_result(raw_onnx_output, slimmed_onnx_output):
"""Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are
detected.
"""
if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()):
print("Model output mismatch after slimming.")
print(f"Raw model output keys: {raw_onnx_output.keys()}")
print(f"Slimmed model output keys: {slimmed_onnx_output.keys()}")
print("Please check the model carefully.")
return False
else:
for key in raw_onnx_output.keys():
if not np.allclose(
raw_onnx_output[key],
slimmed_onnx_output[key],
rtol=1e-03,
atol=1e-04,
equal_nan=True,
):
print(f"\033[31mModel output {key} mismatch after slimming.")
print("\033[31mPlease check the model carefully.")
return False
return True
def get_numpy_type(onnx_type):
if not isinstance(onnx_type, int):
# Already a NumPy type
return onnx_type
numpy_unsupported_types = [
onnx.TensorProto.BFLOAT16,
onnx.TensorProto.FLOAT8E4M3FN,
onnx.TensorProto.FLOAT8E4M3FNUZ,
onnx.TensorProto.FLOAT8E5M2,
onnx.TensorProto.FLOAT8E5M2FNUZ,
]
# TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types.
# This obviously breaks things, so we need to treat this as a special case.
if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes():
return onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
return None
def get_itemsize(dtype):
np_dtype = get_numpy_type(dtype)
if np_dtype is not None:
return np.dtype(np_dtype).itemsize
if dtype == onnx.TensorProto.BFLOAT16:
return 2
if dtype in [
onnx.TensorProto.FLOAT8E4M3FN,
onnx.TensorProto.FLOAT8E4M3FNUZ,
onnx.TensorProto.FLOAT8E5M2,
onnx.TensorProto.FLOAT8E5M2FNUZ,
]:
return 1
print(f"Unknown ONNX dtype: {dtype}")
raise ValueError(f"Unsupported TensorProto dtype: {dtype}")
def calculate_tensor_size(tensor):
"""Calculates the size of an ONNX tensor in bytes based on its shape and data type size."""
shape = tensor.dims
num_elements = np.prod(shape) if shape else 0
element_size = get_itemsize(tensor.data_type)
return num_elements * element_size
def get_initializer_size(model):
"""Calculate total size of all subgraphs in an ONNX model."""
total_size = get_graph_initializer_size(model.graph)
return total_size
def get_graph_initializer_size(graph):
initializer_size = 0
for tensor in graph.initializer:
tensor_size = calculate_tensor_size(tensor)
initializer_size += tensor_size
for node in graph.node:
if node.op_type == "Constant":
for attr in node.attribute:
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
initializer_size += calculate_tensor_size(attr.t)
elif node.op_type == "If":
initializer_size += get_graph_initializer_size(node.attribute[0].g)
initializer_size += get_graph_initializer_size(node.attribute[1].g)
elif node.op_type == "Loop":
initializer_size += get_graph_initializer_size(node.attribute[0].g)
elif node.op_type == "Scan":
initializer_size += get_graph_initializer_size(node.attribute[0].g)
return initializer_size
def is_onnxruntime_available():
if importlib.util.find_spec("onnxruntime") is None:
logger = logging.getLogger("onnxslim")
logger.debug("onnxruntime is not available, please install it first for better optimization")
return False
else:
try:
# in case of onnxruntime import error
import onnxruntime as ort
if hasattr(ort, "__version__"):
return True
else:
return False
except:
logger = logging.getLogger("onnxslim")
logger.debug("onnxruntime is not available, please install it first for better optimization")
return False
def check_onnx_compatibility():
"""Ensure ONNX Runtime and ONNX versions are compatible for model inference."""
compatibility_dict = {
"1.20": "1.16",
"1.19": "1.16",
"1.18": "1.16",
"1.17": "1.15",
"1.16": "1.14.1",
"1.15": "1.14",
"1.14": "1.13",
"1.13": "1.12",
"1.12": "1.12",
"1.11": "1.11",
"1.10": "1.10",
"1.9": "1.10",
"1.8": "1.9",
"1.7": "1.8",
"1.6": "1.8",
"1.5": "1.7",
"1.4": "1.7",
"1.3": "1.7",
"1.2": "1.6",
"1.1": "1.6",
"1.0": "1.6",
"0.5": "1.5",
"0.4": "1.5",
"0.3": "1.4",
"0.2": "1.3",
"0.1": "1.3",
}
import onnx
import onnxruntime
onnx_version = onnx.__version__
# ort_version = onnxruntime.__version__
ort_version = ".".join(onnxruntime.__version__.split("+")[0].split(".")[:2])
# Check compatibility
expected_onnx_version = compatibility_dict.get(ort_version)
if expected_onnx_version is None:
print(
f"Warning: Onnx Runtime version {ort_version} has no specified compatible ONNX version. Compatibility issues may occur."
)
elif expected_onnx_version == ".".join(onnx_version.split("+")[0].split(".")[:2]):
logger.info(
f"Installed Onnx Runtime version {ort_version} is compatible with installed ONNX version {onnx_version}."
)
else:
print(
f"Warning: Installed Onnx Runtime version {ort_version} is not compatible with installed ONNX version {onnx_version}. Expected ONNX version: {expected_onnx_version}."
)
def get_max_tensor(model, topk=5):
graph = gs.import_onnx(model)
tensor_map = graph.tensors()
constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)]
sub_graphs = graph.subgraphs(recursive=True)
sub_graphs_constant_tensors = [
[tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
for sub_graph in sub_graphs
]
constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors])
sizes = [tensor.values.size for tensor in constant_tensors]
sorted_indices = np.argsort(sizes)[::-1][:topk]
for i in sorted_indices:
tensor = constant_tensors[i]
print(
f"Tensor name: {tensor.name}, shape: {tensor.values.shape}, dtype: {tensor.values.dtype} size: {tensor.values.size}"
)
# copied from https://onnx.ai/onnx/api/tools.html
def update_outputs_dims(
model,
output_dims,
):
dim_param_set: set[str] = set()
def init_dim_param_set(dim_param_set, value_infos):
for info in value_infos:
shape = info.type.tensor_type.shape
for dim in shape.dim:
if dim.HasField("dim_param"):
dim_param_set.add(dim.dim_param) # type: ignore
init_dim_param_set(dim_param_set, model.graph.output) # type: ignore
def update_dim(tensor, dim, j, name) -> None:
dim_proto = tensor.type.tensor_type.shape.dim[j]
# if it's int in model, it won't be replaced by original symbol
if dim_proto.HasField("dim_value"):
return
if isinstance(dim, int):
if dim >= 0:
if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim:
raise ValueError(
f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}."
)
dim_proto.dim_value = dim
else:
generated_dim_param = name + "_" + str(j)
if generated_dim_param in dim_param_set:
raise ValueError(
f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value."
)
dim_proto.dim_param = generated_dim_param
elif isinstance(dim, str):
dim_proto.dim_param = dim
else:
raise ValueError(f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}")
for output in model.graph.output:
output_name = output.name
output_dim_arr = output_dims[output_name]
if output_dim_arr is None:
continue
if len(output.type.tensor_type.shape.dim) == 0:
for _ in range(len(output_dim_arr)):
output.type.tensor_type.shape.dim.add()
for j, dim in enumerate(output_dim_arr):
update_dim(output, dim, j, output_name)
return model