98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
import functools
|
|
import os
|
|
import inspect
|
|
import subprocess
|
|
import tempfile
|
|
|
|
import triton
|
|
from triton.compiler import ASTSource, make_backend
|
|
from triton.backends.compiler import GPUTarget
|
|
from triton.experimental.gluon._runtime import GluonASTSource
|
|
from triton.runtime.jit import create_function_from_signature
|
|
from triton._C.libtriton import ir
|
|
|
|
# ===-----------------------------------------------------------------------===#
|
|
# filecheck_test
|
|
# ===-----------------------------------------------------------------------===#
|
|
|
|
# Stub target for testing the frontend.
|
|
stub_target = GPUTarget("cuda", 100, 32)
|
|
|
|
triton_dir = os.path.dirname(__file__)
|
|
filecheck_path = os.path.join(triton_dir, "FileCheck")
|
|
|
|
|
|
class MatchError(ValueError):
|
|
|
|
def __init__(self, message, module_str):
|
|
super().__init__(message)
|
|
self.module_str = module_str
|
|
|
|
def __str__(self):
|
|
return f"{super().__str__()}\n{self.module_str}"
|
|
|
|
|
|
def run_filecheck(name, module_str, check_template):
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
temp_module = os.path.join(tempdir, "module")
|
|
with open(temp_module, "w") as temp:
|
|
temp.write(module_str)
|
|
|
|
temp_expected = os.path.join(tempdir, "expected")
|
|
with open(temp_expected, "w") as temp:
|
|
temp.write(check_template)
|
|
|
|
try:
|
|
subprocess.check_output(
|
|
[filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
|
|
stderr=subprocess.STDOUT)
|
|
except subprocess.CalledProcessError as error:
|
|
decoded = error.output.decode('unicode_escape')
|
|
raise ValueError(decoded)
|
|
|
|
|
|
def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
|
|
if "sanitize_overflow" not in kwargs:
|
|
kwargs = dict(kwargs)
|
|
kwargs["sanitize_overflow"] = False
|
|
backend = make_backend(target)
|
|
binder = create_function_from_signature(
|
|
kernel_fn.signature,
|
|
kernel_fn.params,
|
|
backend,
|
|
)
|
|
|
|
bound_args, specialization, options = binder(*args, **kwargs)
|
|
options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
|
|
source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
|
|
src = source_cls(kernel_fn, signature, constexprs, attrs)
|
|
|
|
context = ir.context()
|
|
ir.load_dialects(context)
|
|
backend.load_dialects(context)
|
|
|
|
codegen_fns = backend.get_codegen_implementation(options)
|
|
module_map = backend.get_module_map()
|
|
module = src.make_ir(target, options, codegen_fns, module_map, context)
|
|
assert module.verify()
|
|
return module
|
|
|
|
|
|
def run_filecheck_test(kernel_fn):
|
|
assert isinstance(kernel_fn, triton.runtime.JITFunction)
|
|
check_template = inspect.getsource(kernel_fn.fn)
|
|
if check_template is None:
|
|
raise ValueError("kernel function must have a docstring with FileCheck template")
|
|
mlir_module = run_parser(kernel_fn)
|
|
|
|
run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
|
|
|
|
|
|
def filecheck_test(fn):
|
|
|
|
@functools.wraps(fn)
|
|
def test_fn():
|
|
run_filecheck_test(fn)
|
|
|
|
return test_fn
|