535 lines
16 KiB
Python
535 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
from ..runtime.jit import jit, constexpr_function
|
|
from . import core
|
|
from . import math
|
|
|
|
# constexpr utilities
|
|
|
|
|
|
@constexpr_function
|
|
def _log2(i):
|
|
log2 = 0
|
|
n = i
|
|
while n > 1:
|
|
n >>= 1
|
|
log2 += 1
|
|
return log2
|
|
|
|
|
|
@constexpr_function
|
|
def _is_power_of_two(i):
|
|
return (i & (i - 1)) == 0 and i != 0
|
|
|
|
|
|
# -----------------------
|
|
# Standard library
|
|
# -----------------------
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
def cdiv(x, div):
|
|
"""
|
|
Computes the ceiling division of :code:`x` by :code:`div`
|
|
|
|
:param x: the input number
|
|
:type x: Block
|
|
:param div: the divisor
|
|
:type div: Block
|
|
"""
|
|
return (x + div - 1) // div
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@math._add_math_1arg_docstr("sigmoid")
|
|
def sigmoid(x):
|
|
return 1 / (1 + math.exp(-x))
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@math._add_math_1arg_docstr("softmax")
|
|
def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
|
|
if dim is None:
|
|
_dim: core.constexpr = 0
|
|
else:
|
|
_dim: core.constexpr = dim
|
|
z = x - max(x, _dim, keep_dims=keep_dims)
|
|
num = math.exp(z)
|
|
den = sum(num, _dim, keep_dims=keep_dims)
|
|
return math.fdiv(num, den, ieee_rounding)
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
def ravel(x, can_reorder=False):
|
|
"""
|
|
Returns a contiguous flattened view of :code:`x`.
|
|
|
|
:param x: the input tensor
|
|
:type x: Block
|
|
"""
|
|
return core.reshape(x, [x.numel], can_reorder=can_reorder)
|
|
|
|
|
|
@jit
|
|
def swizzle2d(i, j, size_i, size_j, size_g):
|
|
"""
|
|
Transforms the indices of a row-major `size_i * size_j` matrix into
|
|
the indices of a column-major matrix for each group of `size_g` rows.
|
|
|
|
For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
|
|
transform ::
|
|
|
|
[[0 , 1 , 2 , 3 ],
|
|
[4 , 5 , 6 , 7 ],
|
|
[8 , 9 , 10, 11],
|
|
[12, 13, 14, 15]]
|
|
|
|
into ::
|
|
|
|
[[0, 2, 4 , 6 ],
|
|
[1, 3, 5 , 7 ],
|
|
[8, 10, 12, 14],
|
|
[9, 11, 13, 15]]
|
|
"""
|
|
# "unrolled index in array"
|
|
ij = i * size_j + j
|
|
# number of elements in `size_g` groups
|
|
# of `size_j` columns
|
|
size_gj = size_g * size_j
|
|
# index of the group in which (i,j) is
|
|
group_id = ij // size_gj
|
|
# row-index of the first element of this group
|
|
off_i = group_id * size_g
|
|
# last group may have fewer rows
|
|
size_g = core.minimum(size_i - off_i, size_g)
|
|
# linear index with respect to the first element in this group
|
|
ij = ij % size_gj
|
|
# new row and column indices
|
|
new_i = off_i + ij % size_g
|
|
new_j = ij // size_g
|
|
return new_i, new_j
|
|
|
|
|
|
@jit
|
|
def zeros(shape, dtype):
|
|
"""
|
|
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
|
|
|
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
|
:type shape: tuple of ints
|
|
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
|
:type dtype: DType
|
|
"""
|
|
return core.full(shape, 0, dtype)
|
|
|
|
|
|
@jit
|
|
def zeros_like(input):
|
|
"""
|
|
Returns a tensor of zeros with the same shape and type as a given tensor.
|
|
|
|
:param input: input tensor
|
|
:type input: Tensor
|
|
"""
|
|
return zeros(input.shape, input.dtype)
|
|
|
|
|
|
# max and argmax
|
|
|
|
|
|
@jit
|
|
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
|
|
if tie_break_left:
|
|
tie = value1 == value2 and index1 < index2
|
|
else:
|
|
tie = False
|
|
gt = value1 > value2 or tie
|
|
v_ret = core.where(gt, value1, value2)
|
|
i_ret = core.where(gt, index1, index2)
|
|
return v_ret, i_ret
|
|
|
|
|
|
@jit
|
|
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
|
|
return _argmax_combine(value1, index1, value2, index2, True)
|
|
|
|
|
|
@jit
|
|
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
|
return _argmax_combine(value1, index1, value2, index2, False)
|
|
|
|
|
|
@jit
|
|
def _elementwise_max(a, b):
|
|
return core.maximum(a, b)
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
|
|
tie_break_arg="return_indices_tie_break_left")
|
|
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
input = core._promote_bfloat16_to_float32(input)
|
|
if return_indices:
|
|
if return_indices_tie_break_left:
|
|
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
|
|
else:
|
|
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
|
|
else:
|
|
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
|
|
if core.constexpr(input.dtype.is_floating()):
|
|
input = input.to(core.float32)
|
|
else:
|
|
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
input = input.to(core.int32)
|
|
return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
|
|
def argmax(input, axis, tie_break_left=True, keep_dims=False):
|
|
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
return ret
|
|
|
|
|
|
# min and argmin
|
|
|
|
|
|
@jit
|
|
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
|
|
if tie_break_left:
|
|
tie = value1 == value2 and index1 < index2
|
|
else:
|
|
tie = False
|
|
lt = value1 < value2 or tie
|
|
value_ret = core.where(lt, value1, value2)
|
|
index_ret = core.where(lt, index1, index2)
|
|
return value_ret, index_ret
|
|
|
|
|
|
@jit
|
|
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
|
|
return _argmin_combine(value1, index1, value2, index2, True)
|
|
|
|
|
|
@jit
|
|
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
|
return _argmin_combine(value1, index1, value2, index2, False)
|
|
|
|
|
|
@jit
|
|
def _elementwise_min(a, b):
|
|
return core.minimum(a, b)
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
|
|
tie_break_arg="return_indices_tie_break_left")
|
|
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
|
|
input = core._promote_bfloat16_to_float32(input)
|
|
if return_indices:
|
|
if return_indices_tie_break_left:
|
|
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
|
|
else:
|
|
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
|
|
else:
|
|
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
|
if core.constexpr(input.dtype.is_floating()):
|
|
input = input.to(core.float32)
|
|
else:
|
|
assert input.dtype.is_int(), "Expecting input to be integer type"
|
|
input = input.to(core.int32)
|
|
return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
|
|
def argmin(input, axis, tie_break_left=True, keep_dims=False):
|
|
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
|
|
return ret
|
|
|
|
|
|
@jit
|
|
def _sum_combine(a, b):
|
|
return a + b
|
|
|
|
|
|
# sum
|
|
|
|
|
|
@constexpr_function
|
|
def _pick_sum_dtype(in_dtype, dtype):
|
|
if dtype is not None:
|
|
return dtype
|
|
|
|
# For integer bitwidths less than 32, pick int32 with the same sign to
|
|
# avoid overflow.
|
|
out_dtype = None
|
|
if in_dtype.is_int_signed():
|
|
out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
|
|
elif in_dtype.is_int_unsigned():
|
|
out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
|
|
return out_dtype
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("sum", dtype_arg="dtype")
|
|
def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
|
|
# Pick a default dtype for the reduction if one was not specified.
|
|
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
|
|
if out_dtype is not None:
|
|
input = input.to(out_dtype)
|
|
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
|
|
|
|
|
|
@jit
|
|
def _xor_combine(a, b):
|
|
return a ^ b
|
|
|
|
|
|
# xor sum
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("xor sum")
|
|
def xor_sum(input, axis=None, keep_dims=False):
|
|
core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
|
|
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
|
|
|
|
|
|
# or reduction
|
|
|
|
|
|
@jit
|
|
def _or_combine(x, y):
|
|
return x | y
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_reduction_docstr("reduce_or")
|
|
def reduce_or(input, axis, keep_dims=False):
|
|
core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
|
|
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
|
|
|
|
|
|
# cumsum
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_scan_docstr("cumsum", dtype_arg="dtype")
|
|
def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
|
|
# todo rename this to a generic function name
|
|
|
|
input = core._promote_bfloat16_to_float32(input)
|
|
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
|
|
if out_dtype is not None:
|
|
input = input.to(out_dtype)
|
|
|
|
return core.associative_scan(input, axis, _sum_combine, reverse)
|
|
|
|
|
|
# cumprod
|
|
|
|
|
|
@jit
|
|
def _prod_combine(a, b):
|
|
return a * b
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
@core._add_scan_docstr("cumprod")
|
|
def cumprod(input, axis=0, reverse=False):
|
|
# todo rename this to a generic function name
|
|
input = core._promote_bfloat16_to_float32(input)
|
|
return core.associative_scan(input, axis, _prod_combine, reverse)
|
|
|
|
|
|
# sort
|
|
|
|
|
|
@jit
|
|
def _indicator(n_dims: core.constexpr, j: core.constexpr):
|
|
ar = core.arange(0, 2)
|
|
ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
|
|
return ar
|
|
|
|
|
|
@jit
|
|
def _compare_and_swap(x, flip, i: core.constexpr):
|
|
# compare-and-swap on the ith *innermost* dimension
|
|
n_dims: core.constexpr = _log2(x.numel)
|
|
|
|
# flip along middle dimension (the bitwise XORs will be optimised away):
|
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
ix = x.to(idtype, bitcast=True)
|
|
iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
|
|
y = iy.to(x.dtype, bitcast=True)
|
|
|
|
# determines whether we are in the right (rather than left) position along the axis:
|
|
is_right = _indicator(n_dims, i)
|
|
|
|
# conditional swap:
|
|
ret = core.where((x > y) != (flip ^ is_right), y, x)
|
|
return ret
|
|
|
|
|
|
@jit
|
|
def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
|
|
'''
|
|
order_type 0 == ascending
|
|
order_type 1 == descending
|
|
order_type 2 == alternating
|
|
'''
|
|
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
# descending order.
|
|
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
# a stride of 2) at this stage
|
|
if order == 2:
|
|
flip = _indicator(_log2(x.numel), stage)
|
|
else:
|
|
flip = order
|
|
# perform `stage` rounds of `compare-and-swap`
|
|
for i in core.static_range(stage):
|
|
x = _compare_and_swap(x, flip, stage - 1 - i)
|
|
return x
|
|
|
|
|
|
@jit
|
|
def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
|
|
h = core.reshape(x, [2] * _log2(x.numel))
|
|
h = _bitonic_merge_hypercube(h, stage, order)
|
|
x = core.reshape(h, x.shape)
|
|
return x
|
|
|
|
|
|
@jit
|
|
def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
"""
|
|
Sorts a tensor along a specified dimension.
|
|
|
|
:param x: The input tensor to be sorted.
|
|
:type x: Tensor
|
|
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
|
|
:type dim: int, optional
|
|
:param k: the number of top elements to select. If none, assume k = x.shape[dim]
|
|
:type k: int, optional
|
|
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
|
|
:type descending: bool, optional
|
|
"""
|
|
# handle default dimension or check that it is the most minor dim
|
|
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
|
|
log_n: core.constexpr = _log2(x.shape[_dim])
|
|
log_k: core.constexpr = log_n if k is None else _log2(k)
|
|
|
|
n_dims: core.constexpr = _log2(x.numel)
|
|
|
|
# reshape to hypercube:
|
|
h = core.reshape(x, [2] * n_dims)
|
|
|
|
# run first log_k bitonic sort iterations:
|
|
for i in core.static_range(1, log_k + 1):
|
|
h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
|
|
|
|
# select top k elements using bitonic top-k
|
|
# https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
|
|
for i in core.static_range(log_k + 1, log_n + 1):
|
|
h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
|
|
h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
|
|
|
|
# reshape back:
|
|
x = core.reshape(h, x.shape[:-1] + [2**log_k])
|
|
return x
|
|
|
|
|
|
@jit
|
|
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
return sort_impl(x, dim=dim, descending=descending)
|
|
|
|
|
|
@jit
|
|
def topk(x, k: core.constexpr, dim: core.constexpr = None):
|
|
return sort_impl(x, k=k, dim=dim, descending=True)
|
|
|
|
|
|
@jit
|
|
def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
# handle default dimension or check that it is the most minor dim
|
|
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
n_dims: core.constexpr = _log2(x.shape[-1])
|
|
return _bitonic_merge(x, n_dims, descending, n_dims)
|
|
|
|
|
|
@constexpr_function
|
|
def _get_flip_dim(dim, shape):
|
|
if dim is None:
|
|
dim = len(shape) - 1
|
|
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
|
dim += len(shape)
|
|
return dim
|
|
|
|
|
|
@core._tensor_member_fn
|
|
@jit
|
|
def flip(x, dim=None):
|
|
"""
|
|
Flips a tensor `x` along the dimension `dim`.
|
|
|
|
:param x: the first input tensor
|
|
:type x: Block
|
|
:param dim: the dimension to flip along
|
|
:type dim: int
|
|
"""
|
|
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
|
|
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
|
|
core.static_assert(_is_power_of_two(x.shape[_dim]))
|
|
steps: core.constexpr = _log2(x.shape[_dim])
|
|
|
|
# reshape the swap dimension to (2, 2, ..., 2)
|
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
|
|
for i in core.static_range(steps):
|
|
y = y ^ xor_sum(y, _dim + i, True)
|
|
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
|
|
return x
|
|
|
|
|
|
@jit
|
|
def interleave(a, b):
|
|
"""
|
|
Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
|
|
Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
|
|
|
|
:param a: The first input tensor.
|
|
:type a: Tensor
|
|
:param b: The second input tensor.
|
|
:type b: Tensor
|
|
"""
|
|
c = core.join(a, b)
|
|
|
|
if len(c.shape) == 1:
|
|
# We must have interleaved two scalars.
|
|
return c
|
|
else:
|
|
# This `else` is necessary because Triton's AST parser doesn't
|
|
# understand that if we take the `if` above we definitely don't run this
|
|
# `else`.
|
|
return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
|