45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
from typing import Optional, Protocol
|
|
from contextvars import ContextVar
|
|
|
|
|
|
class Buffer(Protocol):
|
|
|
|
def data_ptr(self) -> int:
|
|
...
|
|
|
|
|
|
class Allocator(Protocol):
|
|
|
|
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
...
|
|
|
|
|
|
class NullAllocator:
|
|
|
|
def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
|
|
raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
|
|
"Use triton.set_allocator to specify an allocator.")
|
|
|
|
|
|
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
|
|
|
|
|
|
def set_allocator(allocator: Allocator):
|
|
"""
|
|
The allocator function is called during kernel launch for kernels that
|
|
require additional global memory workspace.
|
|
"""
|
|
_allocator.set(allocator)
|
|
|
|
|
|
_profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
|
|
|
|
|
|
def set_profile_allocator(allocator: Optional[Allocator]):
|
|
"""
|
|
The profile allocator function is called before kernel launch for kernels
|
|
that require additional global memory workspace.
|
|
"""
|
|
global _profile_allocator
|
|
_profile_allocator.set(allocator)
|