94 lines
3.8 KiB
Python
94 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import hashlib
|
|
import importlib.util
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sysconfig
|
|
import tempfile
|
|
|
|
from types import ModuleType
|
|
|
|
from .cache import get_cache_manager
|
|
from .. import knobs
|
|
|
|
|
|
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
|
|
ccflags: list[str]) -> str:
|
|
if impl := knobs.build.impl:
|
|
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
|
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
|
cc = os.environ.get("CC")
|
|
if cc is None:
|
|
clang = shutil.which("clang")
|
|
gcc = shutil.which("gcc")
|
|
cc = gcc if gcc is not None else clang
|
|
if cc is None:
|
|
raise RuntimeError(
|
|
"Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.")
|
|
# This function was renamed and made public in Python 3.10
|
|
if hasattr(sysconfig, 'get_default_scheme'):
|
|
scheme = sysconfig.get_default_scheme()
|
|
else:
|
|
scheme = sysconfig._get_default_scheme() # type: ignore
|
|
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
|
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
|
if scheme == 'posix_local':
|
|
scheme = 'posix_prefix'
|
|
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
|
custom_backend_dirs = knobs.build.backend_dirs
|
|
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
|
|
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
|
|
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]
|
|
cc_cmd += [f'-l{lib}' for lib in libraries]
|
|
cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
|
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
|
cc_cmd.extend(ccflags)
|
|
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
|
|
return so
|
|
|
|
|
|
@functools.lru_cache
|
|
def platform_key() -> str:
|
|
from platform import machine, system, architecture
|
|
return ",".join([machine(), system(), *architecture()])
|
|
|
|
|
|
def _load_module_from_path(name: str, path: str) -> ModuleType:
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
if not spec or not spec.loader:
|
|
raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
|
|
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
|
|
ccflags: list[str] | None = None) -> ModuleType:
|
|
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
cache = get_cache_manager(key)
|
|
suffix = sysconfig.get_config_var("EXT_SUFFIX")
|
|
cache_path = cache.get_file(f"{name}{suffix}")
|
|
|
|
if cache_path is not None:
|
|
try:
|
|
return _load_module_from_path(name, cache_path)
|
|
except (RuntimeError, ImportError):
|
|
log = logging.getLogger(__name__)
|
|
log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
src_path = os.path.join(tmpdir, name + ".c")
|
|
with open(src_path, "w") as f:
|
|
f.write(src)
|
|
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
|
|
with open(so, "rb") as f:
|
|
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
|
|
|
|
return _load_module_from_path(name, cache_path)
|