181 lines
5.4 KiB
Python
181 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import onnx
|
|
|
|
from onnxslim.argparser import OnnxSlimKwargs
|
|
|
|
|
|
def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from onnxslim.core import (
|
|
OptimizationSettings,
|
|
convert_data_format,
|
|
freeze,
|
|
input_modification,
|
|
input_shape_modification,
|
|
optimize,
|
|
output_modification,
|
|
shape_infer,
|
|
)
|
|
from onnxslim.utils import (
|
|
TensorInfo,
|
|
check_onnx,
|
|
check_point,
|
|
check_result,
|
|
dump_model_info_to_disk,
|
|
init_logging,
|
|
onnxruntime_inference,
|
|
print_model_info_as_table,
|
|
save,
|
|
summarize_model,
|
|
update_outputs_dims,
|
|
)
|
|
|
|
output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
|
|
model_check = kwargs.get("model_check", False)
|
|
input_shapes = kwargs.get("input_shapes", None)
|
|
inputs = kwargs.get("inputs", None)
|
|
outputs = kwargs.get("outputs", None)
|
|
no_shape_infer = kwargs.get("no_shape_infer", False)
|
|
skip_optimizations = kwargs.get("skip_optimizations", None)
|
|
dtype = kwargs.get("dtype", None)
|
|
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
|
|
size_threshold = kwargs.get("size_threshold", None)
|
|
size_threshold = int(size_threshold) if size_threshold else None
|
|
kwargs.get("inspect", False)
|
|
dump_to_disk = kwargs.get("dump_to_disk", False)
|
|
save_as_external_data = kwargs.get("save_as_external_data", False)
|
|
model_check_inputs = kwargs.get("model_check_inputs", None)
|
|
verbose = kwargs.get("verbose", False)
|
|
|
|
logger = init_logging(verbose)
|
|
|
|
MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10
|
|
|
|
start_time = time.time()
|
|
|
|
def get_info(model, inspect=False):
|
|
if isinstance(model, str):
|
|
model_name = Path(model).name
|
|
model = onnx.load(model)
|
|
else:
|
|
model_name = "OnnxModel"
|
|
|
|
freeze(model)
|
|
|
|
if not inspect:
|
|
return model_name, model
|
|
|
|
model_info = summarize_model(model, model_name)
|
|
|
|
return model_info
|
|
|
|
if isinstance(model, list):
|
|
model_info_list = [get_info(m, inspect=True) for m in model]
|
|
|
|
if dump_to_disk:
|
|
[dump_model_info_to_disk(info) for info in model_info_list]
|
|
|
|
print_model_info_as_table(model_info_list)
|
|
|
|
return
|
|
else:
|
|
model_name, model = get_info(model)
|
|
if output_model:
|
|
original_info = summarize_model(model, model_name)
|
|
|
|
if inputs:
|
|
model = input_modification(model, inputs)
|
|
|
|
if input_shapes:
|
|
model = input_shape_modification(model, input_shapes)
|
|
|
|
if outputs:
|
|
model = output_modification(model, outputs)
|
|
|
|
if model_check:
|
|
input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs)
|
|
|
|
output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output}
|
|
|
|
if not no_shape_infer:
|
|
model = shape_infer(model)
|
|
|
|
OptimizationSettings.reset(skip_optimizations)
|
|
if OptimizationSettings.enabled():
|
|
graph_check_point = check_point(model)
|
|
while MAX_ITER > 0:
|
|
logger.debug(f"iter: {MAX_ITER}")
|
|
model = optimize(model, skip_fusion_patterns, size_threshold)
|
|
if not no_shape_infer:
|
|
model = shape_infer(model)
|
|
graph = check_point(model)
|
|
if graph == graph_check_point:
|
|
logger.debug(f"converged at iter: {MAX_ITER}")
|
|
break
|
|
else:
|
|
graph_check_point = graph
|
|
|
|
MAX_ITER -= 1
|
|
|
|
if dtype:
|
|
model = convert_data_format(model, dtype)
|
|
|
|
model = update_outputs_dims(model, output_dims=output_info)
|
|
|
|
if model_check:
|
|
slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict)
|
|
if not check_result(raw_onnx_output, slimmed_onnx_output):
|
|
return None
|
|
|
|
if not output_model:
|
|
return model
|
|
|
|
slimmed_info = summarize_model(model, output_model)
|
|
save(model, output_model, model_check, save_as_external_data, slimmed_info)
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
print_model_info_as_table(
|
|
[original_info, slimmed_info],
|
|
elapsed_time,
|
|
)
|
|
|
|
|
|
def main():
|
|
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
|
|
from onnxslim.argparser import (
|
|
CheckerArguments,
|
|
ModelArguments,
|
|
ModificationArguments,
|
|
OnnxSlimArgumentParser,
|
|
OptimizationArguments,
|
|
)
|
|
|
|
argument_parser = OnnxSlimArgumentParser(
|
|
ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
|
|
)
|
|
model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()
|
|
|
|
if not checker_args.inspect and checker_args.dump_to_disk:
|
|
argument_parser.error("dump_to_disk can only be used with --inspect")
|
|
|
|
if not optimization_args.no_shape_infer:
|
|
from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available
|
|
|
|
if is_onnxruntime_available():
|
|
check_onnx_compatibility()
|
|
|
|
slim(
|
|
model_args.input_model,
|
|
model_args.output_model,
|
|
**optimization_args.__dict__,
|
|
**modification_args.__dict__,
|
|
**checker_args.__dict__,
|
|
)
|
|
|
|
return 0
|