70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
flops_by_device = {
|
|
"CUDA": {
|
|
"80":
|
|
lambda width, **kwargs: 624e12 / (width / 8),
|
|
"89":
|
|
lambda width, **kwargs: (330.3 * 1e12) / (width / 8), # TODO(Keren): Implement fp16 acc-> 660.6 fp8
|
|
"90":
|
|
lambda width, num_sms, clock_rate, **kwargs: ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) /
|
|
(width / 8),
|
|
"100":
|
|
lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8),
|
|
}
|
|
}
|
|
|
|
amd_bps_by_arch = {
|
|
'gfx90a': 3.2 * 1e12,
|
|
'gfx942': 5.3 * 1e12,
|
|
'gfx950': 8.0 * 1e12,
|
|
}
|
|
|
|
# FP8 Matrix Performance(FLOPS/clock/CU)
|
|
# For gfx90a we use the performance of INT8 since it doesn't support FP8 matrix operations.
|
|
amd_fp8_flops_by_arch = {'gfx90a': 1024, 'gfx942': 4096, 'gfx950': 8192}
|
|
|
|
|
|
def max_flops(device_type, arch, width, num_sms, clock_rate):
|
|
"""
|
|
Calculate the maximum FLOPS for a given device type and width.
|
|
|
|
Args:
|
|
device_type (str): The type of device (e.g., "CUDA", "HIP").
|
|
arch (str): The architecture of the device (e.g., "80", "90").
|
|
width (int): The width in bits.
|
|
num_sms (int): The number of streaming multiprocessors.
|
|
clock_rate (float): The clock rate in GHz.
|
|
|
|
Returns:
|
|
float: The maximum FLOPS for the given device type and width.
|
|
"""
|
|
if device_type == "HIP":
|
|
return amd_fp8_flops_by_arch[arch] * num_sms * clock_rate * 1e3 / (width / 8)
|
|
|
|
if device_type not in flops_by_device:
|
|
raise ValueError(f"Unsupported device type: {device_type}")
|
|
|
|
if arch not in flops_by_device[device_type]:
|
|
raise ValueError(f"Unsupported architecture: {arch}")
|
|
|
|
flops_func = flops_by_device[device_type][arch]
|
|
|
|
return flops_func(width, num_sms=num_sms, clock_rate=clock_rate)
|
|
|
|
|
|
def max_bps(device_type, arch, bus_width, memory_clock_rate):
|
|
"""
|
|
Calculate the maximum bytes per second for a given bus width and memory clock rate.
|
|
|
|
Args:
|
|
bus_width (int): The bus width in bits.
|
|
memory_clock_rate (float): The memory clock rate in GHz.
|
|
|
|
Returns:
|
|
float: The maximum bytes per second.
|
|
"""
|
|
if device_type == "CUDA":
|
|
return 2 * bus_width * memory_clock_rate * 1e3 / 8
|
|
else:
|
|
assert device_type == "HIP"
|
|
return amd_bps_by_arch[arch]
|