67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
from abc import ABCMeta, abstractmethod
|
|
from typing import Callable, List, Protocol, Sequence
|
|
|
|
|
|
class Benchmarker(Protocol):
|
|
|
|
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
|
|
pass
|
|
|
|
|
|
class DriverBase(metaclass=ABCMeta):
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def is_active(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def map_python_to_cpp_type(self, ty: str) -> str:
|
|
"""
|
|
Converts a Triton type string to its corresponding C++ type string for this backend.
|
|
|
|
Args:
|
|
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
|
|
|
|
Returns:
|
|
str: The C++ type string.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_current_target(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_active_torch_device(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_benchmarker(self) -> Benchmarker:
|
|
"""
|
|
Return the benchmarking function that this backend should use by default.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
|
|
class GPUDriver(DriverBase):
|
|
|
|
def __init__(self):
|
|
# TODO: support other frameworks than torch
|
|
import torch
|
|
self.get_device_capability = torch.cuda.get_device_capability
|
|
try:
|
|
from torch._C import _cuda_getCurrentRawStream
|
|
self.get_current_stream = _cuda_getCurrentRawStream
|
|
except ImportError:
|
|
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
|
|
self.get_current_device = torch.cuda.current_device
|
|
self.set_current_device = torch.cuda.set_device
|
|
|
|
# TODO: remove once TMA is cleaned up
|
|
def assemble_tensormap_to_arg(self, tensormaps_info, args):
|
|
return args
|