DriverTrac/venv/lib/python3.12/site-packages/onnxslim/cli/_main.py

181 lines
5.4 KiB
Python

from __future__ import annotations
import onnx
from onnxslim.argparser import OnnxSlimKwargs
def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
import os
import time
from pathlib import Path
from onnxslim.core import (
OptimizationSettings,
convert_data_format,
freeze,
input_modification,
input_shape_modification,
optimize,
output_modification,
shape_infer,
)
from onnxslim.utils import (
TensorInfo,
check_onnx,
check_point,
check_result,
dump_model_info_to_disk,
init_logging,
onnxruntime_inference,
print_model_info_as_table,
save,
summarize_model,
update_outputs_dims,
)
output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
model_check = kwargs.get("model_check", False)
input_shapes = kwargs.get("input_shapes", None)
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
no_shape_infer = kwargs.get("no_shape_infer", False)
skip_optimizations = kwargs.get("skip_optimizations", None)
dtype = kwargs.get("dtype", None)
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
size_threshold = kwargs.get("size_threshold", None)
size_threshold = int(size_threshold) if size_threshold else None
kwargs.get("inspect", False)
dump_to_disk = kwargs.get("dump_to_disk", False)
save_as_external_data = kwargs.get("save_as_external_data", False)
model_check_inputs = kwargs.get("model_check_inputs", None)
verbose = kwargs.get("verbose", False)
logger = init_logging(verbose)
MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10
start_time = time.time()
def get_info(model, inspect=False):
if isinstance(model, str):
model_name = Path(model).name
model = onnx.load(model)
else:
model_name = "OnnxModel"
freeze(model)
if not inspect:
return model_name, model
model_info = summarize_model(model, model_name)
return model_info
if isinstance(model, list):
model_info_list = [get_info(m, inspect=True) for m in model]
if dump_to_disk:
[dump_model_info_to_disk(info) for info in model_info_list]
print_model_info_as_table(model_info_list)
return
else:
model_name, model = get_info(model)
if output_model:
original_info = summarize_model(model, model_name)
if inputs:
model = input_modification(model, inputs)
if input_shapes:
model = input_shape_modification(model, input_shapes)
if outputs:
model = output_modification(model, outputs)
if model_check:
input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs)
output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output}
if not no_shape_infer:
model = shape_infer(model)
OptimizationSettings.reset(skip_optimizations)
if OptimizationSettings.enabled():
graph_check_point = check_point(model)
while MAX_ITER > 0:
logger.debug(f"iter: {MAX_ITER}")
model = optimize(model, skip_fusion_patterns, size_threshold)
if not no_shape_infer:
model = shape_infer(model)
graph = check_point(model)
if graph == graph_check_point:
logger.debug(f"converged at iter: {MAX_ITER}")
break
else:
graph_check_point = graph
MAX_ITER -= 1
if dtype:
model = convert_data_format(model, dtype)
model = update_outputs_dims(model, output_dims=output_info)
if model_check:
slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict)
if not check_result(raw_onnx_output, slimmed_onnx_output):
return None
if not output_model:
return model
slimmed_info = summarize_model(model, output_model)
save(model, output_model, model_check, save_as_external_data, slimmed_info)
end_time = time.time()
elapsed_time = end_time - start_time
print_model_info_as_table(
[original_info, slimmed_info],
elapsed_time,
)
def main():
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
from onnxslim.argparser import (
CheckerArguments,
ModelArguments,
ModificationArguments,
OnnxSlimArgumentParser,
OptimizationArguments,
)
argument_parser = OnnxSlimArgumentParser(
ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
)
model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()
if not checker_args.inspect and checker_args.dump_to_disk:
argument_parser.error("dump_to_disk can only be used with --inspect")
if not optimization_args.no_shape_infer:
from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available
if is_onnxruntime_available():
check_onnx_compatibility()
slim(
model_args.input_model,
model_args.output_model,
**optimization_args.__dict__,
**modification_args.__dict__,
**checker_args.__dict__,
)
return 0