DriverTrac/venv/lib/python3.12/site-packages/triton/backends/driver.py

67 lines
1.8 KiB
Python

from abc import ABCMeta, abstractmethod
from typing import Callable, List, Protocol, Sequence
class Benchmarker(Protocol):
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
pass
class DriverBase(metaclass=ABCMeta):
@classmethod
@abstractmethod
def is_active(self):
pass
@abstractmethod
def map_python_to_cpp_type(self, ty: str) -> str:
"""
Converts a Triton type string to its corresponding C++ type string for this backend.
Args:
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
Returns:
str: The C++ type string.
"""
pass
@abstractmethod
def get_current_target(self):
pass
@abstractmethod
def get_active_torch_device(self):
pass
@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Return the benchmarking function that this backend should use by default.
"""
raise NotImplementedError
def __init__(self) -> None:
pass
class GPUDriver(DriverBase):
def __init__(self):
# TODO: support other frameworks than torch
import torch
self.get_device_capability = torch.cuda.get_device_capability
try:
from torch._C import _cuda_getCurrentRawStream
self.get_current_stream = _cuda_getCurrentRawStream
except ImportError:
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
self.get_current_device = torch.cuda.current_device
self.set_current_device = torch.cuda.set_device
# TODO: remove once TMA is cleaned up
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args