343 lines
6.4 KiB
Python
343 lines
6.4 KiB
Python
"""isort:skip_file"""
|
|
# Import order is significant here.
|
|
|
|
from . import math
|
|
from . import extra
|
|
from .standard import (
|
|
argmax,
|
|
argmin,
|
|
bitonic_merge,
|
|
cdiv,
|
|
cumprod,
|
|
cumsum,
|
|
flip,
|
|
interleave,
|
|
max,
|
|
min,
|
|
ravel,
|
|
reduce_or,
|
|
sigmoid,
|
|
softmax,
|
|
sort,
|
|
sum,
|
|
swizzle2d,
|
|
topk,
|
|
xor_sum,
|
|
zeros,
|
|
zeros_like,
|
|
)
|
|
from .core import (
|
|
PropagateNan,
|
|
TRITON_MAX_TENSOR_NUMEL,
|
|
load_tensor_descriptor,
|
|
store_tensor_descriptor,
|
|
make_tensor_descriptor,
|
|
tensor_descriptor,
|
|
tensor_descriptor_type,
|
|
add,
|
|
advance,
|
|
arange,
|
|
associative_scan,
|
|
assume,
|
|
async_task,
|
|
atomic_add,
|
|
atomic_and,
|
|
atomic_cas,
|
|
atomic_max,
|
|
atomic_min,
|
|
atomic_or,
|
|
atomic_xchg,
|
|
atomic_xor,
|
|
bfloat16,
|
|
block_type,
|
|
broadcast,
|
|
broadcast_to,
|
|
cat,
|
|
cast,
|
|
clamp,
|
|
condition,
|
|
const,
|
|
constexpr,
|
|
constexpr_type,
|
|
debug_barrier,
|
|
device_assert,
|
|
device_print,
|
|
dot,
|
|
dot_scaled,
|
|
dtype,
|
|
expand_dims,
|
|
float16,
|
|
float32,
|
|
float64,
|
|
float8e4b15,
|
|
float8e4nv,
|
|
float8e4b8,
|
|
float8e5,
|
|
float8e5b16,
|
|
full,
|
|
gather,
|
|
histogram,
|
|
inline_asm_elementwise,
|
|
int1,
|
|
int16,
|
|
int32,
|
|
int64,
|
|
int8,
|
|
join,
|
|
load,
|
|
make_block_ptr,
|
|
map_elementwise,
|
|
max_constancy,
|
|
max_contiguous,
|
|
maximum,
|
|
minimum,
|
|
multiple_of,
|
|
num_programs,
|
|
permute,
|
|
pi32_t,
|
|
pointer_type,
|
|
program_id,
|
|
range,
|
|
reduce,
|
|
reshape,
|
|
slice,
|
|
split,
|
|
static_assert,
|
|
static_print,
|
|
static_range,
|
|
store,
|
|
tensor,
|
|
trans,
|
|
tuple,
|
|
tuple_type,
|
|
uint16,
|
|
uint32,
|
|
uint64,
|
|
uint8,
|
|
view,
|
|
void,
|
|
where,
|
|
)
|
|
from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
|
|
ceil)
|
|
from .random import (
|
|
pair_uniform_to_normal,
|
|
philox,
|
|
philox_impl,
|
|
rand,
|
|
rand4x,
|
|
randint,
|
|
randint4x,
|
|
randn,
|
|
randn4x,
|
|
uint_to_uniform_float,
|
|
)
|
|
from . import target_info
|
|
|
|
__all__ = [
|
|
"PropagateNan",
|
|
"TRITON_MAX_TENSOR_NUMEL",
|
|
"load_tensor_descriptor",
|
|
"store_tensor_descriptor",
|
|
"make_tensor_descriptor",
|
|
"tensor_descriptor",
|
|
"abs",
|
|
"add",
|
|
"advance",
|
|
"arange",
|
|
"argmax",
|
|
"argmin",
|
|
"associative_scan",
|
|
"assume",
|
|
"async_task",
|
|
"atomic_add",
|
|
"atomic_and",
|
|
"atomic_cas",
|
|
"atomic_max",
|
|
"atomic_min",
|
|
"atomic_or",
|
|
"atomic_xchg",
|
|
"atomic_xor",
|
|
"bfloat16",
|
|
"bitonic_merge",
|
|
"block_type",
|
|
"broadcast",
|
|
"broadcast_to",
|
|
"cat",
|
|
"cast",
|
|
"cdiv",
|
|
"ceil",
|
|
"clamp",
|
|
"condition",
|
|
"const",
|
|
"constexpr",
|
|
"constexpr_type",
|
|
"cos",
|
|
"cumprod",
|
|
"cumsum",
|
|
"debug_barrier",
|
|
"device_assert",
|
|
"device_print",
|
|
"div_rn",
|
|
"dot",
|
|
"dot_scaled",
|
|
"dtype",
|
|
"erf",
|
|
"exp",
|
|
"exp2",
|
|
"expand_dims",
|
|
"extra",
|
|
"fdiv",
|
|
"flip",
|
|
"float16",
|
|
"float32",
|
|
"float64",
|
|
"float8e4b15",
|
|
"float8e4nv",
|
|
"float8e4b8",
|
|
"float8e5",
|
|
"float8e5b16",
|
|
"floor",
|
|
"fma",
|
|
"full",
|
|
"gather",
|
|
"histogram",
|
|
"inline_asm_elementwise",
|
|
"interleave",
|
|
"int1",
|
|
"int16",
|
|
"int32",
|
|
"int64",
|
|
"int8",
|
|
"join",
|
|
"load",
|
|
"log",
|
|
"log2",
|
|
"make_block_ptr",
|
|
"map_elementwise",
|
|
"math",
|
|
"max",
|
|
"max_constancy",
|
|
"max_contiguous",
|
|
"maximum",
|
|
"min",
|
|
"minimum",
|
|
"multiple_of",
|
|
"num_programs",
|
|
"pair_uniform_to_normal",
|
|
"permute",
|
|
"philox",
|
|
"philox_impl",
|
|
"pi32_t",
|
|
"pointer_type",
|
|
"program_id",
|
|
"rand",
|
|
"rand4x",
|
|
"randint",
|
|
"randint4x",
|
|
"randn",
|
|
"randn4x",
|
|
"range",
|
|
"ravel",
|
|
"reduce",
|
|
"reduce_or",
|
|
"reshape",
|
|
"rsqrt",
|
|
"slice",
|
|
"sigmoid",
|
|
"sin",
|
|
"softmax",
|
|
"sort",
|
|
"split",
|
|
"sqrt",
|
|
"sqrt_rn",
|
|
"static_assert",
|
|
"static_print",
|
|
"static_range",
|
|
"store",
|
|
"sum",
|
|
"swizzle2d",
|
|
"target_info",
|
|
"tensor",
|
|
"topk",
|
|
"trans",
|
|
"tuple",
|
|
"uint16",
|
|
"uint32",
|
|
"uint64",
|
|
"uint8",
|
|
"uint_to_uniform_float",
|
|
"umulhi",
|
|
"view",
|
|
"void",
|
|
"where",
|
|
"xor_sum",
|
|
"zeros",
|
|
"zeros_like",
|
|
]
|
|
|
|
|
|
def str_to_ty(name, c):
|
|
from builtins import tuple
|
|
|
|
if isinstance(name, tuple):
|
|
fields = type(name).__dict__.get("_fields", None)
|
|
return tuple_type([str_to_ty(x, c) for x in name], fields)
|
|
|
|
if name[0] == "*":
|
|
name = name[1:]
|
|
const = False
|
|
if name[0] == "k":
|
|
name = name[1:]
|
|
const = True
|
|
ty = str_to_ty(name, c)
|
|
return pointer_type(element_ty=ty, const=const)
|
|
|
|
if name.startswith("tensordesc"):
|
|
inner = name.split("<")[1].rstrip(">")
|
|
dtype, rest = inner.split("[", maxsplit=1)
|
|
block_shape, rest = rest.split("]", maxsplit=1)
|
|
block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
|
|
layout = rest.lstrip(",")
|
|
is_gluon = len(layout)
|
|
dtype = str_to_ty(dtype, None)
|
|
ndim = len(block_shape)
|
|
shape_type = tuple_type([int32] * ndim)
|
|
# FIXME: Last dim stride should be constexpr(1)
|
|
stride_type = tuple_type(([int64] * ndim))
|
|
block = block_type(dtype, block_shape)
|
|
if is_gluon:
|
|
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type
|
|
layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout))
|
|
assert isinstance(layout, NVMMASharedLayout)
|
|
return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
|
|
return tensor_descriptor_type(block, shape_type, stride_type)
|
|
|
|
if name.startswith("constexpr"):
|
|
return constexpr_type(c)
|
|
|
|
tys = {
|
|
"fp8e4nv": float8e4nv,
|
|
"fp8e4b8": float8e4b8,
|
|
"fp8e5": float8e5,
|
|
"fp8e5b16": float8e5b16,
|
|
"fp8e4b15": float8e4b15,
|
|
"fp16": float16,
|
|
"bf16": bfloat16,
|
|
"fp32": float32,
|
|
"fp64": float64,
|
|
"i1": int1,
|
|
"i8": int8,
|
|
"i16": int16,
|
|
"i32": int32,
|
|
"i64": int64,
|
|
"u1": int1,
|
|
"u8": uint8,
|
|
"u16": uint16,
|
|
"u32": uint32,
|
|
"u64": uint64,
|
|
"B": int1,
|
|
}
|
|
return tys[name]
|