from __future__ import annotations import importlib.util import logging import os import sys from collections import defaultdict from pathlib import Path import numpy as np import onnx from onnx import checker, helper import onnxslim.third_party.onnx_graphsurgeon as gs from onnxslim.misc.tabulate import SEPARATING_LINE, tabulate from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER logger = logging.getLogger("onnxslim") import ml_dtypes try: from onnx._mapping import TensorDtypeMap except ImportError: from onnx.mapping import TensorDtypeMap TENSOR_TYPE_MAP = {} candidates = [ ("BFLOAT16", "bfloat16", "UINT16"), ("FLOAT8E4M3FN", "float8_e4m3fn", "UINT8"), ("FLOAT8E4M3FNUZ", "float8_e4m3fnuz", "UINT8"), ("FLOAT8E5M2", "float8_e5m2", "UINT8"), ("FLOAT8E5M2FNUZ", "float8_e5m2fnuz", "UINT8"), ("UINT4", "uint4", "INT32"), ("INT4", "int4", "INT32"), ("FLOAT4E2M1", "float4_e2m1fn", "UINT8"), ] for onnx_name, ml_name, storage_name in candidates: if hasattr(onnx.TensorProto, onnx_name) and hasattr(ml_dtypes, ml_name): TENSOR_TYPE_MAP[int(getattr(onnx.TensorProto, onnx_name))] = TensorDtypeMap( np.dtype(getattr(ml_dtypes, ml_name)), int(getattr(onnx.TensorProto, storage_name)), f"TensorProto.{onnx_name}", ) def init_logging(verbose=False): """Configure the logging settings for the application based on the verbosity level.""" logger = logging.getLogger("onnxslim") if verbose: # DEBUG logger.setLevel(logging.DEBUG) G_LOGGER.severity = logging.DEBUG else: # ERROR logger.setLevel(logging.ERROR) G_LOGGER.severity = logging.ERROR if not logger.handlers: handler = logging.StreamHandler(sys.stderr) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) G_LOGGER.colors = False if is_onnxruntime_available(): import onnxruntime as ort ort.set_default_logger_severity(3) return logger def format_bytes(size: int | tuple[int, ...]) -> str: """Convert byte sizes into human-readable format with appropriate units (B, KB, MB, GB).""" if isinstance(size, int): size = (size,) elif isinstance(size, np.integer): size = (int(size),) units = ["B", "KB", "MB", "GB"] formatted_sizes = [] for size_in_bytes in size: unit_index = 0 while size_in_bytes >= 1024 and unit_index < len(units) - 1: size_in_bytes /= 1024 unit_index += 1 formatted_size = f"{size_in_bytes:.2f} {units[unit_index]}" formatted_sizes.append(formatted_size) if len(formatted_sizes) == 1: return formatted_sizes[0] else: return f"{formatted_sizes[0]} ({formatted_sizes[1]})" def onnx_dtype_to_numpy(onnx_dtype: int) -> np.dtype: """Maps an ONNX dtype to its corresponding NumPy dtype.""" tensor_dtype = TENSOR_TYPE_MAP.get(onnx_dtype) if tensor_dtype: return tensor_dtype.np_dtype if onnx_dtype in onnx.helper.get_all_tensor_dtypes(): return np.dtype(helper.tensor_dtype_to_np_dtype(onnx_dtype)) return "UNDEFINED" def gen_onnxruntime_input_data( model: onnx.ModelProto, model_check_inputs: list[str] | None = None ) -> dict[str, np.ndarray]: """Generate random input data for an ONNX model considering potential specific input shapes and types.""" input_info = {} for input_tensor in model.graph.input: name = input_tensor.name shape = [] for dim in input_tensor.type.tensor_type.shape.dim: if dim.HasField("dim_param"): shape.append(dim.dim_param) elif dim.HasField("dim_value"): shape.append(dim.dim_value) else: shape.append(None) dtype = onnx_dtype_to_numpy(input_tensor.type.tensor_type.elem_type) input_info[name] = {"shape": shape, "dtype": dtype} if model_check_inputs: for model_check_input in model_check_inputs: key, value = model_check_input.rsplit(":", 1) if value.endswith(".npy"): if key not in input_info: raise Exception( f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}" ) data = np.load(value) input_info[key] = {"data": data} else: values_list = [int(val) for val in value.split(",")] if key in input_info: input_info[key]["shape"] = values_list else: raise Exception( f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}" ) input_data_dict = {} for name, info in input_info.items(): if "data" in info: input_data_dict[name] = info["data"] else: shapes = [shape if (shape != -1 and not isinstance(shape, str)) else 1 for shape in info["shape"]] shapes = shapes or [1] dtype = info["dtype"] if dtype in {np.int32, np.int64}: random_data = np.random.randint(10, size=shapes).astype(dtype) else: random_data = np.random.rand(*shapes).astype(dtype) input_data_dict[name] = random_data return input_data_dict def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> dict[str, np.array]: """Perform inference using ONNX Runtime on the given model and input data.""" import os import tempfile import onnx import onnxruntime as rt if model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: tmp_dir = tempfile.TemporaryDirectory() tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") location = f"{os.path.basename(tmp_path)}.data" if os.path.exists(location): os.remove(location) onnx.save( model, tmp_path, save_as_external_data=True, all_tensors_to_one_file=True, location=location, ) onnx_model = tmp_path else: onnx_model = model.SerializeToString() sess = rt.InferenceSession(onnx_model, providers=["CPUExecutionProvider"]) onnx_output = sess.run(None, input_data) output_names = [output.name for output in sess.get_outputs()] onnx_output = dict(zip(output_names, onnx_output)) if isinstance(onnx_model, str): model = onnx.load(onnx_model) return onnx_output, model def format_model_info(model_info_list: dict | list[dict], elapsed_time: float | None = None): assert model_info_list, "model_info_list must contain more than one model info" from colorama import Fore, init init() if not isinstance(model_info_list, (list, tuple)): model_info_list = [model_info_list] final_op_info = [] final_op_info.extend( ( ["Model Name"] + [item.tag for item in model_info_list], [SEPARATING_LINE] * (len(model_info_list) + 1), ["Model Info"] + ["Op Set: " + item.op_set + " / IR Version: " + item.ir_version for item in model_info_list], [SEPARATING_LINE] * (len(model_info_list) + 1), ) ) def get_io_info(model_info_list, tag=None): if tag == "OUT": ios = [op_type for model_info in model_info_list for op_type in model_info.output_info] else: ios = [op_type for model_info in model_info_list for op_type in model_info.input_info] ios = list(dict.fromkeys([io.name for io in ios])) io_info = [] for io in ios: input_info_list = [f"{tag}: {io}"] for model_info in model_info_list: if tag == "OUT": io_tensor = model_info.output_maps.get(io, None) else: io_tensor = model_info.input_maps.get(io, None) inputs_shape = (io_tensor.dtype, io_tensor.shape) if io_tensor else "" if isinstance(inputs_shape, (list, tuple)): inputs_shape = ": ".join([str(i) for i in inputs_shape]) input_info_list.append(inputs_shape) io_info.append(input_info_list) return io_info final_op_info.extend(get_io_info(model_info_list, "IN")) final_op_info.extend(get_io_info(model_info_list, "OUT")) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) all_ops = {op_type for model_info in model_info_list for op_type in model_info.op_type_counts} sorted_ops = sorted(all_ops) for op in sorted_ops: op_info_list = [op] float_number = model_info_list[0].op_type_counts.get(op, 0) op_info_list.append(float_number) for model_info in model_info_list[1:]: slimmed_number = model_info.op_type_counts.get(op, 0) if float_number > slimmed_number: slimmed_number = Fore.GREEN + str(slimmed_number) + Fore.WHITE op_info_list.append(slimmed_number) final_op_info.append(op_info_list) final_op_info.extend( ( [SEPARATING_LINE] * (len(model_info_list) + 1), ["Model Size"] + [format_bytes(model_info.model_size) for model_info in model_info_list], ) ) if elapsed_time: final_op_info.extend( ( [SEPARATING_LINE] * (len(model_info_list) + 1), ["Elapsed Time", f"{elapsed_time:.2f} s"], ) ) return final_op_info def print_model_info_as_table(model_info_list: dict | list[dict], elapsed_time: float | None = None): """Prints the model information as a formatted table for the given model name and list of model details.""" if not isinstance(model_info_list, (list, tuple)): model_info_list = [model_info_list] final_op_info = format_model_info(model_info_list, elapsed_time) lines = tabulate( final_op_info, headers=[], tablefmt="pretty", maxcolwidths=[None] + [40] * len(model_info_list), ).split("\n") if elapsed_time: time_row = lines[-2].split("|") time_row[-3] = ( time_row[-2][: len(time_row[-2]) // 2 + 1] + time_row[-3] + time_row[-2][len(time_row[-2]) // 2 :] ) time_row.pop(-2) lines[-2] = "|".join(time_row) output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines]) print(output) def dump_model_info_to_disk(model_info: dict): """Writes model information to a CSV file for a given model name and dictionary of model info.""" import csv csv_file_path = f"{model_info.tag}_model_info.csv" with open(csv_file_path, "a", newline="") as csvfile: # Use 'a' for append mode fieldnames = ["NodeName", "OpType", "OutputDtype", "OutputShape"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) # If the file is empty, write the header if csvfile.tell() == 0: writer.writeheader() # Write the data for node_name, info in model_info.op_info.items(): op_type, output_info_list = info.op, info.outputs if len(output_info_list) >= 1: # Write the first row with actual NodeName and OpType row_data_first = { "NodeName": node_name, "OpType": op_type, "OutputDtype": output_info_list[0].dtype, # First entry in the list "OutputShape": output_info_list[0].shape, # First entry in the list } writer.writerow(row_data_first) # Write subsequent rows with empty strings for NodeName and OpType for output_dtype, output_shape in output_info_list[1:]: row_data_empty = { "NodeName": "", "OpType": "", "OutputDtype": output_dtype, "OutputShape": output_shape, } writer.writerow(row_data_empty) print(f"Model info written to {csv_file_path}") def get_opset(model: onnx.ModelProto) -> int: """Returns the ONNX opset version for a given model.""" try: for importer in model.opset_import: if importer.domain in {"", "ai.onnx"}: return importer.version return None except Exception: return None def get_ir_version(model: onnx.ModelProto) -> int: """Returns the ONNX ir version for a given model.""" try: return model.ir_version except Exception: return None class TensorInfo: def __init__(self, tensor): self.dtype: np.dtype = np.float32 self.shape: tuple[str | int] = None self._extract_info(tensor) def _extract_info(self, tensor): """Extract the data type and shape of an ONNX tensor.""" self.dtype = onnx_dtype_to_numpy(tensor.type.tensor_type.elem_type) shape = None if tensor.type.tensor_type.HasField("shape"): shape = [] for dim in tensor.type.tensor_type.shape.dim: if dim.HasField("dim_param"): shape.append(dim.dim_param) elif dim.HasField("dim_value"): shape.append(dim.dim_value) else: shape.append("?") self.shape = tuple(shape) if shape is not None else None self.name = tensor.name class OperatorInfo: def __init__(self, operator, outputs=None): self.name: str = None self.op: str = None self._extract_info(operator) self.outputs = outputs def _extract_info(self, operator): self.name: str = operator.name self.op: str = operator.op_type class ModelInfo: def __init__(self, model: str | onnx.ModelProto, tag: str = "OnnxSlim"): if isinstance(model, str): tag = Path(model).name model = onnx.load(model) self.tag: str = tag self.model_size: int = -1 self.op_set: str = None self.ir_version: str = None self.op_type_counts: dict[str, int] = defaultdict(int) self.op_info: dict[str, dict] = {} self.input_info: list[str, tuple[str, tuple]] = [] self.output_info: list[str, tuple[str, tuple]] = [] self._summarize_model(model) def _summarize_model(self, model): self.op_set = str(get_opset(model)) self.ir_version = str(get_ir_version(model)) self.model_size = get_initializer_size(model) for input in model.graph.input: self.input_info.append(TensorInfo(input)) for output in model.graph.output: self.output_info.append(TensorInfo(output)) value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} def get_graph_node_info(graph: onnx.GraphProto) -> dict[str, list[str]]: for node in graph.node: op_type = node.op_type self.op_type_counts[op_type] += 1 output_tensor_info = [] for output in node.output: if output in value_info_dict: tensor = value_info_dict[output] tensor_info = TensorInfo(tensor) output_tensor_info.append(tensor_info) self.op_info[node.name] = OperatorInfo(node, output_tensor_info) for attr in node.attribute: ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()} if attr.type in ATTR_TYPE_MAPPING: attr_str = ATTR_TYPE_MAPPING[attr.type] if attr_str == "GRAPH": get_graph_node_info(attr.g) get_graph_node_info(model.graph) @property def input_maps(self): self.input_dict = {input_info.name: input_info for input_info in self.input_info} return self.input_dict @property def output_maps(self): self.output_dict = {output_info.name: output_info for output_info in self.output_info} return self.output_dict def summarize_model(model: str | onnx.ModelProto, tag="OnnxModel") -> dict: """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" logger.debug("Start summarizing model.") model_info = ModelInfo(model, tag) logger.debug("Finish summarizing model.") return model_info def model_save_as_external_data(model: onnx.ModelProto, model_path: str): """Save an ONNX model with tensor data as an external file.""" location = f"{os.path.basename(model_path)}.data" if os.path.exists(location): os.remove(location) onnx.save( model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=location, ) def check_onnx(model: onnx.ModelProto, model_check_inputs=None): """Validates an ONNX model by generating input data and performing inference to check outputs.""" input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs) raw_onnx_output, model = onnxruntime_inference(model, input_data_dict) return input_data_dict, raw_onnx_output, model def check_point(model: onnx.ModelProto): """Imports an ONNX model checkpoint into a Graphsurgeon graph representation.""" return gs.import_onnx(model) def save( model: onnx.ModelProto, model_path: str, model_check: bool = False, save_as_external_data: bool = False, model_info: dict | None = None, ): """Save an ONNX model to a specified path, with optional model checking for validity.""" if model_check: try: checker.check_model(model) except ValueError: logger.warning("Model too large and cannot be checked.") if model_path: # model larger than 2GB can be saved, but compiler like trtexec won't parse it if get_initializer_size(model) <= checker.MAXIMUM_PROTOBUF and not save_as_external_data: onnx.save(model, model_path) else: import os location = f"{os.path.basename(model_path)}.data" if os.path.exists(location): os.remove(location) onnx.save( model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=location, ) logger.debug("Model too large and saved as external data automatically.") if model_info: model_size = model.ByteSize() model_info.model_size = [model_size, model_info.model_size] def check_result(raw_onnx_output, slimmed_onnx_output): """Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are detected. """ if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()): print("Model output mismatch after slimming.") print(f"Raw model output keys: {raw_onnx_output.keys()}") print(f"Slimmed model output keys: {slimmed_onnx_output.keys()}") print("Please check the model carefully.") return False else: for key in raw_onnx_output.keys(): if not np.allclose( raw_onnx_output[key], slimmed_onnx_output[key], rtol=1e-03, atol=1e-04, equal_nan=True, ): print(f"\033[31mModel output {key} mismatch after slimming.") print("\033[31mPlease check the model carefully.") return False return True def get_numpy_type(onnx_type): if not isinstance(onnx_type, int): # Already a NumPy type return onnx_type numpy_unsupported_types = [ onnx.TensorProto.BFLOAT16, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2FNUZ, ] # TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types. # This obviously breaks things, so we need to treat this as a special case. if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes(): return onnx.helper.tensor_dtype_to_np_dtype(onnx_type) return None def get_itemsize(dtype): np_dtype = get_numpy_type(dtype) if np_dtype is not None: return np.dtype(np_dtype).itemsize if dtype == onnx.TensorProto.BFLOAT16: return 2 if dtype in [ onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2FNUZ, ]: return 1 print(f"Unknown ONNX dtype: {dtype}") raise ValueError(f"Unsupported TensorProto dtype: {dtype}") def calculate_tensor_size(tensor): """Calculates the size of an ONNX tensor in bytes based on its shape and data type size.""" shape = tensor.dims num_elements = np.prod(shape) if shape else 0 element_size = get_itemsize(tensor.data_type) return num_elements * element_size def get_initializer_size(model): """Calculate total size of all subgraphs in an ONNX model.""" total_size = get_graph_initializer_size(model.graph) return total_size def get_graph_initializer_size(graph): initializer_size = 0 for tensor in graph.initializer: tensor_size = calculate_tensor_size(tensor) initializer_size += tensor_size for node in graph.node: if node.op_type == "Constant": for attr in node.attribute: if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: initializer_size += calculate_tensor_size(attr.t) elif node.op_type == "If": initializer_size += get_graph_initializer_size(node.attribute[0].g) initializer_size += get_graph_initializer_size(node.attribute[1].g) elif node.op_type == "Loop": initializer_size += get_graph_initializer_size(node.attribute[0].g) elif node.op_type == "Scan": initializer_size += get_graph_initializer_size(node.attribute[0].g) return initializer_size def is_onnxruntime_available(): if importlib.util.find_spec("onnxruntime") is None: logger = logging.getLogger("onnxslim") logger.debug("onnxruntime is not available, please install it first for better optimization") return False else: try: # in case of onnxruntime import error import onnxruntime as ort if hasattr(ort, "__version__"): return True else: return False except: logger = logging.getLogger("onnxslim") logger.debug("onnxruntime is not available, please install it first for better optimization") return False def check_onnx_compatibility(): """Ensure ONNX Runtime and ONNX versions are compatible for model inference.""" compatibility_dict = { "1.20": "1.16", "1.19": "1.16", "1.18": "1.16", "1.17": "1.15", "1.16": "1.14.1", "1.15": "1.14", "1.14": "1.13", "1.13": "1.12", "1.12": "1.12", "1.11": "1.11", "1.10": "1.10", "1.9": "1.10", "1.8": "1.9", "1.7": "1.8", "1.6": "1.8", "1.5": "1.7", "1.4": "1.7", "1.3": "1.7", "1.2": "1.6", "1.1": "1.6", "1.0": "1.6", "0.5": "1.5", "0.4": "1.5", "0.3": "1.4", "0.2": "1.3", "0.1": "1.3", } import onnx import onnxruntime onnx_version = onnx.__version__ # ort_version = onnxruntime.__version__ ort_version = ".".join(onnxruntime.__version__.split("+")[0].split(".")[:2]) # Check compatibility expected_onnx_version = compatibility_dict.get(ort_version) if expected_onnx_version is None: print( f"Warning: Onnx Runtime version {ort_version} has no specified compatible ONNX version. Compatibility issues may occur." ) elif expected_onnx_version == ".".join(onnx_version.split("+")[0].split(".")[:2]): logger.info( f"Installed Onnx Runtime version {ort_version} is compatible with installed ONNX version {onnx_version}." ) else: print( f"Warning: Installed Onnx Runtime version {ort_version} is not compatible with installed ONNX version {onnx_version}. Expected ONNX version: {expected_onnx_version}." ) def get_max_tensor(model, topk=5): graph = gs.import_onnx(model) tensor_map = graph.tensors() constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)] sub_graphs = graph.subgraphs(recursive=True) sub_graphs_constant_tensors = [ [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)] for sub_graph in sub_graphs ] constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors]) sizes = [tensor.values.size for tensor in constant_tensors] sorted_indices = np.argsort(sizes)[::-1][:topk] for i in sorted_indices: tensor = constant_tensors[i] print( f"Tensor name: {tensor.name}, shape: {tensor.values.shape}, dtype: {tensor.values.dtype} size: {tensor.values.size}" ) # copied from https://onnx.ai/onnx/api/tools.html def update_outputs_dims( model, output_dims, ): dim_param_set: set[str] = set() def init_dim_param_set(dim_param_set, value_infos): for info in value_infos: shape = info.type.tensor_type.shape for dim in shape.dim: if dim.HasField("dim_param"): dim_param_set.add(dim.dim_param) # type: ignore init_dim_param_set(dim_param_set, model.graph.output) # type: ignore def update_dim(tensor, dim, j, name) -> None: dim_proto = tensor.type.tensor_type.shape.dim[j] # if it's int in model, it won't be replaced by original symbol if dim_proto.HasField("dim_value"): return if isinstance(dim, int): if dim >= 0: if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim: raise ValueError( f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}." ) dim_proto.dim_value = dim else: generated_dim_param = name + "_" + str(j) if generated_dim_param in dim_param_set: raise ValueError( f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value." ) dim_proto.dim_param = generated_dim_param elif isinstance(dim, str): dim_proto.dim_param = dim else: raise ValueError(f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}") for output in model.graph.output: output_name = output.name output_dim_arr = output_dims[output_name] if output_dim_arr is None: continue if len(output.type.tensor_type.shape.dim) == 0: for _ in range(len(output_dim_arr)): output.type.tensor_type.shape.dim.add() for j, dim in enumerate(output_dim_arr): update_dim(output, dim, j, output_name) return model