import argparse import dataclasses from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from dataclasses import dataclass, field from typing import List, Optional, Type, Union, get_args, get_origin, TypedDict, Dict, Literal from .core.optimization import OptimizationSettings from .core.pattern.registry import DEFAULT_FUSION_PATTERNS from .version import __version__ class OnnxSlimKwargs(TypedDict, total=False): model_check: bool input_shapes: Dict[str, List[int]] inputs: List[str] outputs: List[str] no_shape_infer: bool skip_optimizations: List[str] dtype: Literal["float16", "float32", "uint8", "int8"] skip_fusion_patterns: List[str] size_threshold: int inspect: bool dump_to_disk: bool save_as_external_data: bool model_check_inputs: Optional[List[str]] verbose: bool def _get_inner_type(arg_type): if get_origin(arg_type) is Union: return next((t for t in get_args(arg_type) if t is not type(None)), str) return arg_type @dataclass class ModelArguments: """ Args: model (Union[str, onnx.ModelProto]): The ONNX model to be slimmed. It can be either a file path or an `onnx.ModelProto` object. output_model (str, optional): File path to save the slimmed model. If None, the model will not be saved. """ input_model: str = field(metadata={"help": "input onnx model"}) output_model: Optional[str] = field(default=None, metadata={"help": "output onnx model"}) @dataclass class OptimizationArguments: """ Args: no_shape_infer (bool, optional): Flag indicating whether to perform shape inference. Default is False. no_constant_folding (bool, optional): Flag indicating whether to perform constant folding. Default is False. skip_fusion_patterns (str, optional): String representing fusion patterns to skip. Default is None. """ no_shape_infer: bool = field(default=False, metadata={"help": "whether to disable shape_infer, default false."}) skip_optimizations: Optional[List[str]] = field( default=None, metadata={ "help": "whether to skip some optimizations", "choices": list(OptimizationSettings.keys()), }, ) skip_fusion_patterns: Optional[List[str]] = field( default=None, metadata={ "help": "whether to skip the fusion of some patterns", "choices": list(DEFAULT_FUSION_PATTERNS.keys()), }, ) size_threshold: int = field( default=None, metadata={ "help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants", }, ) @dataclass class ModificationArguments: """ Args: input_shapes (str, optional): String representing the input shapes. Default is None. outputs (str, optional): String representing the outputs. Default is None. dtype (str, optional): Data type. Default is None. save_as_external_data (bool, optional): Flag indicating whether to split onnx as model and weight. Default is False. """ input_shapes: Optional[List[str]] = field( default=None, metadata={ "help": "input shape of the model, INPUT_NAME:SHAPE, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:1,3,224,224" }, ) inputs: Optional[List[str]] = field( default=None, metadata={ "help": "input of the model, INPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the input will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32" }, ) outputs: Optional[List[str]] = field( default=None, metadata={ "help": "output of the model, OUTPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the output will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32" }, ) dtype: Optional[str] = field( default=None, metadata={"help": "convert data format to fp16 or fp32.", "choices": ["fp16", "fp32"]} ) save_as_external_data: bool = field( default=False, metadata={"help": "split onnx as model and weight, default False."} ) @dataclass class CheckerArguments: """ Args: model_check (bool, optional): Flag indicating whether to perform model checking. Default is False. model_check_inputs (str, optional): The shape or tensor used for model check. Default is None. inspect (bool, optional): Flag indicating whether to inspect the model. Default is False. dump_to_disk (bool, optional): Flag indicating whether to dump the model detail to disk. Default is False. verbose (bool, optional): Flag indicating whether to print verbose logs. Default is False. """ model_check: bool = field(default=False, metadata={"help": "enable model check"}) model_check_inputs: Optional[List[str]] = field( default=None, metadata={ "help": "Works only when model_check is enabled, Input shape of the model or numpy data path, INPUT_NAME:SHAPE or INPUT_NAME:DATAPATH, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:data.npy. Useful when input shapes are dynamic." }, ) inspect: bool = field(default=False, metadata={"help": "inspect model, default False."}) dump_to_disk: bool = field(default=False, metadata={"help": "dump model info to disk, default False."}) verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."}) class OnnxSlimArgumentParser(ArgumentParser): def __init__(self, *argument_dataclasses: Type, **kwargs): if "formatter_class" not in kwargs: kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter super().__init__(**kwargs) self.argument_dataclasses = argument_dataclasses self.parser = argparse.ArgumentParser( description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", formatter_class=argparse.RawDescriptionHelpFormatter, ) self._add_arguments() def _add_arguments(self): for dataclass_type in self.argument_dataclasses: if dataclass_type is ModelArguments: continue for field_name, field_def in dataclass_type.__dataclass_fields__.items(): arg_type = _get_inner_type(field_def.type) default_value = field_def.default if field_def.default is not field_def.default_factory else None help_text = field_def.metadata.get("help", "") nargs = "+" if get_origin(arg_type) == list else None choices = field_def.metadata.get("choices", None) if choices and default_value is not None and default_value not in choices: raise ValueError( f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}." ) arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type if arg_type == bool: self.parser.add_argument( f"--{field_name.replace('_', '-')}", action="store_true", default=default_value, help=help_text, ) else: self.parser.add_argument( f"--{field_name.replace('_', '-')}", type=arg_type, default=default_value, nargs=nargs, choices=choices, help=help_text, ) # Add positional arguments separately for ModelArguments self.parser.add_argument("input_model", help="input onnx model") self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model") self.parser.add_argument("-v", "--version", action="version", version=__version__) def parse_args_into_dataclasses(self): # Pre-parse arguments to check for `--inspect` pre_parsed_args, _ = self.parser.parse_known_args() if pre_parsed_args.inspect: for action in self.parser._actions: if action.dest == "input_model": action.nargs = "+" break args = self.parser.parse_args() args_dict = vars(args) outputs = [] for dtype in self.argument_dataclasses: keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in args_dict.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) return (*outputs,)