DriverTrac/venv/lib/python3.12/site-packages/triton/language/target_info.py

55 lines
1.3 KiB
Python

from triton.runtime import driver
from triton.runtime.jit import constexpr_function
__all__ = ["current_target"]
def current_target():
try:
active_driver = driver.active
except RuntimeError:
# If there is no active driver, return None
return None
return active_driver.get_current_target()
current_target.__triton_builtin__ = True
@constexpr_function
def is_cuda():
target = current_target()
return target is not None and target.backend == "cuda"
@constexpr_function
def cuda_capability_geq(major, minor=0):
"""
Determines whether we have compute capability >= (major, minor) and
returns this as a constexpr boolean. This can be used for guarding
inline asm implementations that require a certain compute capability.
"""
target = current_target()
if target is None or target.backend != "cuda":
return False
assert isinstance(target.arch, int)
return target.arch >= major * 10 + minor
@constexpr_function
def is_hip():
target = current_target()
return target is not None and target.backend == "hip"
@constexpr_function
def is_hip_cdna3():
target = current_target()
return target is not None and target.arch == "gfx942"
@constexpr_function
def is_hip_cdna4():
target = current_target()
return target is not None and target.arch == "gfx950"