1485 lines
57 KiB
Python
1485 lines
57 KiB
Python
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import dataclasses
|
|
import math
|
|
from typing import Any, cast, Callable
|
|
|
|
import itertools
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import llvm
|
|
from jaxlib.mlir.dialects import memref
|
|
import numpy as np
|
|
|
|
from . import utils
|
|
from . import fragmented_array as fa
|
|
from . import mma_utils
|
|
from .launch_context import LaunchContext
|
|
|
|
|
|
TMEM_ROWS = 128
|
|
TMEM_MAX_COLS = 512
|
|
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
|
|
LAYOUT = fa.TCGEN05_LAYOUT
|
|
TRANSPOSED_LAYOUT = fa.TCGEN05_TRANSPOSED_LAYOUT
|
|
ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT
|
|
COL_LAYOUT = fa.TCGEN05_COL_LAYOUT
|
|
TMEM_NATIVE_LAYOUT = fa.TMEM_NATIVE_LAYOUT
|
|
|
|
|
|
def create_instr_descriptor(
|
|
m: int,
|
|
n: int,
|
|
acc_dtype,
|
|
input_dtype,
|
|
transpose_a: bool = False,
|
|
transpose_b: bool = False,
|
|
sparsity_selector: int | None = None,
|
|
):
|
|
f16 = ir.F16Type.get()
|
|
f32 = ir.F32Type.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
|
|
desc = 0
|
|
if sparsity_selector is not None:
|
|
assert 0 <= sparsity_selector < 3
|
|
desc |= sparsity_selector
|
|
desc |= 1 << 2 # Enable sparsity
|
|
if acc_dtype == f16:
|
|
d_type_val = 0
|
|
elif acc_dtype == f32:
|
|
d_type_val = 1
|
|
elif acc_dtype == i32:
|
|
d_type_val = 2
|
|
else:
|
|
raise NotImplementedError(f"Unsupported accumulator dtype: {acc_dtype}")
|
|
desc |= (d_type_val << 4) # D type, bits 4-5
|
|
# Bit 6 is reserved
|
|
if input_dtype == f16:
|
|
assert acc_dtype in {f16, f32}
|
|
ab_type_val = 0
|
|
elif input_dtype == ir.BF16Type.get():
|
|
assert acc_dtype == f32
|
|
ab_type_val = 1
|
|
elif input_dtype == ir.Float8E4M3FNType.get():
|
|
assert acc_dtype in {f16, f32}
|
|
ab_type_val = 0
|
|
elif input_dtype == ir.Float8E5M2Type.get():
|
|
assert acc_dtype in {f16, f32}
|
|
ab_type_val = 1
|
|
elif input_dtype == ir.IntegerType.get_signless(8): # Only s8 for now.
|
|
assert acc_dtype == i32
|
|
ab_type_val = 1
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input dtype: {input_dtype}")
|
|
desc |= (ab_type_val << 7) # A dtype, bits 7-9
|
|
desc |= (ab_type_val << 10) # B dtype, bits 10-12
|
|
# We ignore negate bits 13-14
|
|
desc |= transpose_a << 15 # Transpose A
|
|
desc |= transpose_b << 16 # Transpose B
|
|
if n % 8 or n > 256:
|
|
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
|
desc |= (n >> 3) << 17 # N, bits 17-22
|
|
# Bit 23 is reserved
|
|
if m % 16 or m > 256:
|
|
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
|
|
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
|
|
# Bit 29 is reserved
|
|
# We ignore max shift under .ws, bits 30-31
|
|
return arith.constant(ir.IntegerType.get_signless(32), desc)
|
|
|
|
|
|
def _create_scaled_instr_descriptor(
|
|
get_input_encoding: Callable[[ir.Type], int],
|
|
m: int,
|
|
n: int,
|
|
a_type: ir.Type,
|
|
b_type: ir.Type,
|
|
a_scale_idx: int,
|
|
b_scale_idx: int,
|
|
transpose_a: bool = False,
|
|
transpose_b: bool = False,
|
|
):
|
|
desc = 0
|
|
# Bits 0, 1 are reserved
|
|
# We ignore sparsity (bit 2)
|
|
# Bit 3 is reserved
|
|
assert 0 <= b_scale_idx < 4
|
|
desc |= b_scale_idx << 4 # B scale factor data ID, bits 4-5
|
|
# Bit 6 is reserved
|
|
desc |= get_input_encoding(a_type) << 7 # A dtype, bits 7-9
|
|
desc |= get_input_encoding(b_type) << 10 # B dtype, bits 10-12
|
|
# We ignore negate bits 13-14
|
|
desc |= transpose_a << 15 # Transpose A
|
|
desc |= transpose_b << 16 # Transpose B
|
|
if n % 8 or n > 256:
|
|
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
|
desc |= (n >> 3) << 17 # N, bits 17-22
|
|
desc |= 1 << 23 # Scale matrix type
|
|
# Bits 24-26 are reserved
|
|
if m % 128 or m > 256:
|
|
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
|
|
desc |= (m >> 7) << 27 # M >> 7, bits 27-28
|
|
desc |= a_scale_idx << 29 # A scale factor data ID, bits 29-30
|
|
# Bit 31 is reserved
|
|
return arith.constant(ir.IntegerType.get_signless(32), desc)
|
|
|
|
|
|
def create_scaled_f8f6f4_instr_descriptor(*args, **kwargs):
|
|
def get_input_encoding(ty):
|
|
if ty == ir.Float8E4M3FNType.get():
|
|
return 0
|
|
elif ty == ir.Float8E5M2Type.get():
|
|
return 1
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input dtype: {ty}")
|
|
return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs)
|
|
|
|
|
|
def create_scaled_f4_instr_descriptor(*args, **kwargs):
|
|
def get_input_encoding(ty):
|
|
if ty == ir.Float4E2M1FNType.get():
|
|
return 1
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input dtype: {ty}")
|
|
return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs)
|
|
|
|
|
|
def mma(
|
|
d: TMEMRef,
|
|
a: ir.Value | TMEMRef,
|
|
b: ir.Value,
|
|
*,
|
|
a_swizzle: int = 128,
|
|
b_swizzle: int = 128,
|
|
a_scale: TMEMRef | None = None,
|
|
b_scale: TMEMRef | None = None,
|
|
a_sparse_metadata: TMEMRef | None = None,
|
|
accumulate: ir.Value | bool = True,
|
|
collective: bool = False,
|
|
) -> None:
|
|
if a_swizzle == 16 or b_swizzle == 16:
|
|
raise NotImplementedError("No swizzle is not supported")
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
if isinstance(accumulate, bool):
|
|
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
|
num_cta = 2 if collective else 1
|
|
if (is_scaled := a_scale is not None) != (b_scale is not None):
|
|
raise ValueError("Either none or both scales should be provided")
|
|
is_sparse = a_sparse_metadata is not None
|
|
if is_scaled and is_sparse:
|
|
raise NotImplementedError("Block-scaled sparse matmuls unsupported")
|
|
|
|
# Step 1. Establish the shape and element type of the operation.
|
|
if not ir.MemRefType.isinstance(b.type):
|
|
raise ValueError(f"B must be a memref, got: {b.type}")
|
|
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
|
if isinstance(a, TMEMRef):
|
|
m, k2 = a.shape
|
|
element_type2 = a.dtype
|
|
if is_scaled or is_sparse:
|
|
raise NotImplementedError("A in TMEM unsupported for block-scaled and sparse matmuls")
|
|
if m != 128:
|
|
raise NotImplementedError(f"Only M=128 is supported for MMA with A in TMEM, but got M={m}")
|
|
# Watch out: this layout must be consistent with D's layout (up to packing).
|
|
expected_packing = 32 // utils.bitwidth(element_type)
|
|
expected_layout = _infer_tmem_layout(
|
|
a.shape, collective, packing=expected_packing
|
|
)
|
|
if a.layout != expected_layout:
|
|
raise ValueError(
|
|
f"A layout mismatch: expected {expected_layout}, got {a.layout}"
|
|
)
|
|
else:
|
|
if not ir.MemRefType.isinstance(a.type):
|
|
raise ValueError(f"A must be a memref, got {a.type}")
|
|
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
|
if is_sparse:
|
|
k2 *= 2
|
|
if k != k2:
|
|
raise ValueError(
|
|
"MMA requires A and B to have the same contraction dimension (K),"
|
|
f" got: {k2} and {k}"
|
|
)
|
|
if element_type != element_type2:
|
|
raise ValueError(
|
|
"MMA requires A and B to have the same element type, got:"
|
|
f" {element_type2} and {element_type}"
|
|
)
|
|
if d.shape != (m, n * num_cta):
|
|
raise ValueError(
|
|
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
|
|
)
|
|
if m == 128:
|
|
if d.layout != (expected_d_layout := tmem_default_layout(packing=1)):
|
|
raise ValueError(
|
|
f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}"
|
|
)
|
|
n_lane_groups = 1
|
|
elif m == 64:
|
|
if is_scaled:
|
|
raise NotImplementedError("MMA with block scaling is not supported for M=64")
|
|
if is_sparse:
|
|
raise NotImplementedError("Sparse MMA not supported for M=64")
|
|
# Watch out: this layout must be consistent with A's layout (up to packing).
|
|
# 2CTA M=128 instruction uses a different TMEM layout than 1CTA M=64.
|
|
expected_d_layout = _infer_tmem_layout(d.shape, collective, packing=1)
|
|
if d.layout != expected_d_layout:
|
|
raise ValueError(
|
|
f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}"
|
|
)
|
|
if collective:
|
|
n_lane_groups = 1
|
|
else:
|
|
n_lane_groups = 2
|
|
# We can't split N into groups if we would partition it below the tile size.
|
|
# TODO: We only need to check this if N is the minormost dim in B.
|
|
if 8 * b_swizzle // utils.bitwidth(element_type) > n // n_lane_groups:
|
|
raise ValueError("Swizzle is too big for MMA with M=64. Try lowering it.")
|
|
else:
|
|
raise ValueError(f"Only M=128 and M=64 are supported for MMA, but got M={m}")
|
|
f32 = ir.F32Type.get()
|
|
f16 = ir.F16Type.get()
|
|
s32 = ir.IntegerType.get_signless(32)
|
|
if element_type == f32 or element_type == ir.BF16Type.get():
|
|
if element_type == f32 and is_sparse:
|
|
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
|
|
if is_scaled:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} does not support block scaling"
|
|
)
|
|
if d.dtype != f32:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports accumulators"
|
|
f" of type f32, but got: {d.dtype}"
|
|
)
|
|
elif element_type == f16:
|
|
if is_scaled:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} does not support block scaling"
|
|
)
|
|
if d.dtype != f16 and d.dtype != f32:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports accumulators of"
|
|
f" type f32 or f16, but got: {d.dtype}"
|
|
)
|
|
elif any(
|
|
t.isinstance(element_type)
|
|
for t in {ir.Float8E5M2Type, ir.Float8E4M3FNType}
|
|
):
|
|
if is_sparse:
|
|
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
|
|
if d.dtype != f16 and d.dtype != f32:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports accumulators of"
|
|
f" type f32 or f16, but got: {d.dtype}"
|
|
)
|
|
if is_scaled and d.dtype != f32:
|
|
raise ValueError(
|
|
f"Block-scaled MMA with element type {element_type} only supports f32"
|
|
f" accumulators, but got: {d.dtype}"
|
|
)
|
|
elif any(
|
|
t.isinstance(element_type) for t in {ir.Float4E2M1FNType}
|
|
):
|
|
if is_sparse:
|
|
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
|
|
if not is_scaled:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports block scaling"
|
|
)
|
|
if d.dtype != f32:
|
|
raise ValueError(
|
|
f"Block-scaled MMA with element type {element_type} only supports f32"
|
|
f" accumulators, but got: {d.dtype}"
|
|
)
|
|
elif element_type == ir.IntegerType.get_signless(8):
|
|
if is_sparse:
|
|
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
|
|
if is_scaled:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} does not support block scaling"
|
|
)
|
|
if d.dtype != s32:
|
|
raise ValueError(
|
|
"MMA with element type s8 only supports s32 accumulators, but got:"
|
|
f" {d.dtype}"
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported element type: {element_type}")
|
|
|
|
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
|
# instructions must be issued in groups of the same width as the swizzle.
|
|
m_group_elems = m # We have already verified M is supported above.
|
|
k_group_elems = 8 * max(a_swizzle * (1 + is_sparse), b_swizzle) // utils.bitwidth(element_type)
|
|
if is_sparse and k_group_elems < 64:
|
|
# This is a limitation of the implementation below. We could relax it if we
|
|
# ever need to support k=32.
|
|
k_group_elems = 64
|
|
if n % 8:
|
|
raise ValueError(f"N must be a multiple of 8, got: {n}")
|
|
if n.bit_count() != 1:
|
|
raise ValueError(f"N must be a power of 2, got: {n}")
|
|
# TODO: We could relax those constraints if we have multiple n_lane_groups,
|
|
# since we will be unrolling the instructions anyway.
|
|
if collective and n > 128:
|
|
raise ValueError("Only N <= 128 are supported for collective MMA")
|
|
elif n > 512:
|
|
raise ValueError("Only N <= 512 are supported for MMA")
|
|
n_group_elems = min(n // n_lane_groups, 256 // num_cta)
|
|
if m % m_group_elems:
|
|
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
|
if k % k_group_elems:
|
|
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
|
if n % n_group_elems:
|
|
raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}")
|
|
m_groups = m // m_group_elems
|
|
k_groups = k // k_group_elems
|
|
n_groups = n // n_group_elems
|
|
# TODO(apaszke): Require users to bitcast input refs to tf32 before MMA.
|
|
mma_element_type = (
|
|
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
|
)
|
|
|
|
# Check that the shapes and element types are correct for block scaling.
|
|
if is_scaled:
|
|
if collective:
|
|
raise NotImplementedError("MMA with block scaling does not support collective")
|
|
assert m == 128 # Checked above.
|
|
if n % 32:
|
|
raise ValueError(
|
|
f"MMA with block scaling requires N to be divisible by 32, got: {n}"
|
|
)
|
|
if k_group_elems != 128 or a_swizzle != b_swizzle:
|
|
assert utils.bitwidth(element_type) <= 8
|
|
expected_swizzle = 128 // (8 // utils.bitwidth(element_type))
|
|
raise NotImplementedError(
|
|
"MMA with block scaling requires swizzle to be"
|
|
f" {expected_swizzle} for dtype {element_type}, got:"
|
|
f" {a_swizzle=} and {b_swizzle=}"
|
|
)
|
|
assert a_scale is not None and b_scale is not None
|
|
if a_scale.shape != (m, 4):
|
|
raise ValueError(
|
|
f"A scale shape mismatch: expected ({m}, 4), got {a_scale.shape}"
|
|
)
|
|
if a_scale.dtype != ir.Float8E8M0FNUType.get():
|
|
raise ValueError(
|
|
f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}"
|
|
)
|
|
if b_scale.shape != (n, 4):
|
|
raise ValueError(
|
|
f"B scale shape mismatch: expected ({n}, 4), got {b_scale.shape}"
|
|
)
|
|
if b_scale.dtype != ir.Float8E8M0FNUType.get():
|
|
raise ValueError(
|
|
f"B scale dtype mismatch: expected f8e8m0fnu, got {b_scale.dtype}"
|
|
)
|
|
if is_sparse:
|
|
a_sparse_metadata = cast(TMEMRef, a_sparse_metadata)
|
|
if collective:
|
|
raise NotImplementedError("Collective sparse MMA unsupported")
|
|
if n % 32:
|
|
raise ValueError(f"Sparse MMA requires N to be divisible by 32, got: {n}")
|
|
if a_sparse_metadata.shape != (m, k // 2):
|
|
raise ValueError(
|
|
f"A sparse metadata shape mismatch: expected {(m, k // 2)}, got"
|
|
f" {a_sparse_metadata.shape}"
|
|
)
|
|
if a_sparse_metadata.dtype != ir.IntegerType.get_signless(2):
|
|
raise ValueError(
|
|
"A sparse metadata dtype mismatch: expected i2, got"
|
|
f" {a_sparse_metadata.dtype}"
|
|
)
|
|
|
|
# Step 3. Compute the operand descriptors.
|
|
if not isinstance(a, TMEMRef):
|
|
# Both dense and sparse matmul consume A with a K bytewidth of 32, only
|
|
# the group size is halved when it's sparse.
|
|
(
|
|
(a_desc_base, a_k_instr_strides),
|
|
(a_m_group_stride, a_k_group_stride),
|
|
a_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
a,
|
|
swizzle=a_swizzle,
|
|
group_size=(m_group_elems, k_group_elems // (1 + is_sparse)),
|
|
logical_k_major=False,
|
|
mma_bytewidth_k=32,
|
|
)
|
|
else:
|
|
a_fastest = mma_utils.Dim.K
|
|
a_k_instr_strides = None
|
|
a_m_group_stride = a_k_group_stride = a_desc_base = None
|
|
(
|
|
(b_desc_base, b_k_instr_strides),
|
|
(b_n_group_stride, b_k_group_stride),
|
|
b_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
b,
|
|
swizzle=b_swizzle,
|
|
group_size=(k_group_elems, n_group_elems),
|
|
logical_k_major=True,
|
|
mma_bytewidth_k=64 if is_sparse else 32,
|
|
)
|
|
|
|
if is_scaled and utils.bitwidth(mma_element_type) == 4:
|
|
if a_fastest != mma_utils.Dim.K:
|
|
raise ValueError(
|
|
"4-bit block scaled MMA only supports K-fastest operands, but A is M-fastest"
|
|
)
|
|
if b_fastest != mma_utils.Dim.K:
|
|
raise ValueError(
|
|
"4-bit block scaled MMA only supports K-fastest operands, but B is N-fastest"
|
|
)
|
|
if is_sparse:
|
|
if b_swizzle == 32 and b_fastest == mma_utils.Dim.K:
|
|
raise NotImplementedError(
|
|
"B tiling too small. Increase swizzle or transpose the input."
|
|
)
|
|
|
|
# Step 4. Issue the instructions.
|
|
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
|
n_collective_group_elems = n_group_elems * num_cta
|
|
n_col_groups = n_groups // n_lane_groups
|
|
assert d.layout.base_tile_shape[0] % 4 == 0
|
|
lanes_per_n_group = d.layout.base_tile_shape[0] // 4
|
|
a_sparse_addr_base = a_sparse_metadata.address if is_sparse else None # type: ignore
|
|
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
|
|
if isinstance(a, TMEMRef):
|
|
if m_groups != 1:
|
|
raise NotImplementedError("A address calculation for multiple M tiles")
|
|
a_mk = a.slice(slice(None), utils.ds(ki * k_group_elems, k_group_elems)).address
|
|
else:
|
|
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
|
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
|
|
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
|
|
b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64))
|
|
if a_sparse_addr_base is not None:
|
|
if n_groups != 1 or m_groups != 1:
|
|
raise NotImplementedError("A sparse metadata address calculation for multiple tiles")
|
|
assert k_group_elems % 32 == 0
|
|
cols_per_k_group = k_group_elems // 32
|
|
a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32))
|
|
else:
|
|
a_sparse_addr = None
|
|
if is_scaled and (m_groups != 1 or n_groups != 1 or k_groups != 1):
|
|
raise NotImplementedError("Block-scaled metadata address calculation for multiple tiles")
|
|
acc = accumulate if ki == 0 else true
|
|
ni_lane_group, ni_col = ni // n_col_groups, ni % n_col_groups
|
|
d_offset = (
|
|
((ni_lane_group * lanes_per_n_group) << 16)
|
|
+ ni_col * n_collective_group_elems
|
|
)
|
|
if m_groups != 1:
|
|
raise NotImplementedError("D address calculation for multiple M tiles")
|
|
_do_mma(
|
|
arith.addi(d.address, arith.constant(i32, d_offset)),
|
|
a_mk,
|
|
b_nk,
|
|
d_type=d.dtype,
|
|
m=m_group_elems,
|
|
n=n_group_elems,
|
|
k=k_group_elems,
|
|
collective=collective,
|
|
a_transpose=a_fastest != mma_utils.Dim.K,
|
|
b_transpose=b_fastest != mma_utils.Dim.K,
|
|
a_k_strides=a_k_instr_strides,
|
|
b_k_strides=b_k_instr_strides,
|
|
a_scale_addr=a_scale.address if a_scale is not None else None,
|
|
b_scale_addr=b_scale.address if b_scale is not None else None,
|
|
a_sparse_addr=a_sparse_addr,
|
|
accumulate=acc,
|
|
element_type=mma_element_type,
|
|
)
|
|
|
|
|
|
def _do_mma(
|
|
d_addr: ir.Value,
|
|
a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None
|
|
b_desc: ir.Value,
|
|
a_transpose: bool,
|
|
b_transpose: bool,
|
|
a_k_strides: tuple[tuple[int, ...], tuple[int, ...]] | None,
|
|
b_k_strides: tuple[tuple[int, ...], tuple[int, ...]],
|
|
a_scale_addr: ir.Value | None,
|
|
b_scale_addr: ir.Value | None,
|
|
a_sparse_addr: ir.Value | None,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
element_type: ir.Type,
|
|
d_type: ir.Type,
|
|
accumulate: ir.Value,
|
|
collective: bool,
|
|
):
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
a_k_idx_tiling, a_k_strides = a_k_strides or (None, None)
|
|
b_k_idx_tiling, b_k_strides = b_k_strides
|
|
assert all(s % 16 == 0 for s in itertools.chain(a_k_strides or (), b_k_strides))
|
|
assert (a_scale_addr is None) == (b_scale_addr is None)
|
|
is_scaled = a_scale_addr is not None
|
|
is_sparse = a_sparse_addr is not None
|
|
elem_bitwidth = utils.bitwidth(element_type)
|
|
instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth
|
|
packing = 8 * 4 // elem_bitwidth
|
|
assert not is_sparse or elem_bitwidth == 16 # Only 16-bit supported for now.
|
|
|
|
extra_args: Sequence[object]
|
|
scale_steps = None
|
|
if is_scaled:
|
|
if (ir.Float8E5M2Type.isinstance(element_type) or
|
|
ir.Float8E4M3FNType.isinstance(element_type)):
|
|
kind = "mxf8f6f4.block_scale.scale_vec::1X"
|
|
scale_steps = 4
|
|
create_scaled_instr_descriptor = create_scaled_f8f6f4_instr_descriptor
|
|
elif ir.Float4E2M1FNType.isinstance(element_type):
|
|
assert not a_transpose and not b_transpose
|
|
kind = "mxf4.block_scale.scale_vec::2X"
|
|
scale_steps = 2
|
|
create_scaled_instr_descriptor = create_scaled_f4_instr_descriptor
|
|
else:
|
|
raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}")
|
|
extra_args = (a_scale_addr, b_scale_addr)
|
|
extra_ptx = "[$5], [$6], "
|
|
extra_constraints = ",r,r"
|
|
else:
|
|
if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type):
|
|
kind = "f16"
|
|
elif ir.Float8E5M2Type.isinstance(element_type):
|
|
kind = "f8f6f4"
|
|
elif ir.Float8E4M3FNType.isinstance(element_type):
|
|
kind = "f8f6f4"
|
|
elif ir.IntegerType.get_signless(8).isinstance(element_type):
|
|
kind = "i8"
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input element type: {element_type}")
|
|
extra_args = ()
|
|
extra_constraints = extra_ptx = ""
|
|
|
|
def create_scaled_instr_descriptor(*args):
|
|
raise NotImplementedError
|
|
|
|
num_cta = 2 if collective else 1
|
|
a_in_tmem = a_k_strides is None
|
|
a_ptx = "[$1]" if a_in_tmem else "$1"
|
|
a_ptx_constraint = "r" if a_in_tmem else "l"
|
|
sparse_mod = ".sp" if is_sparse else ""
|
|
sparse_meta_ptx = "[$5], " if is_sparse else ""
|
|
extra_constraints += ",r" if is_sparse else ""
|
|
sparse_addr: tuple[Any, ...] = ()
|
|
assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64)
|
|
assert scale_steps is None or scale_steps == k // instr_k
|
|
def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]):
|
|
assert len(idx_tiling) + 1 == len(strides)
|
|
idxs = []
|
|
for t in idx_tiling:
|
|
idxs.append(idx // t)
|
|
idx = idx % t
|
|
idxs.append(idx)
|
|
offset = sum(i * s for i, s in zip(idxs, strides, strict=True))
|
|
return arith.constant(i64, offset >> 4)
|
|
for k_step in range(k // instr_k):
|
|
if is_scaled:
|
|
assert scale_steps is not None
|
|
scale_vec_width = 4 // scale_steps
|
|
scale_id = k_step * scale_vec_width
|
|
i_desc = create_scaled_instr_descriptor(
|
|
m, n, element_type, element_type, scale_id, scale_id, a_transpose, b_transpose
|
|
)
|
|
else:
|
|
sp_selector = None
|
|
if is_sparse:
|
|
assert (k // instr_k) % 2 == 0
|
|
sp_selector = k_step % 2
|
|
selector_width = 64
|
|
k_steps_for_col_inc = selector_width // instr_k
|
|
# If the K group is large, we need to increment the sparse metadata.
|
|
# TODO(apaszke): At this point the purpose of this function is becoming
|
|
# less clear, since we end up replicating address arithmetic that's
|
|
# already there in the caller. We should unify them into a single loop.
|
|
sparse_addr = (
|
|
arith.addi(
|
|
a_sparse_addr, utils.c(k_step // k_steps_for_col_inc * 2, i32)
|
|
),
|
|
)
|
|
i_desc = create_instr_descriptor(
|
|
m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose, sparsity_selector=sp_selector
|
|
)
|
|
if a_in_tmem:
|
|
a_desc_or_addr_instr = arith.addi(
|
|
a_desc_or_addr, arith.constant(i32, k_step * instr_k // packing)
|
|
)
|
|
else:
|
|
assert a_k_idx_tiling is not None and a_k_strides is not None
|
|
a_desc_or_addr_instr = arith.addi(
|
|
a_desc_or_addr, _get_offset(k_step, a_k_idx_tiling, a_k_strides)
|
|
)
|
|
b_desc_instr = arith.addi(b_desc, _get_offset(k_step, b_k_idx_tiling, b_k_strides))
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *extra_args, *sparse_addr],
|
|
f"tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, {sparse_meta_ptx}$3, {extra_ptx}$4;",
|
|
f"r,{a_ptx_constraint},l,r,b" + extra_constraints,
|
|
has_side_effects=True,
|
|
)
|
|
accumulate = arith.constant(i1, 1)
|
|
|
|
|
|
def commit_arrive(
|
|
barrier: utils.BarrierRef | ir.Value,
|
|
collective: bool = False,
|
|
ctx: LaunchContext | None = None,
|
|
):
|
|
if isinstance(barrier, utils.BarrierRef):
|
|
barrier = barrier.get_ptr()
|
|
elif barrier.type != ir.Type.parse("!llvm.ptr<3>"):
|
|
raise ValueError(
|
|
"barrier must be a Mosaic barrier or a SMEM pointer, got:"
|
|
f" {barrier.type}"
|
|
)
|
|
if collective:
|
|
if ctx is None:
|
|
raise ValueError("ctx must be provided for collective barriers")
|
|
# TODO(apaszke): This is just 0b11 shifted by the even CTA index.
|
|
if ctx.cluster_size != (2, 1, 1):
|
|
raise NotImplementedError("Collective arrivals only support (2, 1, 1)-shaped clusters")
|
|
ptx = """
|
|
{
|
|
.reg .b16 msk;
|
|
mov.b16 msk, 3;
|
|
tcgen05.commit.cta_group::2.mbarrier::arrive::one.multicast::cluster.b64 [$0], msk;
|
|
}
|
|
"""
|
|
else:
|
|
ptx = "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];"
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"), [barrier], ptx, "r", has_side_effects=True
|
|
)
|
|
|
|
|
|
def tmem_alloc_exact_ncols(ncols: int, exact: bool) -> int:
|
|
"""Returns the exact number of columns to allocate in TMEM.
|
|
|
|
The number of columns is rounded up to the nearest power of 2.
|
|
|
|
Args:
|
|
ncols: The number of columns to allocate.
|
|
exact: If true, throws an error if the number of columns is not a power of 2
|
|
and within [32, 512].
|
|
"""
|
|
if exact:
|
|
if ncols.bit_count() != 1 or not 32 <= ncols <= 512:
|
|
raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}")
|
|
else:
|
|
ncols = max(32, 1 << (ncols - 1).bit_length())
|
|
if ncols > 512:
|
|
raise ValueError(
|
|
f"After rounding up, got {ncols} columns, exceeding the limit of 512"
|
|
)
|
|
return ncols
|
|
|
|
|
|
def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True) -> tuple[ir.Value, int]:
|
|
if ir.MemRefType.isinstance(tmem_addr.type):
|
|
ref_ty = ir.MemRefType(tmem_addr.type)
|
|
if ref_ty.element_type != ir.IntegerType.get_signless(32):
|
|
raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}")
|
|
if not utils.is_smem_ref(ref_ty):
|
|
raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}")
|
|
if math.prod(ref_ty.shape) != 1:
|
|
raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}")
|
|
tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
|
elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"):
|
|
raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}")
|
|
ncols = tmem_alloc_exact_ncols(ncols, exact)
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[tmem_addr],
|
|
f"tcgen05.alloc.cta_group::{num_cta}.sync.aligned.shared::cta.b32 [$0], {ncols};",
|
|
"r",
|
|
has_side_effects=True,
|
|
), ncols
|
|
|
|
|
|
def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True):
|
|
if tmem_addr.type != ir.IntegerType.get_signless(32):
|
|
raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}")
|
|
ncols = tmem_alloc_exact_ncols(ncols, exact)
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[tmem_addr],
|
|
f"tcgen05.dealloc.cta_group::{num_cta}.sync.aligned.b32 $0, {ncols};",
|
|
"r",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
def tmem_relinquish_alloc_permit(collective: bool):
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[],
|
|
f"tcgen05.relinquish_alloc_permit.cta_group::{num_cta}.sync.aligned;",
|
|
"",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
def _tmem_access_helper(shape, num):
|
|
if num.bit_count() != 1 or num > 128:
|
|
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
|
|
match shape:
|
|
case "32x32b":
|
|
num_regs = 1
|
|
case "16x128b":
|
|
num_regs = 2
|
|
case "16x256b":
|
|
num_regs = 4
|
|
case _:
|
|
raise NotImplementedError(f"{shape=} is unsupported")
|
|
num_regs *= num
|
|
if num_regs > 255:
|
|
raise ValueError(
|
|
f"TMEM translation too big : {shape=} and {num=} involve"
|
|
f" {num_regs} registers per-thread, which exceeds the limit of 255"
|
|
)
|
|
regs_vector = ",".join(f"${i}" for i in range(num_regs))
|
|
regs_vector = "{" + regs_vector + "}"
|
|
return num_regs, regs_vector
|
|
|
|
|
|
def tmem_load(tmem_addr, shape, num, pack: bool):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
|
|
pack_mod = ".pack::16b" if pack else ""
|
|
regs = llvm.inline_asm(
|
|
ir.Type.parse(
|
|
"!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
|
|
),
|
|
[tmem_addr],
|
|
f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];",
|
|
"=r," * num_out_regs + "r",
|
|
has_side_effects=True,
|
|
)
|
|
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
|
|
|
|
|
|
def wait_tmem_load():
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[],
|
|
"tcgen05.wait::ld.sync.aligned;",
|
|
"",
|
|
has_side_effects=True,
|
|
)
|
|
utils.warpgroup_barrier()
|
|
|
|
|
|
def tmem_store(tmem_addr, shape, num, regs, unpack: bool):
|
|
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
|
|
pack_mod = ".unpack::16b" if unpack else ""
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[*regs, tmem_addr],
|
|
f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};",
|
|
"r," * num_out_regs + "r",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TMEMLayout(fa.TiledLayout):
|
|
"""Represents the way a shape is laid out in TMEM.
|
|
|
|
The layout describes how the shape is split across the 128 rows (lanes) of
|
|
TMEM. We reinterpret warp_dims as the partitioning of TMEM into 4 banks, each
|
|
accessible from a single warp. The 32 lanes inside each bank are assigned
|
|
consecutive elements from lane_dims. The data within each lane is linearized
|
|
in row-major order, with each vector padded up to 32 bits (wider vectors are
|
|
unsupported).
|
|
"""
|
|
|
|
def check_type(self, shape: tuple[int, ...], bitwidth: int):
|
|
if len(shape) != 2:
|
|
raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
|
|
if any(s % t for s, t in zip(shape, self.base_tile_shape)):
|
|
raise ValueError(
|
|
f"{shape} is not divisible into tiles of shape {self.base_tile_shape}"
|
|
)
|
|
if self.vector_length not in {1, fully_packed := 32 // bitwidth}:
|
|
raise ValueError(
|
|
f"For {bitwidth}-bit types, the vector length must be 1 or"
|
|
f" {fully_packed} , but got: {self.vector_length}"
|
|
)
|
|
|
|
def cols_in_shape(self, shape: tuple[int, int], bitwidth: int):
|
|
self.check_type(shape, bitwidth)
|
|
return math.prod(shape) // TMEM_ROWS // self.vector_length
|
|
|
|
def canonicalize(self) -> "TMEMLayout":
|
|
layout = super().canonicalize()
|
|
return TMEMLayout(
|
|
layout.tiling,
|
|
layout.warp_dims,
|
|
layout.lane_dims,
|
|
layout.vector_dim,
|
|
_check_canonical=False,
|
|
)
|
|
|
|
|
|
def _infer_tmem_load_registers_layout(
|
|
tmem_layout: TMEMLayout, columns: int, packing: int
|
|
) -> fa.TiledLayout:
|
|
if tmem_layout == tmem_default_layout(packing=packing):
|
|
return LAYOUT
|
|
if tmem_layout == tmem_half_lane_layout(columns, packing=packing):
|
|
return fa.WGMMA_LAYOUT
|
|
if tmem_layout == tmem_m64_collective_layout(columns, packing=packing):
|
|
return fa_m64_collective_layout(columns)
|
|
raise ValueError(f"TMEM layout {tmem_layout} is not supported")
|
|
|
|
|
|
def _infer_tmem_layout(shape: tuple[int, int], collective: bool, packing: int) -> TMEMLayout:
|
|
if len(shape) != 2:
|
|
raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
|
|
if packing > 8 or packing.bit_count() != 1:
|
|
raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
|
|
if shape[1] % packing:
|
|
raise ValueError(f"Minor dimension of shape must be divisible by packing, got: {shape}")
|
|
if shape[0] == TMEM_ROWS:
|
|
return tmem_default_layout(packing)
|
|
elif shape[0] == TMEM_ROWS // 2:
|
|
if collective:
|
|
return tmem_m64_collective_layout(shape[1], packing)
|
|
else:
|
|
return tmem_half_lane_layout(shape[1], packing)
|
|
else:
|
|
raise ValueError(f"Unsupported shape: {shape}")
|
|
|
|
|
|
def tmem_default_layout(packing: int = 1):
|
|
"""A TMEM layout used for 1CTA MMA with M=128 and 2CTA MMA with M=256."""
|
|
if packing.bit_count() != 1:
|
|
raise ValueError(f"Packing must be a power of 2, got: {packing}")
|
|
return TMEMLayout(
|
|
fa.Tiling(((TMEM_ROWS, packing), (fa.WARP_SIZE, packing))),
|
|
warp_dims=(-4,),
|
|
lane_dims=(-2,),
|
|
vector_dim=-1,
|
|
)
|
|
|
|
|
|
def tmem_half_lane_layout(columns, packing: int = 1):
|
|
"""A TMEM layout used for 1CTA MMA with M=64."""
|
|
if packing > columns or packing.bit_count() != 1:
|
|
raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
|
|
if columns % 16:
|
|
raise ValueError(f"Columns must be a multiple of 16, got: {columns}")
|
|
return TMEMLayout(
|
|
fa.Tiling((
|
|
(TMEM_ROWS // 2, columns),
|
|
(fa.WARP_SIZE // 2, columns // 2),
|
|
(packing,),
|
|
)),
|
|
warp_dims=(-5,),
|
|
lane_dims=(-4, -3),
|
|
vector_dim=-1,
|
|
)
|
|
|
|
|
|
def tmem_m64_collective_layout(columns, packing: int = 1):
|
|
"""A TMEM layout used for 2CTA MMA with M=128."""
|
|
if packing > 8 or packing.bit_count() != 1:
|
|
raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
|
|
if columns % 16:
|
|
raise ValueError(f"Columns must be a multiple of 16, got: {columns}")
|
|
return TMEMLayout(
|
|
fa.Tiling((
|
|
(TMEM_ROWS // 2, columns),
|
|
(fa.WARP_SIZE, columns // 2),
|
|
(packing,),
|
|
)),
|
|
warp_dims=(-4, -5,),
|
|
lane_dims=(-3,),
|
|
vector_dim=-1,
|
|
)
|
|
|
|
|
|
def fa_m64_collective_layout(columns):
|
|
"""The register layout for transfers to/from tmem_m64_collective_layout."""
|
|
if columns % 8:
|
|
raise ValueError(f"Columns must be a multiple of 8, got: {columns}")
|
|
return fa.TiledLayout(
|
|
fa.Tiling((
|
|
(TMEM_ROWS // 2, columns), (fa.WARP_SIZE, columns // 2), (8, 8), (2,)
|
|
)),
|
|
warp_dims=(-6, -7),
|
|
lane_dims=(-3, -2),
|
|
vector_dim=-1,
|
|
)
|
|
|
|
|
|
def scales_layout():
|
|
"""A TMEM layout for A and B scales in .scale_vec::1X configuration.
|
|
|
|
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
|
|
"""
|
|
return TMEMLayout(
|
|
fa.Tiling(((TMEM_ROWS, 4), (TMEM_ROWS // 4, 1))),
|
|
warp_dims=(fa.Replicated(times=4),),
|
|
lane_dims=(-2,),
|
|
vector_dim=-3,
|
|
)
|
|
|
|
|
|
def sparse_meta_layout():
|
|
"""A TMEM layout for A sparsity metadata.
|
|
|
|
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-sparse-matrices-sparsity-selector-kind-tf32-m128-256
|
|
"""
|
|
# TODO(apaszke): This does not really describe this layout and we can't do it
|
|
# until we add support for multiple vector dims. Still, it's ok to do for now,
|
|
# because we don't use TMEM layouts for any automatic transformations at the
|
|
# moment and only ever compare it for equality.
|
|
return TMEMLayout(
|
|
fa.Tiling(((TMEM_ROWS, 16), (TMEM_ROWS // 4, 1), (16, 1), (8, 1))),
|
|
warp_dims=(-8,),
|
|
lane_dims=(-2, -4, -6),
|
|
vector_dim=-7,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TMEMRef:
|
|
address: ir.Value
|
|
shape: tuple[int, int]
|
|
dtype: ir.Type
|
|
layout: TMEMLayout
|
|
|
|
@classmethod
|
|
def from_alloc(
|
|
cls,
|
|
tmem_addr_ref: ir.Value,
|
|
shape: tuple[int, int],
|
|
dtype,
|
|
collective: bool | None = None,
|
|
layout: TMEMLayout | None = None,
|
|
):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
|
|
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
|
|
addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
|
|
if not utils.is_smem_ref(addr_ref_ty):
|
|
raise ValueError(f"tmem_addr_ref must be in shared memory, got: {addr_ref_ty}")
|
|
if addr_ref_ty.element_type != i32:
|
|
raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
|
|
if math.prod(addr_ref_ty.shape) != 1:
|
|
raise ValueError(f"tmem_addr_ref must contain a single element, got: {addr_ref_ty}")
|
|
i0 = arith.ConstantOp.create_index(0)
|
|
tmem_addr = memref.load(tmem_addr_ref, [i0] * addr_ref_ty.rank)
|
|
if shape[0] < 32:
|
|
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
|
|
if layout is None:
|
|
if collective is None:
|
|
raise ValueError(
|
|
"collective argument must be provided when TMEM layout is inferred"
|
|
)
|
|
layout = _infer_tmem_layout(shape, collective, packing=1)
|
|
else:
|
|
layout.check_type(shape, utils.bitwidth(dtype))
|
|
# TODO: Do we have to do this??
|
|
# warp_idx = utils.warp_idx(sync=False)
|
|
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
|
|
return cls(tmem_addr, shape, dtype, layout)
|
|
|
|
def slice(self, *idxs):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
|
|
if any(is_squeezed):
|
|
raise ValueError("TMEM can only be sliced, not indexed")
|
|
if base_idx == [0] * len(base_idx) and slice_shape == list(self.shape):
|
|
return self # Trival slice
|
|
if self.layout == tmem_default_layout(packing=1):
|
|
packing = 1
|
|
elif self.layout == tmem_default_layout(packing=2):
|
|
packing = 2
|
|
else:
|
|
raise NotImplementedError(
|
|
"Slicing only implemented for refs with standard layout, got:"
|
|
f" {self.layout}"
|
|
)
|
|
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
|
|
raise NotImplementedError("TMEM cannot be sliced along rows")
|
|
if slice_shape[1] % 8:
|
|
raise NotImplementedError(
|
|
"TMEM column slice length must be a multiple of 8. "
|
|
f"Got {slice_shape[1]}."
|
|
)
|
|
col_idx = base_idx[1]
|
|
if not isinstance(col_idx, ir.Value):
|
|
col_idx = arith.constant(i32, col_idx)
|
|
if col_idx.type == ir.IndexType.get():
|
|
col_idx = arith.index_cast(i32, col_idx)
|
|
if packing != 1:
|
|
col_idx = arith.divui(col_idx, arith.constant(i32, packing))
|
|
return TMEMRef(
|
|
address=arith.addi(self.address, col_idx),
|
|
shape=tuple(slice_shape),
|
|
layout=self.layout,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
def load(self, layout: fa.TiledLayout | None = None, is_signed: bool | None = None):
|
|
if utils.bitwidth(self.dtype) not in {16, 32}:
|
|
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
|
packing = self.layout.vector_length
|
|
if layout is None:
|
|
layout = _infer_tmem_load_registers_layout(
|
|
self.layout, self.shape[1], packing
|
|
)
|
|
regs_shape = layout.registers_shape(self.shape)
|
|
if regs_shape[0] != 1: # We'll need to issue multiple loads below.
|
|
raise NotImplementedError("Loading multiple row tiles")
|
|
if layout == LAYOUT and self.layout == tmem_default_layout(packing=packing):
|
|
registers = _load_32xcols(
|
|
self.address, self.shape[1], self.dtype, packing
|
|
).T.reshape(regs_shape)
|
|
elif layout == TMEM_NATIVE_LAYOUT and self.layout == tmem_default_layout(packing=packing):
|
|
registers = _load_32xcols_native(
|
|
self.address, self.shape[1], self.dtype, packing
|
|
).reshape(regs_shape)
|
|
elif layout == fa.WGMMA_LAYOUT and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing):
|
|
# Load half the columns, since they are folded over lanes.
|
|
raw_registers = _load_32xcols(
|
|
self.address, self.shape[1] // 2, self.dtype, packing
|
|
)
|
|
assert raw_registers.shape[0] == 4
|
|
registers = np.concatenate([raw_registers[:2], raw_registers[2:]], axis=1)
|
|
registers = registers.T.reshape(regs_shape)
|
|
elif layout == fa_m64_collective_layout(self.shape[1]) and self.layout == tmem_m64_collective_layout(self.shape[1], packing=packing):
|
|
regs_shape = layout.registers_shape(self.shape)
|
|
# We take half the columns, because they are split over halves of TMEM.
|
|
registers = _load_32xcols(
|
|
self.address, self.shape[1] // 2, self.dtype, packing
|
|
).reshape(regs_shape)
|
|
else:
|
|
raise ValueError(
|
|
f"Loads from TMEM layout {self.layout} to register layout"
|
|
f" {layout} are not supported"
|
|
)
|
|
return fa.FragmentedArray(
|
|
_registers=registers, _layout=layout, _is_signed=is_signed
|
|
)
|
|
|
|
def store(self, value):
|
|
if utils.bitwidth(self.dtype) not in {16, 32}:
|
|
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
|
if not isinstance(value, fa.FragmentedArray):
|
|
raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}")
|
|
if value.shape != self.shape:
|
|
raise ValueError(
|
|
f"Stored array has shape {value.shape}, but TMEM has shape"
|
|
f" {self.shape}"
|
|
)
|
|
if value.mlir_dtype != self.dtype:
|
|
raise ValueError(
|
|
f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype"
|
|
f" {self.dtype}"
|
|
)
|
|
packing = self.layout.vector_length
|
|
if value.layout == LAYOUT and self.layout == tmem_default_layout(packing=packing):
|
|
_store_32xcols(
|
|
self.address, value.registers.T.reshape((4, -1)), packing
|
|
)
|
|
elif value.layout == TMEM_NATIVE_LAYOUT and self.layout == tmem_default_layout(packing=packing):
|
|
_store_32xcols_native(
|
|
self.address, value.registers.reshape(-1), packing
|
|
)
|
|
elif value.layout == fa.WGMMA_LAYOUT and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing):
|
|
registers = value.registers.T.reshape(2, -1)
|
|
registers = np.concatenate(np.split(registers, 2, axis=1), axis=0)
|
|
_store_32xcols(self.address, registers, packing)
|
|
elif value.layout == fa_m64_collective_layout(self.shape[1]) and self.layout == tmem_m64_collective_layout(self.shape[1], packing=packing):
|
|
_store_32xcols(self.address, value.registers.reshape(4, -1), packing)
|
|
else:
|
|
raise ValueError(
|
|
f"Storing from register layout {value.layout} to TMEM layout"
|
|
f" {self.layout} is not supported"
|
|
)
|
|
|
|
def _debug_print(self):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
num_cols = self.layout.cols_in_shape(self.shape, utils.bitwidth(self.dtype))
|
|
lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE))
|
|
for c in range(num_cols):
|
|
val = llvm.inline_asm(
|
|
i32,
|
|
[arith.addi(self.address, arith.constant(i32, c))],
|
|
"tcgen05.ld.sync.aligned.32x32b.x1.b32 {$0}, [$1];",
|
|
"=r,r",
|
|
)
|
|
dtype_bitwidth = utils.bitwidth(self.dtype)
|
|
full_packing = 32 // dtype_bitwidth
|
|
if self.layout.vector_length == 1:
|
|
if dtype_bitwidth < 32:
|
|
val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val)
|
|
val = utils.bitcast(val, self.dtype)
|
|
elif self.layout.vector_length == full_packing:
|
|
val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype))
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unsupported packing: {self.layout.vector_length}"
|
|
)
|
|
# TODO(apaszke): Make this print logical, not physical location.
|
|
utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False)
|
|
|
|
|
|
def _transfer_32xcols(
|
|
base_addr: ir.Value,
|
|
cols: int,
|
|
atom_shape: tuple[int, int],
|
|
tmem_packing: int,
|
|
reg_packing: int,
|
|
):
|
|
"""Generates a sequence of parameters for a given TMEM read or write.
|
|
|
|
Arguments:
|
|
base_addr: The base address of the TMEM region.
|
|
cols: The number of logical columns to transfer.
|
|
atom_shape: The logical shape of the tile written by the warp in a single
|
|
TMEM transfer.
|
|
tmem_packing: Packing degree in TMEM. When packing is 1, but the data is
|
|
16-bit, we expect that each transfer actually involves double the number
|
|
of physical columns.
|
|
reg_packing: The number of elements that fit in a single 32-bit register.
|
|
"""
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
atom_rows, atom_cols = atom_shape
|
|
assert cols % atom_cols == 0
|
|
total_num = cols // atom_cols
|
|
assert total_num.bit_count() == 1
|
|
regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing)
|
|
# We artificially lower the instr_num compared to its limits, because higher
|
|
# values can lead to register spills..
|
|
instr_num = min(total_num, 32 // regs_per_instr)
|
|
assert 32 % atom_rows == 0
|
|
num_row_steps = 32 // atom_rows
|
|
for lane_step in range(num_row_steps):
|
|
addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32))
|
|
cols_per_instr = instr_num * atom_cols
|
|
for num_step in range(total_num // instr_num):
|
|
num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num)
|
|
addr_row_col = arith.addi(
|
|
addr_row, utils.c(num_step * cols_per_instr // tmem_packing, i32)
|
|
)
|
|
yield addr_row_col, instr_num, lane_step, num_slice
|
|
|
|
|
|
def _store_32xcols(base_addr, vector_regs, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4
|
|
cols = vector_regs.shape[1] * 8
|
|
|
|
reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
|
|
if reg_packing == 1:
|
|
store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
|
|
regs = np.empty((4, vector_regs.shape[1], 2), dtype=object)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for idx, vreg in np.ndenumerate(vector_regs):
|
|
regs[(*idx, 0)] = llvm.extractelement(vreg, c0)
|
|
regs[(*idx, 1)] = llvm.extractelement(vreg, c1)
|
|
regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2)
|
|
# From a single lane perspective a num tile consists of a 2x2, with the
|
|
# minor dim traversing columns and major being 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
|
|
assert regs.shape[-2:] == (2, 2)
|
|
assert tmem_packing == 1
|
|
unpack = False
|
|
elif reg_packing == 2:
|
|
store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
|
|
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
|
|
regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2)
|
|
assert 1 <= tmem_packing <= 2
|
|
unpack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
regs_slice = regs[lane_step, num_slice].flat
|
|
tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack)
|
|
|
|
|
|
def _store_32xcols_native(base_addr, vector_regs, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
assert vector_regs.ndim == 1
|
|
cols = len(vector_regs) * TMEM_NATIVE_LAYOUT.vector_length
|
|
|
|
reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
|
|
store_shape = "32x32b"
|
|
if reg_packing == 1:
|
|
store_atom_shape = (32, 1)
|
|
regs = [None] * (len(vector_regs) * 2)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for idx, vreg in enumerate(vector_regs):
|
|
regs[2 * idx] = llvm.extractelement(vreg, c0)
|
|
regs[2 * idx + 1] = llvm.extractelement(vreg, c1)
|
|
assert tmem_packing == 1
|
|
unpack = False
|
|
elif reg_packing == 2:
|
|
store_atom_shape = (32, 2)
|
|
regs = vector_regs
|
|
assert 1 <= tmem_packing <= 2
|
|
unpack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
assert lane_step == 0
|
|
regs_slice = regs[num_slice]
|
|
tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack)
|
|
|
|
|
|
def _load_32xcols(base_addr, cols, dtype, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
vec_ty = ir.VectorType.get((2,), dtype)
|
|
reg_packing = 32 // utils.bitwidth(dtype)
|
|
if reg_packing == 1:
|
|
load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
|
|
assert tmem_packing == 1
|
|
pack = False
|
|
elif reg_packing == 2:
|
|
load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
|
|
assert 1 <= tmem_packing <= 2
|
|
pack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
vector_regs = np.ndarray((4, cols // 8), dtype=object)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
regs = tmem_load(addr_row_col, load_shape, instr_num, pack)
|
|
row_slice = slice(lane_step * 2, (lane_step + 1) * 2)
|
|
# This aliases the original array, so updates will be reflected there.
|
|
vector_regs_update = vector_regs[row_slice, num_slice]
|
|
assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num)
|
|
if reg_packing == 1:
|
|
regs = [llvm.bitcast(dtype, r) for r in regs]
|
|
# From a single lane perspective a num tile consists of a 2x2, with the
|
|
# minor dim traversing columns and major being 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
|
|
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1)
|
|
undef = llvm.mlir_undef(vec_ty)
|
|
assert regs.shape == (*vector_regs_update.shape, 2)
|
|
for idx in np.ndindex(vector_regs_update.shape):
|
|
high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0)
|
|
vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1)
|
|
vector_regs_update[idx] = vreg
|
|
else:
|
|
assert reg_packing == 2
|
|
regs = [llvm.bitcast(vec_ty, r) for r in regs]
|
|
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
|
|
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1)
|
|
vector_regs_update[...] = regs
|
|
|
|
return vector_regs
|
|
|
|
|
|
def _load_32xcols_native(base_addr, cols, dtype, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
vec_ty = ir.VectorType.get((2,), dtype)
|
|
reg_packing = 32 // utils.bitwidth(dtype)
|
|
load_shape = "32x32b"
|
|
if reg_packing == 1:
|
|
load_atom_shape = (32, 1)
|
|
assert tmem_packing == 1
|
|
pack = False
|
|
elif reg_packing == 2:
|
|
load_atom_shape = (32, 2)
|
|
assert 1 <= tmem_packing <= 2
|
|
pack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
regs = [None] * (cols // reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
assert lane_step == 0, lane_step
|
|
instr_regs = tmem_load(addr_row_col, load_shape, instr_num, pack)
|
|
if reg_packing == 1:
|
|
regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs]
|
|
else:
|
|
assert reg_packing == 2
|
|
regs[num_slice] = [llvm.bitcast(vec_ty, r) for r in instr_regs]
|
|
|
|
if reg_packing == 1:
|
|
vector_regs = np.ndarray((cols // 2,), dtype=object)
|
|
undef = llvm.mlir_undef(vec_ty)
|
|
for idx in range(vector_regs.size):
|
|
high_undef = llvm.insertelement(undef, regs[2 * idx], c0)
|
|
vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1)
|
|
vector_regs[idx] = vreg
|
|
else:
|
|
assert reg_packing == 2
|
|
vector_regs = np.asarray(regs, dtype=object)
|
|
|
|
assert vector_regs.shape == (cols // TMEM_NATIVE_LAYOUT.vector_length,)
|
|
return vector_regs
|
|
|
|
|
|
def commit_tmem():
|
|
void = ir.Type.parse("!llvm.void")
|
|
llvm.inline_asm(
|
|
void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True,
|
|
)
|
|
utils.warpgroup_barrier()
|
|
|
|
|
|
def wait_load_tmem():
|
|
void = ir.Type.parse("!llvm.void")
|
|
llvm.inline_asm(
|
|
void, [], "tcgen05.wait::ld.sync.aligned;", "", has_side_effects=True,
|
|
)
|
|
utils.warpgroup_barrier()
|
|
|
|
|
|
def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef):
|
|
"""Asynchronously copies the scale data from SMEM to TMEM.
|
|
|
|
The result of the copy can be awaited by calling ``commit_arrive`` and waiting
|
|
on the chosen ``Barrier``. However, if TMEM reference is to be consumed by a
|
|
MMA issued in the same thread, no additional synchronization is needed.
|
|
|
|
At the moment the function requires ``smem_ref`` to be contiguous and have a
|
|
shape of (MN // 128, 32, 16) for 8-bit scales (here MN stands for the size of
|
|
the non-contracting dimension which is M or N), matching the scale layout for
|
|
.scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
|
|
for more details. Note that we always put the non-contracting dimension first.
|
|
If you have a (MN, 4) array of scales in JAX (where MN is divisible by 128),
|
|
you can prepare it for use in the kernel this way::
|
|
|
|
scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16)
|
|
|
|
The TMEM ref is expected to have the logical shape of the scales (MN, 4), and
|
|
the layout created by ``scales_layout()``.
|
|
"""
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
smem_ty = ir.MemRefType(smem_ref.type)
|
|
if (dtype := smem_ty.element_type) != tmem_ref.dtype:
|
|
raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}")
|
|
if dtype != ir.Float8E8M0FNUType.get():
|
|
raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu supported")
|
|
if tmem_ref.shape[0] % TMEM_ROWS:
|
|
raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}")
|
|
if tmem_ref.shape[1] != 4:
|
|
raise ValueError(f"TMEM reference must have 4 colums, but got {tmem_ref.shape[1]}")
|
|
if tmem_ref.layout != scales_layout():
|
|
raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported")
|
|
smem_shape = tuple(smem_ty.shape)
|
|
expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, 32, 16)
|
|
if smem_shape != expected_smem_shape:
|
|
raise NotImplementedError(
|
|
f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM"
|
|
f" ref shape {tmem_ref.shape}"
|
|
)
|
|
strides, _ = smem_ty.get_strides_and_offset()
|
|
if strides != utils.get_contiguous_strides(smem_shape):
|
|
raise ValueError("Only copies from contiguous SMEM references are supported")
|
|
row_tile_stride = strides[0]
|
|
if row_tile_stride % 4:
|
|
raise ValueError("Column tile stride must be a multiple of 4")
|
|
row_tile_stride_i32 = row_tile_stride // 4
|
|
smem_base_ptr = utils.memref_ptr(smem_ref, 3)
|
|
for row_tile in range(expected_smem_shape[0]):
|
|
load_ptr = utils.getelementptr(
|
|
smem_base_ptr, [row_tile * row_tile_stride_i32], i32
|
|
)
|
|
store_ptr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * row_tile))
|
|
# The "core matrix" here is the same as in MMA: 8x(16 bytes).
|
|
desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None)
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[store_ptr, desc],
|
|
"tcgen05.cp.cta_group::1.32x128b.warpx4 [$0], $1;",
|
|
"r,l",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
def async_copy_sparse_metadata_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef):
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
smem_ty = ir.MemRefType(smem_ref.type)
|
|
if (dtype := smem_ty.element_type) != tmem_ref.dtype:
|
|
raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}")
|
|
if dtype != ir.IntegerType.get_signless(2):
|
|
raise NotImplementedError(f"Unsupported dtype: {dtype}, only i2 supported")
|
|
if tmem_ref.shape[0] % 128:
|
|
raise ValueError(f"TMEM reference must have a multiple of 128 rows, but got {tmem_ref.shape[0]}")
|
|
if tmem_ref.shape[1] % 64:
|
|
raise ValueError(f"TMEM reference must have a multiple of 64 colums, but got {tmem_ref.shape[1]}")
|
|
if tmem_ref.layout != sparse_meta_layout():
|
|
raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported")
|
|
smem_shape = tuple(smem_ty.shape)
|
|
expected_smem_shape = (tmem_ref.shape[0] // 128, tmem_ref.shape[1] // 64, 128, 64)
|
|
if smem_shape != expected_smem_shape:
|
|
raise NotImplementedError(
|
|
f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM"
|
|
f" ref shape {tmem_ref.shape}"
|
|
)
|
|
strides, _ = smem_ty.get_strides_and_offset()
|
|
if strides != utils.get_contiguous_strides(smem_shape):
|
|
raise ValueError("Only copies from contiguous SMEM references are supported")
|
|
if expected_smem_shape[0] != 1:
|
|
raise NotImplementedError("Only M=128 supported")
|
|
k_tile_stride = strides[1]
|
|
if k_tile_stride % 16:
|
|
raise ValueError("K tile stride must be a multiple of 16")
|
|
k_tile_byte_stride = k_tile_stride // 4
|
|
smem_base_ptr = utils.memref_ptr(smem_ref, 3)
|
|
for k_tile in range(expected_smem_shape[1]):
|
|
load_ptr = utils.getelementptr(
|
|
smem_base_ptr, [k_tile * k_tile_byte_stride], i8
|
|
)
|
|
store_ptr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * k_tile))
|
|
# The "core matrix" here is the same as in MMA: 8x(16 bytes).
|
|
desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None)
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[store_ptr, desc],
|
|
"tcgen05.cp.cta_group::1.128x128b [$0], $1;",
|
|
"r,l",
|
|
has_side_effects=True,
|
|
)
|