DriverTrac/venv/lib/python3.12/site-packages/triton/runtime/cache.py

310 lines
11 KiB
Python

import json
import os
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import base64
import hashlib
import functools
import sysconfig
from triton import __version__, knobs
class CacheManager(ABC):
def __init__(self, key, override=False, dump=False):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = knobs.cache.dump_dir
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = knobs.cache.override_dir
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = knobs.cache.dir
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
# Invalid group data.
if child_paths is None:
return None
result = {}
for c, p in child_paths.items():
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": group})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath
class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""
def __init__(self, key: str):
pass
@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass
@abstractmethod
def put(self, filename: str, data: bytes):
pass
class RedisRemoteCacheBackend(RemoteCacheBackend):
def __init__(self, key):
import redis
self._key = key
self._key_fmt = knobs.cache.redis.key_format
self._redis = redis.Redis(
host=knobs.cache.redis.host,
port=knobs.cache.redis.port,
)
def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)
def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)
class RemoteCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
remote_cache_cls = knobs.cache.remote_manager_class
if not remote_cache_cls:
raise RuntimeError(
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
self._backend = remote_cache_cls(key)
self._override = override
self._dump = dump
# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)
def get_file(self, filename: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)
# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)
def put(self, data, filename: str, binary=True) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)
if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)
grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
result = None
# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)
return result
def put_group(self, filename: str, group: Dict[str, str]):
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)
def _base32(key):
# Assume key is a hex string.
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
def get_cache_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key))
def get_override_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key), override=True)
def get_dump_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key), dump=True)
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
for kw in kwargs:
key = f"{key}-{kwargs.get(kw)}"
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
return _base32(key)
@functools.lru_cache()
def triton_key():
import pkgutil
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
# compiler
path_prefixes = [
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
]
for path, prefix in path_prefixes:
for lib in pkgutil.walk_packages([path], prefix=prefix):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha256()
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
return f'{__version__}' + '-'.join(contents)
def get_cache_key(src, backend, backend_options, env_vars):
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
return key