91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from abc import ABCMeta, abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Dict, Union
|
|
from types import ModuleType
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GPUTarget(object):
|
|
# Target backend, e.g., cuda, hip
|
|
backend: str
|
|
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
|
|
arch: Union[int, str]
|
|
warp_size: int
|
|
|
|
|
|
class Language(Enum):
|
|
"""The input language being compiled by the backend."""
|
|
TRITON = 0
|
|
GLUON = 1
|
|
|
|
|
|
class BaseBackend(metaclass=ABCMeta):
|
|
|
|
def __init__(self, target: GPUTarget) -> None:
|
|
self.target = target
|
|
assert self.supports_target(target)
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def supports_target(target: GPUTarget):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def hash(self) -> str:
|
|
"""Returns a unique identifier for this backend"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def parse_options(self, options: dict) -> object:
|
|
"""
|
|
Converts an `options` dictionary into an arbitrary object and returns it.
|
|
This function may contain target-specific heuristics and check the legality of the provided options
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def add_stages(self, stages: dict, options: object) -> None:
|
|
"""
|
|
Populates `stages` dictionary with entries of the form:
|
|
ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes]
|
|
The value of each entry may populate a `metadata` dictionary.
|
|
Stages will be run sequentially (in inseriton order) and can communicate using `metadata`.
|
|
All stages are expected to return a `str` object, except for the last stage which returns
|
|
a `bytes` object for execution by the launcher.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def load_dialects(self, context):
|
|
"""
|
|
Load additional MLIR dialects into the provided `context`
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_module_map(self) -> Dict[str, ModuleType]:
|
|
"""
|
|
Return a map of interface modules to their device-specific implementations
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def parse_attr(desc):
|
|
assert isinstance(desc, str)
|
|
ret = []
|
|
if "D" in desc:
|
|
ret += [["tt.divisibility", 16]]
|
|
return ret
|
|
|
|
@staticmethod
|
|
def get_arg_specialization(arg, ty, **kwargs):
|
|
"""
|
|
Return a string unique to each possible specialization of the argument
|
|
"""
|
|
if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
|
|
return "D"
|
|
if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
|
|
return "D"
|
|
return ""
|