DriverTrac/venv/lib/python3.12/site-packages/triton/runtime/build.py

94 lines
3.8 KiB
Python

from __future__ import annotations
import functools
import hashlib
import importlib.util
import logging
import os
import shutil
import subprocess
import sysconfig
import tempfile
from types import ModuleType
from .cache import get_cache_manager
from .. import knobs
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
ccflags: list[str]) -> str:
if impl := knobs.build.impl:
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
cc = os.environ.get("CC")
if cc is None:
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError(
"Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme() # type: ignore
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
custom_backend_dirs = knobs.build.backend_dirs
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
cc_cmd.extend(ccflags)
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
return so
@functools.lru_cache
def platform_key() -> str:
from platform import machine, system, architecture
return ",".join([machine(), system(), *architecture()])
def _load_module_from_path(name: str, path: str) -> ModuleType:
spec = importlib.util.spec_from_file_location(name, path)
if not spec or not spec.loader:
raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
ccflags: list[str] | None = None) -> ModuleType:
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
suffix = sysconfig.get_config_var("EXT_SUFFIX")
cache_path = cache.get_file(f"{name}{suffix}")
if cache_path is not None:
try:
return _load_module_from_path(name, cache_path)
except (RuntimeError, ImportError):
log = logging.getLogger(__name__)
log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, name + ".c")
with open(src_path, "w") as f:
f.write(src)
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
return _load_module_from_path(name, cache_path)