211 lines
8.5 KiB
Python
211 lines
8.5 KiB
Python
import binascii
|
|
import hashlib
|
|
import importlib.util
|
|
import sys
|
|
from argparse import ArgumentParser
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import triton
|
|
import triton.backends
|
|
|
|
|
|
@dataclass
|
|
class CompileArgs:
|
|
'''
|
|
A class to contain arguments from command-line parser.
|
|
'''
|
|
path: str = ''
|
|
kernel_name: str = ''
|
|
signature: str = ''
|
|
grid: str = ''
|
|
target: str | None = None
|
|
num_warps: int = 1
|
|
num_stages: int = 3
|
|
out_name: str | None = None
|
|
out_path: Path | None = None
|
|
|
|
|
|
desc = """
|
|
Triton ahead-of-time compiler:
|
|
|
|
This program compiles the kernel with name `kernel-name` in the file at the
|
|
provided `path` into self-contained C source-code that embeds the `cubin`
|
|
data along with utilities to load, unload and launch the kernel.
|
|
|
|
signature is provided as a list of (optionally divisibility-hinted) types
|
|
or constexpr values, e.g.
|
|
|
|
`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
|
|
|
|
will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
|
|
Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
|
|
and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype.
|
|
|
|
The resulting entry point will have signature
|
|
|
|
CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2)
|
|
|
|
Different such specialized entry points can be combined using the `linker.py` script.
|
|
|
|
NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter
|
|
used to run this `compile.py` script
|
|
"""
|
|
|
|
|
|
def main():
|
|
# command-line arguments
|
|
parser = ArgumentParser(description=desc)
|
|
parser.add_argument("path",
|
|
help="Path to Python source containing desired kernel in its scope. File will be executed.")
|
|
parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
|
|
required=True)
|
|
parser.add_argument(
|
|
"--target", "-t", type=str, default=None,
|
|
help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
|
|
"e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
|
|
parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
|
|
parser.add_argument("--num-stages", "-ns", type=int, default=3,
|
|
help="Number of stages (meta-parameter of the kernel)")
|
|
parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
|
|
parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
|
|
parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
|
|
parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
|
|
cli_args = parser.parse_args()
|
|
args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
|
|
compile_kernel(args)
|
|
|
|
|
|
def compile_kernel(args: CompileArgs):
|
|
out_name = args.out_name if args.out_name else args.kernel_name
|
|
out_path = args.out_path if args.out_path else Path(out_name)
|
|
|
|
# execute python sources and extract functions wrapped in JITFunction
|
|
arg_path = Path(args.path)
|
|
sys.path.insert(0, str(arg_path.parent))
|
|
spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
kernel = getattr(mod, args.kernel_name)
|
|
grid = args.grid.split(",")
|
|
assert len(grid) == 3
|
|
|
|
# validate and parse signature
|
|
signature = list(map(lambda s: s.strip(" "), args.signature.split(",")))
|
|
|
|
def hash_signature(signature: List[str]):
|
|
m = hashlib.sha256()
|
|
m.update(" ".join(signature).encode())
|
|
return m.hexdigest()[:8]
|
|
|
|
meta_sig = f"warps{args.num_warps}xstages{args.num_stages}"
|
|
sig_hash = hash_signature(signature + [meta_sig])
|
|
|
|
def constexpr(s):
|
|
try:
|
|
ret = int(s)
|
|
return ret
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
ret = float(s)
|
|
return ret
|
|
except ValueError:
|
|
pass
|
|
return None
|
|
|
|
hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
|
|
hints = {k: v for k, v in hints.items() if v is not None}
|
|
constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
|
|
constants = {k: v for k, v in constants.items() if v is not None}
|
|
for key, value in hints.items():
|
|
if value == 1:
|
|
constants[kernel.arg_names[key[0]]] = value
|
|
signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)}
|
|
for key in constants:
|
|
signature[key] = 'constexpr'
|
|
const_sig = 'x'.join([str(v) for v in constants.values()])
|
|
doc_string = [f"{k}={v}" for k, v in constants.items()]
|
|
doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
|
|
# compile ast into cubin
|
|
for h in hints.values():
|
|
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
|
|
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
|
|
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
|
|
|
|
target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
|
|
if args.target else triton.runtime.driver.active.get_current_target()
|
|
backend = triton.compiler.make_backend(target)
|
|
kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
|
|
options = backend.parse_options(kwargs)
|
|
ccinfo = triton.compile(src, target=target, options=options.__dict__)
|
|
|
|
if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
|
|
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
|
|
if ccinfo.metadata.profile_scratch_size > 0:
|
|
raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
|
|
|
|
arg_names = []
|
|
arg_types = []
|
|
arg_names_not_1 = []
|
|
arg_types_not_1 = []
|
|
for i, arg_name in enumerate(kernel.arg_names):
|
|
if arg_name not in constants:
|
|
arg_names.append(arg_name)
|
|
arg_types.append(signature[arg_name])
|
|
arg_names_not_1.append(arg_name)
|
|
arg_types_not_1.append(signature[arg_name])
|
|
elif hints.get((i, ), None) == 1:
|
|
arg_names.append(arg_name)
|
|
arg_types.append("i32")
|
|
|
|
# dump C stub code
|
|
suffix = ''
|
|
for i, ty in enumerate(signature.values()):
|
|
suffix += str(i)
|
|
if hints.get((i, ), None) == 1:
|
|
suffix += 'c'
|
|
if hints.get((i, ), None) == 16:
|
|
suffix += 'd'
|
|
func_name = '_'.join([out_name, sig_hash, suffix])
|
|
asm = ccinfo.asm[backend.binary_ext] # store binary data once
|
|
|
|
hex_ = str(binascii.hexlify(asm))[2:-1]
|
|
|
|
ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
|
|
|
|
params = {
|
|
"kernel_name": func_name,
|
|
"triton_kernel_name": args.kernel_name,
|
|
"bin_size": len(asm),
|
|
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
|
|
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
|
|
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
|
|
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
|
|
"num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
|
|
"kernel_docstring": doc_string,
|
|
"shared": ccinfo.metadata.shared,
|
|
"num_warps": args.num_warps,
|
|
"algo_info": "_".join([const_sig, meta_sig]),
|
|
"gridX": grid[0],
|
|
"gridY": grid[1],
|
|
"gridZ": grid[2],
|
|
"_placeholder": "",
|
|
}
|
|
output_files = []
|
|
backend_name = target.backend
|
|
template_dir = Path(__file__).parent / "extra" / backend_name
|
|
for template_path in template_dir.glob('compile.*'):
|
|
ext = template_path.suffix
|
|
output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
|
|
with output_file.open("w") as fp:
|
|
fp.write(template_path.read_text().format(**params))
|
|
output_files.append(output_file)
|
|
|
|
return func_name, output_files
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|