DriverTrac/venv/lib/python3.12/site-packages/triton/experimental/gluon/_runtime.py
2025-11-28 09:08:33 +05:30

103 lines
3.4 KiB
Python

from __future__ import annotations
from triton.compiler.compiler import ASTSource
from triton.backends.compiler import Language
from triton.runtime.jit import JITFunction, constexpr_function
from typing import TypeVar, Optional, Callable, Iterable, Union
from triton._C.libtriton import ir
T = TypeVar("T")
__all__ = ["constexpr_function", "jit"]
class GluonASTSource(ASTSource):
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
super().__init__(fn, signature, constexprs, attrs)
self.language = Language.GLUON
self.ext = "ttgir"
def make_ir(self, target, options, codegen_fns, module_map, context):
from triton.compiler.compiler import make_backend
from triton.compiler.code_generator import ast_to_ttir
builder = ir.builder(context)
module = builder.create_module()
# Assign module attributes eagerly, as they are needed to verify layouts
backend = make_backend(target)
target = backend.get_target_name(options)
module.set_attr("ttg.target", builder.get_string_attr(target))
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
is_cuda = options.backend_name == "cuda"
if is_cuda and options.maxnreg is not None:
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map, module=module)
return module
class GluonJITFunction(JITFunction[T]):
def create_binder(self):
result = super().create_binder()
self.ASTSource = GluonASTSource
return result
def is_gluon(self):
return True
def jit(
fn: Optional[T] = None,
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int | str]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return GluonJITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
debug=debug,
noinline=noinline,
repr=repr,
launch_metadata=launch_metadata,
)
if fn is not None:
return decorator(fn)
else:
return decorator