# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from thop.rnn_hooks import ( count_gru, count_gru_cell, count_lstm, count_lstm_cell, count_rnn, count_rnn_cell, torch, ) from thop.vision.basic_hooks import ( count_adap_avgpool, count_avgpool, count_convNd, count_convtNd, count_linear, count_normalization, count_parameters, count_prelu, count_relu, count_softmax, count_upsample, logging, nn, zero_ops, ) from .utils import prRed default_dtype = torch.float64 register_hooks = { nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication. nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd, nn.ConvTranspose1d: count_convtNd, nn.ConvTranspose2d: count_convtNd, nn.ConvTranspose3d: count_convtNd, nn.BatchNorm1d: count_normalization, nn.BatchNorm2d: count_normalization, nn.BatchNorm3d: count_normalization, nn.LayerNorm: count_normalization, nn.InstanceNorm1d: count_normalization, nn.InstanceNorm2d: count_normalization, nn.InstanceNorm3d: count_normalization, nn.PReLU: count_prelu, nn.Softmax: count_softmax, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu, nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops, nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops, nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool, nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool, nn.Linear: count_linear, nn.Dropout: zero_ops, nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample, nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell, nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, nn.Sequential: zero_ops, nn.PixelShuffle: zero_ops, nn.SyncBatchNorm: count_normalization, } def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False): """Profiles a PyTorch model's operations and parameters, applying either custom or default hooks.""" handler_collection = [] types_collection = set() if custom_ops is None: custom_ops = {} if report_missing: verbose = True def add_hooks(m): if list(m.children()): return if hasattr(m, "total_ops") or hasattr(m, "total_params"): logging.warning( f"Either .total_ops or .total_params is already defined in {m!s}. " "Be careful, it might change your code's behavior." ) m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype)) m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype)) for p in m.parameters(): m.total_params += torch.DoubleTensor([p.numel()]) m_type = type(m) fn = None if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite. fn = custom_ops[m_type] if m_type not in types_collection and verbose: print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.") elif m_type in register_hooks: fn = register_hooks[m_type] if m_type not in types_collection and verbose: print(f"[INFO] Register {fn.__qualname__}() for {m_type}.") else: if m_type not in types_collection and report_missing: prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.") if fn is not None: handler = m.register_forward_hook(fn) handler_collection.append(handler) types_collection.add(m_type) training = model.training model.eval() model.apply(add_hooks) with torch.no_grad(): model(*inputs) total_ops = 0 total_params = 0 for m in model.modules(): if list(m.children()): # skip for non-leaf module continue total_ops += m.total_ops total_params += m.total_params total_ops = total_ops.item() total_params = total_params.item() # reset model to original status model.train(training) for handler in handler_collection: handler.remove() # remove temporal buffers for n, m in model.named_modules(): if list(m.children()): continue if "total_ops" in m._buffers: m._buffers.pop("total_ops") if "total_params" in m._buffers: m._buffers.pop("total_params") return total_ops, total_params def profile( model: nn.Module, inputs, custom_ops=None, verbose=True, ret_layer_info=False, report_missing=False, ): """Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details.""" handler_collection = {} types_collection = set() if custom_ops is None: custom_ops = {} if report_missing: # overwrite `verbose` option when enable report_missing verbose = True def add_hooks(m: nn.Module): """Registers hooks to a neural network module to track total operations and parameters.""" m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64)) m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64)) # for p in m.parameters(): # m.total_params += torch.DoubleTensor([p.numel()]) m_type = type(m) fn = None if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite. fn = custom_ops[m_type] if m_type not in types_collection and verbose: print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.") elif m_type in register_hooks: fn = register_hooks[m_type] if m_type not in types_collection and verbose: print(f"[INFO] Register {fn.__qualname__}() for {m_type}.") else: if m_type not in types_collection and report_missing: prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.") if fn is not None: handler_collection[m] = ( m.register_forward_hook(fn), m.register_forward_hook(count_parameters), ) types_collection.add(m_type) prev_training_status = model.training model.eval() model.apply(add_hooks) with torch.no_grad(): model(*inputs) def dfs_count(module: nn.Module, prefix="\t") -> (int, int): """Recursively counts the total operations and parameters of the given PyTorch module and its submodules.""" total_ops, total_params = module.total_ops.item(), 0 ret_dict = {} for n, m in module.named_children(): # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0: # m_ops, m_params = dfs_count(m, prefix=prefix + "\t") # else: # m_ops, m_params = m.total_ops, m.total_params next_dict = {} if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)): m_ops, m_params = m.total_ops.item(), m.total_params.item() else: m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") ret_dict[n] = (m_ops, m_params, next_dict) total_ops += m_ops total_params += m_params # print(prefix, module._get_name(), (total_ops, total_params)) return total_ops, total_params, ret_dict total_ops, total_params, ret_dict = dfs_count(model) # reset model to original status model.train(prev_training_status) for m, (op_handler, params_handler) in handler_collection.items(): op_handler.remove() params_handler.remove() m._buffers.pop("total_ops") m._buffers.pop("total_params") if ret_layer_info: return total_ops, total_params, ret_dict return total_ops, total_params