DriverTrac/venv/lib/python3.12/site-packages/thop/vision/calc_func.py

133 lines
4.4 KiB
Python

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import warnings
import numpy as np
import torch
def l_prod(in_list):
"""Compute the product of all elements in the input list."""
res = 1
for _ in in_list:
res *= _
return res
def l_sum(in_list):
"""Calculate the sum of all numerical elements in a list."""
return sum(in_list)
def calculate_parameters(param_list):
"""Calculate the total number of parameters in a list of tensors using the product of their shapes."""
return sum(torch.DoubleTensor([p.nelement()]) for p in param_list)
def calculate_zero_ops():
"""Initializes and returns a tensor with all elements set to zero."""
return torch.DoubleTensor([0])
def calculate_conv2d_flops(
input_size: list, output_size: list, kernel_size: list, groups: int, bias: bool = False, transpose: bool = False
):
"""Calculate FLOPs for a Conv2D layer using input/output sizes, kernel size, groups, and the bias flag."""
# n, in_c, ih, iw = input_size
# out_c, in_c, kh, kw = kernel_size
if transpose:
out_c = output_size[1]
return l_prod(input_size) * (out_c // groups) * l_prod(kernel_size[2:])
else:
in_c = input_size[1]
return l_prod(output_size) * (in_c // groups) * l_prod(kernel_size[2:])
def calculate_conv(bias, kernel_size, output_size, in_channel, group):
"""Calculate FLOPs for convolutional layers given bias, kernel size, output size, in_channels, and groups."""
warnings.warn("This API is being deprecated.")
return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)])
def calculate_norm(input_size):
"""Compute the L2 norm of a tensor or array based on its input size."""
return torch.DoubleTensor([2 * input_size])
def calculate_relu_flops(input_size):
"""Calculates the FLOPs for a ReLU activation function based on the input tensor's dimensions."""
return 0
def calculate_relu(input_size: torch.Tensor):
"""Convert an input tensor to a DoubleTensor with the same value (deprecated)."""
warnings.warn("This API is being deprecated")
return torch.DoubleTensor([int(input_size)])
def calculate_softmax(batch_size, nfeatures):
"""Compute FLOPs for a softmax activation given batch size and feature count."""
total_exp = nfeatures
total_add = nfeatures - 1
total_div = nfeatures
total_ops = batch_size * (total_exp + total_add + total_div)
return torch.DoubleTensor([int(total_ops)])
def calculate_avgpool(input_size):
"""Calculate the average pooling size for a given input tensor."""
return torch.DoubleTensor([int(input_size)])
def calculate_adaptive_avg(kernel_size, output_size):
"""Calculate FLOPs for adaptive average pooling given kernel size and output size."""
total_div = 1
kernel_op = kernel_size + total_div
return torch.DoubleTensor([int(kernel_op * output_size)])
def calculate_upsample(mode: str, output_size):
"""Calculate the operations required for various upsample methods based on mode and output size."""
total_ops = output_size
if mode == "bicubic":
total_ops *= 224 + 35
elif mode == "bilinear":
total_ops *= 11
elif mode == "linear":
total_ops *= 5
elif mode == "trilinear":
total_ops *= 13 * 2 + 5
return torch.DoubleTensor([int(total_ops)])
def calculate_linear(in_feature, num_elements):
"""Calculate the linear operation count for given input feature and number of elements."""
return torch.DoubleTensor([int(in_feature * num_elements)])
def counter_matmul(input_size, output_size):
"""Calculate the total number of operations for matrix multiplication given input and output sizes."""
input_size = np.array(input_size)
output_size = np.array(output_size)
return np.prod(input_size) * output_size[-1]
def counter_mul(input_size):
"""Calculate the total number of operations for element-wise multiplication given the input size."""
return input_size
def counter_pow(input_size):
"""Computes the total scalar multiplications required for power operations based on input size."""
return input_size
def counter_sqrt(input_size):
"""Calculate the total number of scalar operations required for a square root operation given an input size."""
return input_size
def counter_div(input_size):
"""Calculate the total number of scalar operations for a division operation given an input size."""
return input_size