103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
from __future__ import annotations
|
|
from triton.compiler.compiler import ASTSource
|
|
from triton.backends.compiler import Language
|
|
from triton.runtime.jit import JITFunction, constexpr_function
|
|
from typing import TypeVar, Optional, Callable, Iterable, Union
|
|
from triton._C.libtriton import ir
|
|
|
|
T = TypeVar("T")
|
|
|
|
__all__ = ["constexpr_function", "jit"]
|
|
|
|
|
|
class GluonASTSource(ASTSource):
|
|
|
|
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
|
|
super().__init__(fn, signature, constexprs, attrs)
|
|
self.language = Language.GLUON
|
|
self.ext = "ttgir"
|
|
|
|
def make_ir(self, target, options, codegen_fns, module_map, context):
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.compiler.code_generator import ast_to_ttir
|
|
|
|
builder = ir.builder(context)
|
|
module = builder.create_module()
|
|
|
|
# Assign module attributes eagerly, as they are needed to verify layouts
|
|
backend = make_backend(target)
|
|
target = backend.get_target_name(options)
|
|
|
|
module.set_attr("ttg.target", builder.get_string_attr(target))
|
|
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
|
|
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
|
|
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
|
|
|
|
is_cuda = options.backend_name == "cuda"
|
|
if is_cuda and options.maxnreg is not None:
|
|
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
|
|
|
|
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
|
|
module_map=module_map, module=module)
|
|
return module
|
|
|
|
|
|
class GluonJITFunction(JITFunction[T]):
|
|
|
|
def create_binder(self):
|
|
result = super().create_binder()
|
|
self.ASTSource = GluonASTSource
|
|
return result
|
|
|
|
def is_gluon(self):
|
|
return True
|
|
|
|
|
|
def jit(
|
|
fn: Optional[T] = None,
|
|
*,
|
|
version=None,
|
|
repr: Optional[Callable] = None,
|
|
launch_metadata: Optional[Callable] = None,
|
|
do_not_specialize: Optional[Iterable[int | str]] = None,
|
|
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
|
|
debug: Optional[bool] = None,
|
|
noinline: Optional[bool] = None,
|
|
) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
|
|
"""
|
|
Decorator for JIT-compiling a function using the Triton compiler.
|
|
|
|
:note: When a jit'd function is called, arguments are
|
|
implicitly converted to pointers if they have a :code:`.data_ptr()` method
|
|
and a `.dtype` attribute.
|
|
|
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
|
|
|
* python primitives,
|
|
* builtins within the triton package,
|
|
* arguments to this function,
|
|
* other jit'd functions
|
|
|
|
:param fn: the function to be jit-compiled
|
|
:type fn: Callable
|
|
"""
|
|
|
|
def decorator(fn: T) -> JITFunction[T]:
|
|
assert callable(fn)
|
|
return GluonJITFunction(
|
|
fn,
|
|
version=version,
|
|
do_not_specialize=do_not_specialize,
|
|
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
|
|
debug=debug,
|
|
noinline=noinline,
|
|
repr=repr,
|
|
launch_metadata=launch_metadata,
|
|
)
|
|
|
|
if fn is not None:
|
|
return decorator(fn)
|
|
|
|
else:
|
|
return decorator
|