252 lines
8.2 KiB
Python
252 lines
8.2 KiB
Python
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
from thop.rnn_hooks import (
|
|
count_gru,
|
|
count_gru_cell,
|
|
count_lstm,
|
|
count_lstm_cell,
|
|
count_rnn,
|
|
count_rnn_cell,
|
|
torch,
|
|
)
|
|
from thop.vision.basic_hooks import (
|
|
count_adap_avgpool,
|
|
count_avgpool,
|
|
count_convNd,
|
|
count_convtNd,
|
|
count_linear,
|
|
count_normalization,
|
|
count_parameters,
|
|
count_prelu,
|
|
count_relu,
|
|
count_softmax,
|
|
count_upsample,
|
|
logging,
|
|
nn,
|
|
zero_ops,
|
|
)
|
|
|
|
from .utils import prRed
|
|
|
|
default_dtype = torch.float64
|
|
|
|
register_hooks = {
|
|
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
|
|
nn.Conv1d: count_convNd,
|
|
nn.Conv2d: count_convNd,
|
|
nn.Conv3d: count_convNd,
|
|
nn.ConvTranspose1d: count_convtNd,
|
|
nn.ConvTranspose2d: count_convtNd,
|
|
nn.ConvTranspose3d: count_convtNd,
|
|
nn.BatchNorm1d: count_normalization,
|
|
nn.BatchNorm2d: count_normalization,
|
|
nn.BatchNorm3d: count_normalization,
|
|
nn.LayerNorm: count_normalization,
|
|
nn.InstanceNorm1d: count_normalization,
|
|
nn.InstanceNorm2d: count_normalization,
|
|
nn.InstanceNorm3d: count_normalization,
|
|
nn.PReLU: count_prelu,
|
|
nn.Softmax: count_softmax,
|
|
nn.ReLU: zero_ops,
|
|
nn.ReLU6: zero_ops,
|
|
nn.LeakyReLU: count_relu,
|
|
nn.MaxPool1d: zero_ops,
|
|
nn.MaxPool2d: zero_ops,
|
|
nn.MaxPool3d: zero_ops,
|
|
nn.AdaptiveMaxPool1d: zero_ops,
|
|
nn.AdaptiveMaxPool2d: zero_ops,
|
|
nn.AdaptiveMaxPool3d: zero_ops,
|
|
nn.AvgPool1d: count_avgpool,
|
|
nn.AvgPool2d: count_avgpool,
|
|
nn.AvgPool3d: count_avgpool,
|
|
nn.AdaptiveAvgPool1d: count_adap_avgpool,
|
|
nn.AdaptiveAvgPool2d: count_adap_avgpool,
|
|
nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
|
nn.Linear: count_linear,
|
|
nn.Dropout: zero_ops,
|
|
nn.Upsample: count_upsample,
|
|
nn.UpsamplingBilinear2d: count_upsample,
|
|
nn.UpsamplingNearest2d: count_upsample,
|
|
nn.RNNCell: count_rnn_cell,
|
|
nn.GRUCell: count_gru_cell,
|
|
nn.LSTMCell: count_lstm_cell,
|
|
nn.RNN: count_rnn,
|
|
nn.GRU: count_gru,
|
|
nn.LSTM: count_lstm,
|
|
nn.Sequential: zero_ops,
|
|
nn.PixelShuffle: zero_ops,
|
|
nn.SyncBatchNorm: count_normalization,
|
|
}
|
|
|
|
|
|
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
|
|
"""Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
|
|
handler_collection = []
|
|
types_collection = set()
|
|
if custom_ops is None:
|
|
custom_ops = {}
|
|
if report_missing:
|
|
verbose = True
|
|
|
|
def add_hooks(m):
|
|
if list(m.children()):
|
|
return
|
|
|
|
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
|
logging.warning(
|
|
f"Either .total_ops or .total_params is already defined in {m!s}. "
|
|
"Be careful, it might change your code's behavior."
|
|
)
|
|
|
|
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
|
|
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
|
|
|
|
for p in m.parameters():
|
|
m.total_params += torch.DoubleTensor([p.numel()])
|
|
|
|
m_type = type(m)
|
|
|
|
fn = None
|
|
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
|
|
fn = custom_ops[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
|
elif m_type in register_hooks:
|
|
fn = register_hooks[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
|
else:
|
|
if m_type not in types_collection and report_missing:
|
|
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
|
|
|
if fn is not None:
|
|
handler = m.register_forward_hook(fn)
|
|
handler_collection.append(handler)
|
|
types_collection.add(m_type)
|
|
|
|
training = model.training
|
|
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
with torch.no_grad():
|
|
model(*inputs)
|
|
|
|
total_ops = 0
|
|
total_params = 0
|
|
for m in model.modules():
|
|
if list(m.children()): # skip for non-leaf module
|
|
continue
|
|
total_ops += m.total_ops
|
|
total_params += m.total_params
|
|
|
|
total_ops = total_ops.item()
|
|
total_params = total_params.item()
|
|
|
|
# reset model to original status
|
|
model.train(training)
|
|
for handler in handler_collection:
|
|
handler.remove()
|
|
|
|
# remove temporal buffers
|
|
for n, m in model.named_modules():
|
|
if list(m.children()):
|
|
continue
|
|
if "total_ops" in m._buffers:
|
|
m._buffers.pop("total_ops")
|
|
if "total_params" in m._buffers:
|
|
m._buffers.pop("total_params")
|
|
|
|
return total_ops, total_params
|
|
|
|
|
|
def profile(
|
|
model: nn.Module,
|
|
inputs,
|
|
custom_ops=None,
|
|
verbose=True,
|
|
ret_layer_info=False,
|
|
report_missing=False,
|
|
):
|
|
"""Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
|
|
handler_collection = {}
|
|
types_collection = set()
|
|
if custom_ops is None:
|
|
custom_ops = {}
|
|
if report_missing:
|
|
# overwrite `verbose` option when enable report_missing
|
|
verbose = True
|
|
|
|
def add_hooks(m: nn.Module):
|
|
"""Registers hooks to a neural network module to track total operations and parameters."""
|
|
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
|
|
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
|
|
|
|
# for p in m.parameters():
|
|
# m.total_params += torch.DoubleTensor([p.numel()])
|
|
|
|
m_type = type(m)
|
|
|
|
fn = None
|
|
if m_type in custom_ops:
|
|
# if defined both op maps, use custom_ops to overwrite.
|
|
fn = custom_ops[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
|
|
elif m_type in register_hooks:
|
|
fn = register_hooks[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
|
|
else:
|
|
if m_type not in types_collection and report_missing:
|
|
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")
|
|
|
|
if fn is not None:
|
|
handler_collection[m] = (
|
|
m.register_forward_hook(fn),
|
|
m.register_forward_hook(count_parameters),
|
|
)
|
|
types_collection.add(m_type)
|
|
|
|
prev_training_status = model.training
|
|
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
with torch.no_grad():
|
|
model(*inputs)
|
|
|
|
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
|
|
"""Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
|
|
total_ops, total_params = module.total_ops.item(), 0
|
|
ret_dict = {}
|
|
for n, m in module.named_children():
|
|
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
|
|
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
|
|
# else:
|
|
# m_ops, m_params = m.total_ops, m.total_params
|
|
next_dict = {}
|
|
if m in handler_collection and not isinstance(m, (nn.Sequential, nn.ModuleList)):
|
|
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
|
else:
|
|
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
|
|
ret_dict[n] = (m_ops, m_params, next_dict)
|
|
total_ops += m_ops
|
|
total_params += m_params
|
|
# print(prefix, module._get_name(), (total_ops, total_params))
|
|
return total_ops, total_params, ret_dict
|
|
|
|
total_ops, total_params, ret_dict = dfs_count(model)
|
|
|
|
# reset model to original status
|
|
model.train(prev_training_status)
|
|
for m, (op_handler, params_handler) in handler_collection.items():
|
|
op_handler.remove()
|
|
params_handler.remove()
|
|
m._buffers.pop("total_ops")
|
|
m._buffers.pop("total_params")
|
|
|
|
if ret_layer_info:
|
|
return total_ops, total_params, ret_dict
|
|
return total_ops, total_params
|