55 lines
1.3 KiB
Python
55 lines
1.3 KiB
Python
from triton.runtime import driver
|
|
from triton.runtime.jit import constexpr_function
|
|
|
|
__all__ = ["current_target"]
|
|
|
|
|
|
def current_target():
|
|
try:
|
|
active_driver = driver.active
|
|
except RuntimeError:
|
|
# If there is no active driver, return None
|
|
return None
|
|
return active_driver.get_current_target()
|
|
|
|
|
|
current_target.__triton_builtin__ = True
|
|
|
|
|
|
@constexpr_function
|
|
def is_cuda():
|
|
target = current_target()
|
|
return target is not None and target.backend == "cuda"
|
|
|
|
|
|
@constexpr_function
|
|
def cuda_capability_geq(major, minor=0):
|
|
"""
|
|
Determines whether we have compute capability >= (major, minor) and
|
|
returns this as a constexpr boolean. This can be used for guarding
|
|
inline asm implementations that require a certain compute capability.
|
|
"""
|
|
target = current_target()
|
|
if target is None or target.backend != "cuda":
|
|
return False
|
|
assert isinstance(target.arch, int)
|
|
return target.arch >= major * 10 + minor
|
|
|
|
|
|
@constexpr_function
|
|
def is_hip():
|
|
target = current_target()
|
|
return target is not None and target.backend == "hip"
|
|
|
|
|
|
@constexpr_function
|
|
def is_hip_cdna3():
|
|
target = current_target()
|
|
return target is not None and target.arch == "gfx942"
|
|
|
|
|
|
@constexpr_function
|
|
def is_hip_cdna4():
|
|
target = current_target()
|
|
return target is not None and target.arch == "gfx950"
|