239 lines
8.2 KiB
Python
239 lines
8.2 KiB
Python
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
import logging
|
|
from distutils.version import LooseVersion
|
|
|
|
import torch
|
|
import torch as th
|
|
import torch.nn as nn
|
|
from torch.fx import symbolic_trace
|
|
from torch.fx.passes.shape_prop import ShapeProp
|
|
|
|
from .utils import prRed, prYellow
|
|
from .vision.calc_func import calculate_conv
|
|
|
|
if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
|
|
logging.warning(
|
|
f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
|
|
)
|
|
|
|
|
|
def count_clamp(input_shapes, output_shapes):
|
|
"""Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
|
|
return 0
|
|
|
|
|
|
def count_mul(input_shapes, output_shapes):
|
|
"""Returns the number of elements in the first output shape."""
|
|
return output_shapes[0].numel()
|
|
|
|
|
|
def count_matmul(input_shapes, output_shapes):
|
|
"""Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
|
|
in_shape = input_shapes[0]
|
|
out_shape = output_shapes[0]
|
|
in_features = in_shape[-1]
|
|
num_elements = out_shape.numel()
|
|
return in_features * num_elements
|
|
|
|
|
|
def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
|
|
"""Calculates the total FLOPs for a linear layer, including bias operations if specified."""
|
|
flops = count_matmul(input_shapes, output_shapes)
|
|
if "bias" in kwargs:
|
|
flops += output_shapes[0].numel()
|
|
return flops
|
|
|
|
|
|
def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
|
|
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
|
|
`calculate_conv`.
|
|
"""
|
|
_inputs, _weight, _bias, _stride, _padding, _dilation, groups = args
|
|
if len(input_shapes) == 2:
|
|
x_shape, k_shape = input_shapes
|
|
elif len(input_shapes) == 3:
|
|
x_shape, k_shape, _b_shape = input_shapes
|
|
out_shape = output_shapes[0]
|
|
|
|
kernel_parameters = k_shape[2:].numel()
|
|
bias_op = 0 # check it later
|
|
in_channel = x_shape[1]
|
|
|
|
total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
|
|
return int(total_ops)
|
|
|
|
|
|
def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
|
|
"""Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
|
|
return count_matmul(input_shapes, output_shapes)
|
|
|
|
|
|
def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
|
|
"""Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
|
|
return 0
|
|
|
|
|
|
def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
|
|
"""Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
|
|
bias_op = 1 if module.bias is not None else 0
|
|
out_shape = output_shapes[0]
|
|
|
|
in_channel = module.in_channels
|
|
groups = module.groups
|
|
kernel_ops = module.weight.shape[2:].numel()
|
|
total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
|
|
return int(total_ops)
|
|
|
|
|
|
def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
|
|
"""Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
|
|
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
|
|
y = output_shapes[0]
|
|
return 2 * y.numel()
|
|
|
|
|
|
zero_ops = (
|
|
nn.ReLU,
|
|
nn.ReLU6,
|
|
nn.Dropout,
|
|
nn.MaxPool2d,
|
|
nn.AvgPool2d,
|
|
nn.AdaptiveAvgPool2d,
|
|
)
|
|
|
|
count_map = {
|
|
nn.Linear: count_nn_linear,
|
|
nn.Conv2d: count_nn_conv2d,
|
|
nn.BatchNorm2d: count_nn_bn2d,
|
|
"function linear": count_fn_linear,
|
|
"clamp": count_clamp,
|
|
"built-in function add": count_zero_ops,
|
|
"built-in method fl": count_zero_ops,
|
|
"built-in method conv2d of type object": count_fn_conv2d,
|
|
"built-in function mul": count_mul,
|
|
"built-in function truediv": count_mul,
|
|
}
|
|
|
|
for k in zero_ops:
|
|
count_map[k] = count_zero_ops
|
|
|
|
missing_maps = {}
|
|
|
|
|
|
def null_print(*args, **kwargs):
|
|
"""A no-op print function that takes any arguments without performing any actions."""
|
|
return
|
|
|
|
|
|
def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
|
|
"""Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
|
|
gm: torch.fx.GraphModule = symbolic_trace(mod)
|
|
ShapeProp(gm).propagate(input)
|
|
|
|
fprint = null_print
|
|
if verbose:
|
|
fprint = print
|
|
|
|
v_maps = {}
|
|
total_flops = 0
|
|
|
|
for node in gm.graph.nodes:
|
|
# print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
|
|
fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
|
|
# node_op_type = str(node.target).split(".")[-1]
|
|
node_flops = None
|
|
|
|
input_shapes = []
|
|
fprint("input_shape:", end="\t")
|
|
for arg in node.args:
|
|
if str(arg) not in v_maps:
|
|
continue
|
|
fprint(f"{v_maps[str(arg)]}", end="\t")
|
|
input_shapes.append(v_maps[str(arg)])
|
|
fprint()
|
|
fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
|
|
output_shapes = [node.meta["tensor_meta"].shape]
|
|
if node.op in ["output", "placeholder"]:
|
|
node_flops = 0
|
|
elif node.op == "call_function":
|
|
# torch internal functions
|
|
key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
|
|
if key in count_map:
|
|
node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
|
|
else:
|
|
missing_maps[key] = (node.op, key)
|
|
prRed(f"|{key}| is missing")
|
|
elif node.op == "call_method":
|
|
# torch internal functions
|
|
# fprint(str(node.target) in count_map, str(node.target), count_map.keys())
|
|
key = str(node.target)
|
|
if key in count_map:
|
|
node_flops = count_map[key](input_shapes, output_shapes)
|
|
else:
|
|
missing_maps[key] = (node.op, key)
|
|
prRed(f"{key} is missing")
|
|
elif node.op == "call_module":
|
|
# torch.nn modules
|
|
# m = getattr(mod, node.target, None)
|
|
m = mod.get_submodule(node.target)
|
|
key = type(m)
|
|
fprint(type(m), type(m) in count_map)
|
|
if type(m) in count_map:
|
|
node_flops = count_map[type(m)](m, input_shapes, output_shapes)
|
|
else:
|
|
missing_maps[key] = (node.op,)
|
|
prRed(f"{key} is missing")
|
|
print("module type:", type(m))
|
|
if isinstance(m, zero_ops):
|
|
print("weight_shape: None")
|
|
else:
|
|
print(type(m))
|
|
print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")
|
|
|
|
v_maps[str(node.name)] = node.meta["tensor_meta"].shape
|
|
if node_flops is not None:
|
|
total_flops += node_flops
|
|
prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
|
|
fprint("==" * 40)
|
|
|
|
if len(missing_maps.keys()) > 0:
|
|
from pprint import pprint
|
|
|
|
print("Missing operators: ")
|
|
pprint(missing_maps)
|
|
return total_flops
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
class MyOP(nn.Module):
|
|
"""Custom operator that performs a simple forward pass dividing input by 1."""
|
|
|
|
def forward(self, input):
|
|
"""Performs forward pass on given input data."""
|
|
return input / 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
"""Neural network module with two linear layers and a custom MyOP operator."""
|
|
|
|
def __init__(self):
|
|
"""Initializes MyModule with two linear layers and a custom MyOP operator."""
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 3)
|
|
self.linear2 = torch.nn.Linear(5, 3)
|
|
self.myop = MyOP()
|
|
|
|
def forward(self, x):
|
|
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes
|
|
with MyOP operator.
|
|
"""
|
|
out1 = self.linear1(x)
|
|
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
|
|
return self.myop(out1 + out2)
|
|
|
|
net = MyModule()
|
|
data = th.randn(20, 5)
|
|
flops = fx_profile(net, data, verbose=False)
|
|
print(flops)
|