DriverTrac/venv/lib/python3.12/site-packages/triton/tools/compile.py
2025-11-28 09:08:33 +05:30

211 lines
8.5 KiB
Python

import binascii
import hashlib
import importlib.util
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import List
import triton
import triton.backends
@dataclass
class CompileArgs:
'''
A class to contain arguments from command-line parser.
'''
path: str = ''
kernel_name: str = ''
signature: str = ''
grid: str = ''
target: str | None = None
num_warps: int = 1
num_stages: int = 3
out_name: str | None = None
out_path: Path | None = None
desc = """
Triton ahead-of-time compiler:
This program compiles the kernel with name `kernel-name` in the file at the
provided `path` into self-contained C source-code that embeds the `cubin`
data along with utilities to load, unload and launch the kernel.
signature is provided as a list of (optionally divisibility-hinted) types
or constexpr values, e.g.
`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype.
The resulting entry point will have signature
CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2)
Different such specialized entry points can be combined using the `linker.py` script.
NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter
used to run this `compile.py` script
"""
def main():
# command-line arguments
parser = ArgumentParser(description=desc)
parser.add_argument("path",
help="Path to Python source containing desired kernel in its scope. File will be executed.")
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
required=True)
parser.add_argument(
"--target", "-t", type=str, default=None,
help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
"e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
parser.add_argument("--num-stages", "-ns", type=int, default=3,
help="Number of stages (meta-parameter of the kernel)")
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
cli_args = parser.parse_args()
args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
compile_kernel(args)
def compile_kernel(args: CompileArgs):
out_name = args.out_name if args.out_name else args.kernel_name
out_path = args.out_path if args.out_path else Path(out_name)
# execute python sources and extract functions wrapped in JITFunction
arg_path = Path(args.path)
sys.path.insert(0, str(arg_path.parent))
spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
kernel = getattr(mod, args.kernel_name)
grid = args.grid.split(",")
assert len(grid) == 3
# validate and parse signature
signature = list(map(lambda s: s.strip(" "), args.signature.split(",")))
def hash_signature(signature: List[str]):
m = hashlib.sha256()
m.update(" ".join(signature).encode())
return m.hexdigest()[:8]
meta_sig = f"warps{args.num_warps}xstages{args.num_stages}"
sig_hash = hash_signature(signature + [meta_sig])
def constexpr(s):
try:
ret = int(s)
return ret
except ValueError:
pass
try:
ret = float(s)
return ret
except ValueError:
pass
return None
hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
hints = {k: v for k, v in hints.items() if v is not None}
constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
constants = {k: v for k, v in constants.items() if v is not None}
for key, value in hints.items():
if value == 1:
constants[kernel.arg_names[key[0]]] = value
signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)}
for key in constants:
signature[key] = 'constexpr'
const_sig = 'x'.join([str(v) for v in constants.values()])
doc_string = [f"{k}={v}" for k, v in constants.items()]
doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
# compile ast into cubin
for h in hints.values():
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
if args.target else triton.runtime.driver.active.get_current_target()
backend = triton.compiler.make_backend(target)
kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
options = backend.parse_options(kwargs)
ccinfo = triton.compile(src, target=target, options=options.__dict__)
if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
if ccinfo.metadata.profile_scratch_size > 0:
raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
arg_names = []
arg_types = []
arg_names_not_1 = []
arg_types_not_1 = []
for i, arg_name in enumerate(kernel.arg_names):
if arg_name not in constants:
arg_names.append(arg_name)
arg_types.append(signature[arg_name])
arg_names_not_1.append(arg_name)
arg_types_not_1.append(signature[arg_name])
elif hints.get((i, ), None) == 1:
arg_names.append(arg_name)
arg_types.append("i32")
# dump C stub code
suffix = ''
for i, ty in enumerate(signature.values()):
suffix += str(i)
if hints.get((i, ), None) == 1:
suffix += 'c'
if hints.get((i, ), None) == 16:
suffix += 'd'
func_name = '_'.join([out_name, sig_hash, suffix])
asm = ccinfo.asm[backend.binary_ext] # store binary data once
hex_ = str(binascii.hexlify(asm))[2:-1]
ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
params = {
"kernel_name": func_name,
"triton_kernel_name": args.kernel_name,
"bin_size": len(asm),
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
"num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
"kernel_docstring": doc_string,
"shared": ccinfo.metadata.shared,
"num_warps": args.num_warps,
"algo_info": "_".join([const_sig, meta_sig]),
"gridX": grid[0],
"gridY": grid[1],
"gridZ": grid[2],
"_placeholder": "",
}
output_files = []
backend_name = target.backend
template_dir = Path(__file__).parent / "extra" / backend_name
for template_path in template_dir.glob('compile.*'):
ext = template_path.suffix
output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
with output_file.open("w") as fp:
fp.write(template_path.read_text().format(**params))
output_files.append(output_file)
return func_name, output_files
if __name__ == "__main__":
main()