DriverTrac/venv/lib/python3.12/site-packages/triton/language/__init__.py

343 lines
6.4 KiB
Python

"""isort:skip_file"""
# Import order is significant here.
from . import math
from . import extra
from .standard import (
argmax,
argmin,
bitonic_merge,
cdiv,
cumprod,
cumsum,
flip,
interleave,
max,
min,
ravel,
reduce_or,
sigmoid,
softmax,
sort,
sum,
swizzle2d,
topk,
xor_sum,
zeros,
zeros_like,
)
from .core import (
PropagateNan,
TRITON_MAX_TENSOR_NUMEL,
load_tensor_descriptor,
store_tensor_descriptor,
make_tensor_descriptor,
tensor_descriptor,
tensor_descriptor_type,
add,
advance,
arange,
associative_scan,
assume,
async_task,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
broadcast,
broadcast_to,
cat,
cast,
clamp,
condition,
const,
constexpr,
constexpr_type,
debug_barrier,
device_assert,
device_print,
dot,
dot_scaled,
dtype,
expand_dims,
float16,
float32,
float64,
float8e4b15,
float8e4nv,
float8e4b8,
float8e5,
float8e5b16,
full,
gather,
histogram,
inline_asm_elementwise,
int1,
int16,
int32,
int64,
int8,
join,
load,
make_block_ptr,
map_elementwise,
max_constancy,
max_contiguous,
maximum,
minimum,
multiple_of,
num_programs,
permute,
pi32_t,
pointer_type,
program_id,
range,
reduce,
reshape,
slice,
split,
static_assert,
static_print,
static_range,
store,
tensor,
trans,
tuple,
tuple_type,
uint16,
uint32,
uint64,
uint8,
view,
void,
where,
)
from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
ceil)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint_to_uniform_float,
)
from . import target_info
__all__ = [
"PropagateNan",
"TRITON_MAX_TENSOR_NUMEL",
"load_tensor_descriptor",
"store_tensor_descriptor",
"make_tensor_descriptor",
"tensor_descriptor",
"abs",
"add",
"advance",
"arange",
"argmax",
"argmin",
"associative_scan",
"assume",
"async_task",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"bitonic_merge",
"block_type",
"broadcast",
"broadcast_to",
"cat",
"cast",
"cdiv",
"ceil",
"clamp",
"condition",
"const",
"constexpr",
"constexpr_type",
"cos",
"cumprod",
"cumsum",
"debug_barrier",
"device_assert",
"device_print",
"div_rn",
"dot",
"dot_scaled",
"dtype",
"erf",
"exp",
"exp2",
"expand_dims",
"extra",
"fdiv",
"flip",
"float16",
"float32",
"float64",
"float8e4b15",
"float8e4nv",
"float8e4b8",
"float8e5",
"float8e5b16",
"floor",
"fma",
"full",
"gather",
"histogram",
"inline_asm_elementwise",
"interleave",
"int1",
"int16",
"int32",
"int64",
"int8",
"join",
"load",
"log",
"log2",
"make_block_ptr",
"map_elementwise",
"math",
"max",
"max_constancy",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"permute",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"range",
"ravel",
"reduce",
"reduce_or",
"reshape",
"rsqrt",
"slice",
"sigmoid",
"sin",
"softmax",
"sort",
"split",
"sqrt",
"sqrt_rn",
"static_assert",
"static_print",
"static_range",
"store",
"sum",
"swizzle2d",
"target_info",
"tensor",
"topk",
"trans",
"tuple",
"uint16",
"uint32",
"uint64",
"uint8",
"uint_to_uniform_float",
"umulhi",
"view",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]
def str_to_ty(name, c):
from builtins import tuple
if isinstance(name, tuple):
fields = type(name).__dict__.get("_fields", None)
return tuple_type([str_to_ty(x, c) for x in name], fields)
if name[0] == "*":
name = name[1:]
const = False
if name[0] == "k":
name = name[1:]
const = True
ty = str_to_ty(name, c)
return pointer_type(element_ty=ty, const=const)
if name.startswith("tensordesc"):
inner = name.split("<")[1].rstrip(">")
dtype, rest = inner.split("[", maxsplit=1)
block_shape, rest = rest.split("]", maxsplit=1)
block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
layout = rest.lstrip(",")
is_gluon = len(layout)
dtype = str_to_ty(dtype, None)
ndim = len(block_shape)
shape_type = tuple_type([int32] * ndim)
# FIXME: Last dim stride should be constexpr(1)
stride_type = tuple_type(([int64] * ndim))
block = block_type(dtype, block_shape)
if is_gluon:
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type
layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout))
assert isinstance(layout, NVMMASharedLayout)
return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
return tensor_descriptor_type(block, shape_type, stride_type)
if name.startswith("constexpr"):
return constexpr_type(c)
tys = {
"fp8e4nv": float8e4nv,
"fp8e4b8": float8e4b8,
"fp8e5": float8e5,
"fp8e5b16": float8e5b16,
"fp8e4b15": float8e4b15,
"fp16": float16,
"bf16": bfloat16,
"fp32": float32,
"fp64": float64,
"i1": int1,
"i8": int8,
"i16": int16,
"i32": int32,
"i64": int64,
"u1": int1,
"u8": uint8,
"u16": uint16,
"u32": uint32,
"u64": uint64,
"B": int1,
}
return tys[name]