510 lines
21 KiB
Python
510 lines
21 KiB
Python
from __future__ import annotations
|
|
import hashlib
|
|
import json
|
|
from .._C.libtriton import get_cache_invalidating_env_vars, ir
|
|
from ..backends import backends
|
|
from ..backends.compiler import Language
|
|
from ..backends.compiler import BaseBackend, GPUTarget
|
|
from .. import __version__, knobs
|
|
from ..runtime.autotuner import OutOfResources
|
|
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
|
|
from ..runtime.driver import driver
|
|
from ..tools.disasm import get_sass
|
|
from pathlib import Path
|
|
import re
|
|
import functools
|
|
import os
|
|
import time
|
|
import copy
|
|
|
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
# and any following whitespace
|
|
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
|
# - (@\w+) : match an @ symbol followed by one or more word characters
|
|
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
|
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
|
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
|
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
|
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
|
prototype_pattern = {
|
|
"ptx": ptx_prototype_pattern,
|
|
}
|
|
|
|
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
|
arg_type_pattern = {
|
|
"ptx": ptx_arg_type_pattern,
|
|
}
|
|
|
|
|
|
def convert_type_repr(x):
|
|
# Currently we only capture the pointer type and assume the pointer is on global memory.
|
|
# TODO: Capture and support shared memory space
|
|
match = re.search(r'!tt\.ptr<([^,]+)', x)
|
|
tma = re.search(r'tt.nv_tma_desc = 1', x)
|
|
if tma is not None:
|
|
return 'nvTmaDesc'
|
|
x = re.sub(r' {[^}]+}', '', x)
|
|
if match is not None:
|
|
return '*' + convert_type_repr(match.group(1))
|
|
return x
|
|
|
|
|
|
class ASTSource:
|
|
|
|
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
self.fn = fn
|
|
self.language = Language.TRITON
|
|
self.ext = "ttir"
|
|
self.name = fn.__name__
|
|
self.signature = signature
|
|
self.constants = dict()
|
|
if constexprs is not None:
|
|
for k, v in constexprs.items():
|
|
k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
|
|
assert isinstance(k, tuple)
|
|
self.constants[k] = v
|
|
self.attrs = attrs or dict()
|
|
for k in self.signature.keys():
|
|
if not isinstance(k, str):
|
|
raise TypeError("Signature keys must be string")
|
|
|
|
def hash(self):
|
|
sorted_sig = [v for k, v in sorted(self.signature.items())]
|
|
get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
|
|
constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
|
|
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
|
|
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
from .code_generator import ast_to_ttir
|
|
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
module_map=module_map)
|
|
|
|
def parse_options(self):
|
|
return dict()
|
|
|
|
|
|
class IRSource:
|
|
|
|
def __init__(self, path, context, backend):
|
|
self.path = path
|
|
path = Path(path)
|
|
self.ext = path.suffix[1:]
|
|
self.language = Language.TRITON
|
|
self.src = path.read_text()
|
|
ir.load_dialects(context)
|
|
backend.load_dialects(context)
|
|
|
|
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
|
|
# TODO - replace with a proper parser
|
|
if self.ext == "ptx":
|
|
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
|
|
self.name = match.group(1)
|
|
signature = match.group(2)
|
|
types = re.findall(arg_type_pattern[self.ext], signature)
|
|
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
|
|
else:
|
|
self.module = ir.parse_mlir_module(self.path, context)
|
|
fn_name = self.module.get_entry_func_name()
|
|
self.name = "@" + fn_name
|
|
funcOp = self.module.get_function(fn_name)
|
|
func_ty = self.module.get_function_signature(funcOp)
|
|
self.signature = {k: ty for k, ty in enumerate(func_ty)}
|
|
|
|
def hash(self):
|
|
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
|
|
|
|
def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
|
|
self.module.context = context
|
|
return self.module
|
|
|
|
def parse_options(self):
|
|
if self.ext == "ttgir":
|
|
num_warps = self.module.get_int_attr("ttg.num-warps")
|
|
assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
|
|
return {'num_warps': num_warps}
|
|
return dict()
|
|
|
|
|
|
@functools.lru_cache()
|
|
def max_shared_mem(device):
|
|
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
|
|
|
|
def parse(full_name, ext, context):
|
|
if ext == "ttir" or ext == "ttgir":
|
|
module = ir.parse_mlir_module(full_name, context)
|
|
module.context = context
|
|
return module
|
|
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
|
|
return Path(full_name).read_text()
|
|
if ext == "cubin" or ext == "hsaco":
|
|
return Path(full_name).read_bytes()
|
|
|
|
|
|
def filter_traceback(e: BaseException):
|
|
"""
|
|
Removes code_generator.py and related files from tracebacks.
|
|
|
|
These are uninteresting to the user -- "just show me *my* code!"
|
|
"""
|
|
if knobs.compilation.front_end_debugging:
|
|
return
|
|
|
|
if e.__cause__ is not None:
|
|
filter_traceback(e.__cause__)
|
|
if e.__context__ is not None:
|
|
filter_traceback(e.__context__)
|
|
|
|
# If a user has a file that matches one of these, they're out of luck.
|
|
BAD_FILES = [
|
|
"/triton/compiler/code_generator.py",
|
|
"/ast.py",
|
|
]
|
|
BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
|
|
|
|
tb = e.__traceback__
|
|
frames = []
|
|
while tb is not None:
|
|
if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
|
|
frames.append(tb)
|
|
tb = tb.tb_next
|
|
|
|
for (cur_frame, next_frame) in zip(frames, frames[1:]):
|
|
cur_frame.tb_next = next_frame
|
|
|
|
if not frames:
|
|
e.__traceback__ = None
|
|
else:
|
|
frames[-1].tb_next = None
|
|
e.__traceback__ = frames[0]
|
|
|
|
|
|
class CompileTimer:
|
|
|
|
def __init__(self) -> None:
|
|
self.start: float = time.time()
|
|
self.ir_initialization_end: float | None = None
|
|
self.lowering_stage_ends: list[tuple[str, float]] = []
|
|
self.store_results_end: float | None = None
|
|
|
|
def finished_ir_initialization(self) -> None:
|
|
self.ir_initialization_end = time.time()
|
|
|
|
def stage_finished(self, stage_name: str) -> None:
|
|
self.lowering_stage_ends.append((stage_name, time.time()))
|
|
|
|
def end(self) -> knobs.CompileTimes:
|
|
timestamp = time.time()
|
|
if self.ir_initialization_end is None:
|
|
self.ir_initialization_end = timestamp
|
|
else:
|
|
self.store_results_end = timestamp
|
|
|
|
def delta(start: float, end: float | None) -> int:
|
|
if end is None:
|
|
return 0
|
|
return int((end - start) * 1000000)
|
|
|
|
lowering_stage_durations = []
|
|
stage_start = self.ir_initialization_end
|
|
for stage_name, stage_end in self.lowering_stage_ends:
|
|
lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
|
|
stage_start = stage_end
|
|
|
|
return knobs.CompileTimes(
|
|
ir_initialization=delta(self.start, self.ir_initialization_end),
|
|
lowering_stages=lowering_stage_durations,
|
|
store_results=delta(stage_start, self.store_results_end),
|
|
)
|
|
|
|
|
|
def compile(src, target=None, options=None, _env_vars=None):
|
|
compilation_listener = knobs.compilation.listener
|
|
if compilation_listener:
|
|
timer = CompileTimer()
|
|
|
|
if target is None:
|
|
target = driver.active.get_current_target()
|
|
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
|
|
backend = make_backend(target)
|
|
ir_source = not isinstance(src, ASTSource)
|
|
# create backend
|
|
if ir_source:
|
|
assert isinstance(src, str), "source must be either AST or a filepath"
|
|
context = ir.context()
|
|
src = IRSource(src, context, backend)
|
|
|
|
extra_options = src.parse_options()
|
|
options = backend.parse_options(dict(options or dict(), **extra_options))
|
|
# create cache manager
|
|
env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
|
|
key = get_cache_key(src, backend, options, env_vars=env_vars)
|
|
hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
fn_cache_manager = get_cache_manager(hash)
|
|
# For dumping/overriding only hash the source as we want it to be independent of triton
|
|
# core changes to make it easier to track kernels by hash.
|
|
enable_override = knobs.compilation.override
|
|
enable_ir_dump = knobs.compilation.dump_ir
|
|
store_only_binary = knobs.compilation.store_binary_only
|
|
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
|
|
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
|
|
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
|
|
# The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}".
|
|
# A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate
|
|
# the file name to 150 characters to be safe.
|
|
file_name = src.name[:150]
|
|
metadata_filename = f"{file_name}.json"
|
|
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
|
metadata_path = metadata_group.get(metadata_filename)
|
|
always_compile = knobs.compilation.always_compile
|
|
if not always_compile and metadata_path is not None:
|
|
# cache hit!
|
|
res = CompiledKernel(src, metadata_group, hash)
|
|
if compilation_listener:
|
|
compilation_listener(
|
|
src=src,
|
|
metadata=res.metadata._asdict(),
|
|
metadata_group=metadata_group,
|
|
times=timer.end(),
|
|
cache_hit=True,
|
|
)
|
|
return res
|
|
|
|
# initialize metadata
|
|
metadata = {
|
|
"hash": hash,
|
|
"target": target,
|
|
**options.__dict__,
|
|
**env_vars,
|
|
}
|
|
metadata["triton_version"] = __version__
|
|
# run compilation pipeline and populate metadata
|
|
stages = dict()
|
|
backend.add_stages(stages, options, src.language)
|
|
first_stage = list(stages.keys()).index(src.ext)
|
|
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
|
|
if ir_source:
|
|
first_stage += 1
|
|
|
|
# For IRSource, we have already grabbed the context + called both
|
|
# ir.load_dialects and backend.load_dialects.
|
|
if not isinstance(src, IRSource):
|
|
context = ir.context()
|
|
ir.load_dialects(context)
|
|
backend.load_dialects(context)
|
|
|
|
codegen_fns = backend.get_codegen_implementation(options)
|
|
module_map = backend.get_module_map()
|
|
try:
|
|
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
except Exception as e:
|
|
filter_traceback(e)
|
|
raise
|
|
|
|
if ir_source:
|
|
ir_filename = f"{file_name}.{src.ext}"
|
|
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
else:
|
|
ir_filename = f"{file_name}.source"
|
|
metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
|
|
|
|
use_ir_loc = knobs.compilation.use_ir_loc
|
|
if ir_source and use_ir_loc:
|
|
module.create_location_snapshot(src.path)
|
|
print(f"Creating new locations for {src.path}")
|
|
|
|
if compilation_listener:
|
|
timer.finished_ir_initialization()
|
|
for ext, compile_ir in list(stages.items())[first_stage:]:
|
|
next_module = compile_ir(module, metadata)
|
|
ir_filename = f"{file_name}.{ext}"
|
|
if fn_override_manager is None:
|
|
# Users can override kernels at scale by setting `ir_override` in autotune config
|
|
# without TRITON_KERNEL_OVERRIDE
|
|
if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
|
|
next_module = parse(ir_override, ext, context)
|
|
elif full_name := fn_override_manager.get_file(ir_filename):
|
|
print(f"\nOverriding kernel with file {full_name}")
|
|
next_module = parse(full_name, ext, context)
|
|
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
|
|
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
|
|
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
if fn_dump_manager is not None:
|
|
fn_dump_manager.put(next_module, ir_filename)
|
|
if ext == "cubin":
|
|
sass = get_sass(next_module)
|
|
fn_dump_manager.put(sass, file_name + ".sass")
|
|
# use an env variable to parse ir from file
|
|
if use_ir_loc == ext:
|
|
ir_full_name = fn_cache_manager.get_file(ir_filename)
|
|
next_module.create_location_snapshot(ir_full_name)
|
|
print(f"Creating new locations for {ir_full_name}")
|
|
module = next_module
|
|
if compilation_listener:
|
|
timer.stage_finished(ext)
|
|
# write-back metadata
|
|
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
|
binary=False)
|
|
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
|
# Compilation completed, disabling multithreading in context.
|
|
# This is needed to safely finalize threads pool inside context: if current process forks before
|
|
# python GC deletes context object, thread pool in child process will be invalid, which could
|
|
# lead to child crash or hang.
|
|
#
|
|
# However disabling multithreading causes the code to hang if the ASAN pass is enabled
|
|
# this is likely due to the llvm-symbolizer forking a process
|
|
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
|
|
# multithreading in the MLIR context
|
|
if not knobs.compilation.enable_asan:
|
|
context.disable_multithreading()
|
|
|
|
# notify any listener
|
|
if compilation_listener:
|
|
compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
|
|
cache_hit=False)
|
|
# return handle to compiled kernel
|
|
return CompiledKernel(src, metadata_group, hash)
|
|
|
|
|
|
def make_backend(target: GPUTarget) -> BaseBackend:
|
|
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
|
|
if len(actives) != 1:
|
|
raise RuntimeError(
|
|
f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
|
|
return actives[0](target)
|
|
|
|
|
|
class LazyDict:
|
|
|
|
def __init__(self, data):
|
|
self.data = data
|
|
self.extras = []
|
|
|
|
def get(self):
|
|
for func, args in self.extras:
|
|
self.data = self.data | func(*args)
|
|
self.extras.clear()
|
|
return self.data
|
|
|
|
def add(self, func, args):
|
|
self.extras.append((func, args))
|
|
|
|
|
|
class AsmDict(dict):
|
|
|
|
def __missing__(self, key):
|
|
|
|
if key == "sass":
|
|
value = get_sass(self["cubin"])
|
|
else:
|
|
raise KeyError("Unknown key: '%s'" % key)
|
|
|
|
self[key] = value
|
|
return value
|
|
|
|
|
|
def _raise_error(err, *args, **kwargs):
|
|
raise copy.deepcopy(err)
|
|
|
|
|
|
class CompiledKernel:
|
|
|
|
def __init__(self, src, metadata_group, hash):
|
|
from collections import namedtuple
|
|
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
|
|
metadata = json.loads(metadata_path.read_text())
|
|
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
|
|
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
|
|
target = metadata['target']
|
|
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
|
|
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
|
|
self.metadata = KernelMetadata(**metadata)
|
|
backend = make_backend(self.metadata.target)
|
|
self.packed_metadata = backend.pack_metadata(self.metadata)
|
|
self.src = src
|
|
self.hash = hash
|
|
self.name = self.metadata.name
|
|
# stores the text of each level of IR that was generated during compilation
|
|
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
|
|
binary_ext = backend.binary_ext
|
|
self.asm = AsmDict({
|
|
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
|
|
for file in asm_files
|
|
})
|
|
self.metadata_group = metadata_group
|
|
self.kernel = self.asm[binary_ext]
|
|
# binaries are lazily initialized
|
|
# because it involves doing runtime things
|
|
# (e.g., checking amount of shared memory on current device)
|
|
self.module = None
|
|
self.function = None
|
|
self._run = None
|
|
|
|
def _init_handles(self):
|
|
if self.module is not None:
|
|
return
|
|
|
|
def raise_(err):
|
|
# clone the exception object so that the one saved in the closure
|
|
# of the partial function below doesn't get assigned a stack trace
|
|
# after the subsequent raise. otherwise, the CompiledKernel instance
|
|
# saved in the (global) kernel cache will keep references to all the
|
|
# locals in the traceback via the exception instance in the closure.
|
|
cloned_err = copy.deepcopy(err)
|
|
self._run = functools.partial(_raise_error, cloned_err)
|
|
raise err
|
|
|
|
device = driver.active.get_current_device()
|
|
# create launcher
|
|
self._run = driver.active.launcher_cls(self.src, self.metadata)
|
|
# not enough shared memory to run the kernel
|
|
max_shared = max_shared_mem(device)
|
|
if self.metadata.shared > max_shared:
|
|
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
|
|
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
|
|
# Use blackwell max tmem size for now, this should be moved in device properties
|
|
max_tmem_size = 512 # tmem size in number of columns
|
|
if self.metadata.tmem_size > max_tmem_size:
|
|
raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
|
|
if knobs.runtime.kernel_load_start_hook is not None:
|
|
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
|
|
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
|
|
self.name, self.kernel, self.metadata.shared, device)
|
|
warp_size = driver.active.get_current_target().warp_size
|
|
if self.metadata.num_warps * warp_size > self.n_max_threads:
|
|
raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
|
|
if knobs.runtime.kernel_load_end_hook is not None:
|
|
knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
|
|
|
|
@property
|
|
def run(self):
|
|
if self._run is None:
|
|
self._init_handles()
|
|
return self._run
|
|
|
|
def launch_metadata(self, grid, stream, *args):
|
|
if knobs.runtime.launch_enter_hook is None:
|
|
return None
|
|
self._init_handles()
|
|
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
|
|
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
|
|
return ret
|
|
arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
|
|
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
|
|
return ret
|
|
|
|
def __getitem__(self, grid):
|
|
self._init_handles()
|
|
|
|
def runner(*args, stream=None):
|
|
if stream is None:
|
|
device = driver.active.get_current_device()
|
|
stream = driver.active.get_current_stream(device)
|
|
launch_metadata = self.launch_metadata(grid, stream, *args)
|
|
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
|
|
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
|
|
|
|
return runner
|