DriverTrac/venv/lib/python3.12/site-packages/triton/tools/ragged_tma.py
2025-11-28 09:08:33 +05:30

93 lines
3.1 KiB
Python

import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
# fmt: off
def create_ragged_descriptor(T, block_shape, ragged_dim=0):
"""
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
which behaves like a concatenation (along the first axis) of subarrays
of potentially unequal size.
The load_ragged and store_ragged device functions can be used to read
and write from subarrays T[batch_offset : batch_offset + batch_size]
with hardware bounds-checking preventing any sort of leakage outside
the subarray.
"""
block_shape = list(block_shape)
tensor_shape = list(T.shape)
rank = len(tensor_shape)
if ragged_dim < 0:
ragged_dim += rank
assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
assert len(block_shape) == rank, "block shape must have same length as tensor shape"
max_int = 0x7fff0000
billion = 0x40000000 # == 2**30
assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
tensor_shape[ragged_dim] = billion
ragged_stride = T.stride(ragged_dim)
# we prepend an extra two dimensions and rely on the fact that pointers
# have 64-bit wraparound semantics:
tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
tma_shape = [max_int, max_int] + tensor_shape
box_shape = [1, 1] + block_shape
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
@triton.jit
def to_ragged_indices(batch_offset, batch_size, row):
"""
Helper function for load_ragged and store_ragged.
"""
billion = 0x40000000 # == 2**30
x = billion - batch_size + row
y = batch_offset + batch_size
return billion, y, x
@triton.jit
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
"""
Read from a subarray T[batch_offset : batch_offset + batch_size] with
hardware bounds-checking, where reading outside the subarray gives zeros.
Coords should be an appropriately-sized list of integers, just like in
TMA.load().
"""
tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
data = tl.reshape(data, data.shape[2:])
return data
@triton.jit
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
"""
Write to a subarray T[batch_offset : batch_offset + batch_size] with
hardware bounds-checking, where writes outside the subarray are masked
correctly.
Coords should be an appropriately-sized list of integers, just like in
TMA.store().
"""
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
data = tl.reshape(data, [1, 1] + data.shape)
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)