310 lines
11 KiB
Python
310 lines
11 KiB
Python
import json
|
|
import os
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional
|
|
import base64
|
|
import hashlib
|
|
import functools
|
|
import sysconfig
|
|
|
|
from triton import __version__, knobs
|
|
|
|
|
|
class CacheManager(ABC):
|
|
|
|
def __init__(self, key, override=False, dump=False):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_file(self, filename) -> Optional[str]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def put(self, data, filename, binary=True) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def put_group(self, filename: str, group: Dict[str, str]):
|
|
pass
|
|
|
|
|
|
class FileCacheManager(CacheManager):
|
|
|
|
def __init__(self, key, override=False, dump=False):
|
|
self.key = key
|
|
self.lock_path = None
|
|
if dump:
|
|
self.cache_dir = knobs.cache.dump_dir
|
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
elif override:
|
|
self.cache_dir = knobs.cache.override_dir
|
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
else:
|
|
# create cache directory if it doesn't exist
|
|
self.cache_dir = knobs.cache.dir
|
|
if self.cache_dir:
|
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
else:
|
|
raise RuntimeError("Could not create or locate cache dir")
|
|
|
|
def _make_path(self, filename) -> str:
|
|
return os.path.join(self.cache_dir, filename)
|
|
|
|
def has_file(self, filename) -> bool:
|
|
if not self.cache_dir:
|
|
raise RuntimeError("Could not create or locate cache dir")
|
|
return os.path.exists(self._make_path(filename))
|
|
|
|
def get_file(self, filename) -> Optional[str]:
|
|
if self.has_file(filename):
|
|
return self._make_path(filename)
|
|
else:
|
|
return None
|
|
|
|
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
grp_filename = f"__grp__{filename}"
|
|
if not self.has_file(grp_filename):
|
|
return None
|
|
grp_filepath = self._make_path(grp_filename)
|
|
with open(grp_filepath) as f:
|
|
grp_data = json.load(f)
|
|
child_paths = grp_data.get("child_paths", None)
|
|
# Invalid group data.
|
|
if child_paths is None:
|
|
return None
|
|
result = {}
|
|
for c, p in child_paths.items():
|
|
if os.path.exists(p):
|
|
result[c] = p
|
|
return result
|
|
|
|
# Note a group of pushed files as being part of a group
|
|
def put_group(self, filename: str, group: Dict[str, str]) -> str:
|
|
if not self.cache_dir:
|
|
raise RuntimeError("Could not create or locate cache dir")
|
|
grp_contents = json.dumps({"child_paths": group})
|
|
grp_filename = f"__grp__{filename}"
|
|
return self.put(grp_contents, grp_filename, binary=False)
|
|
|
|
def put(self, data, filename, binary=True) -> str:
|
|
if not self.cache_dir:
|
|
raise RuntimeError("Could not create or locate cache dir")
|
|
binary = isinstance(data, bytes)
|
|
if not binary:
|
|
data = str(data)
|
|
assert self.lock_path is not None
|
|
filepath = self._make_path(filename)
|
|
# Random ID to avoid any collisions
|
|
rnd_id = str(uuid.uuid4())
|
|
# we use the PID in case a bunch of these around so we can see what PID made it
|
|
pid = os.getpid()
|
|
# use temp dir to be robust against program interruptions
|
|
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
temp_path = os.path.join(temp_dir, filename)
|
|
|
|
mode = "wb" if binary else "w"
|
|
with open(temp_path, mode) as f:
|
|
f.write(data)
|
|
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
|
|
# so filepath cannot see a partial write
|
|
os.replace(temp_path, filepath)
|
|
os.removedirs(temp_dir)
|
|
return filepath
|
|
|
|
|
|
class RemoteCacheBackend:
|
|
"""
|
|
A backend implementation for accessing a remote/distributed cache.
|
|
"""
|
|
|
|
def __init__(self, key: str):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get(self, filenames: List[str]) -> Dict[str, bytes]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def put(self, filename: str, data: bytes):
|
|
pass
|
|
|
|
|
|
class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|
|
|
def __init__(self, key):
|
|
import redis
|
|
self._key = key
|
|
self._key_fmt = knobs.cache.redis.key_format
|
|
self._redis = redis.Redis(
|
|
host=knobs.cache.redis.host,
|
|
port=knobs.cache.redis.port,
|
|
)
|
|
|
|
def _get_key(self, filename: str) -> str:
|
|
return self._key_fmt.format(key=self._key, filename=filename)
|
|
|
|
def get(self, filenames: List[str]) -> Dict[str, str]:
|
|
results = self._redis.mget([self._get_key(f) for f in filenames])
|
|
return {filename: result for filename, result in zip(filenames, results) if result is not None}
|
|
|
|
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
|
|
self._redis.set(self._get_key(filename), data)
|
|
|
|
|
|
class RemoteCacheManager(CacheManager):
|
|
|
|
def __init__(self, key, override=False, dump=False):
|
|
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
|
|
remote_cache_cls = knobs.cache.remote_manager_class
|
|
if not remote_cache_cls:
|
|
raise RuntimeError(
|
|
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
|
|
self._backend = remote_cache_cls(key)
|
|
|
|
self._override = override
|
|
self._dump = dump
|
|
|
|
# Use a `FileCacheManager` to materialize remote cache paths locally.
|
|
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
|
|
|
|
def _materialize(self, filename: str, data: bytes):
|
|
# We use a backing `FileCacheManager` to provide the materialized data.
|
|
return self._file_cache_manager.put(data, filename, binary=True)
|
|
|
|
def get_file(self, filename: str) -> Optional[str]:
|
|
# We don't handle the dump/override cases.
|
|
if self._dump or self._override:
|
|
return self._file_cache_manager.get_file(filename)
|
|
|
|
# We always check the remote cache backend -- even if our internal file-
|
|
# based cache has the item -- to make sure LRU accounting works as
|
|
# expected.
|
|
results = self._backend.get([filename])
|
|
if len(results) == 0:
|
|
return None
|
|
(_, data), = results.items()
|
|
return self._materialize(filename, data)
|
|
|
|
def put(self, data, filename: str, binary=True) -> str:
|
|
# We don't handle the dump/override cases.
|
|
if self._dump or self._override:
|
|
return self._file_cache_manager.put(data, filename, binary=binary)
|
|
|
|
if not isinstance(data, bytes):
|
|
data = str(data).encode("utf-8")
|
|
self._backend.put(filename, data)
|
|
return self._materialize(filename, data)
|
|
|
|
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
|
# We don't handle the dump/override cases.
|
|
if self._dump or self._override:
|
|
return self._file_cache_manager.get_group(filename)
|
|
|
|
grp_filename = f"__grp__{filename}"
|
|
grp_filepath = self.get_file(grp_filename)
|
|
if grp_filepath is None:
|
|
return None
|
|
with open(grp_filepath) as f:
|
|
grp_data = json.load(f)
|
|
child_paths = grp_data.get("child_paths", None)
|
|
|
|
result = None
|
|
|
|
# Found group data.
|
|
if child_paths is not None:
|
|
result = {}
|
|
for child_path, data in self._backend.get(child_paths).items():
|
|
result[child_path] = self._materialize(child_path, data)
|
|
|
|
return result
|
|
|
|
def put_group(self, filename: str, group: Dict[str, str]):
|
|
# We don't handle the dump/override cases.
|
|
if self._dump or self._override:
|
|
return self._file_cache_manager.put_group(filename, group)
|
|
|
|
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
|
grp_filename = f"__grp__{filename}"
|
|
return self.put(grp_contents, grp_filename)
|
|
|
|
|
|
def _base32(key):
|
|
# Assume key is a hex string.
|
|
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
|
|
|
|
def get_cache_manager(key) -> CacheManager:
|
|
cls = knobs.cache.manager_class or FileCacheManager
|
|
return cls(_base32(key))
|
|
|
|
|
|
def get_override_manager(key) -> CacheManager:
|
|
cls = knobs.cache.manager_class or FileCacheManager
|
|
return cls(_base32(key), override=True)
|
|
|
|
|
|
def get_dump_manager(key) -> CacheManager:
|
|
cls = knobs.cache.manager_class or FileCacheManager
|
|
return cls(_base32(key), dump=True)
|
|
|
|
|
|
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
# Get unique key for the compiled code
|
|
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
|
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
|
|
for kw in kwargs:
|
|
key = f"{key}-{kwargs.get(kw)}"
|
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
return _base32(key)
|
|
|
|
|
|
@functools.lru_cache()
|
|
def triton_key():
|
|
import pkgutil
|
|
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
contents = []
|
|
# frontend
|
|
with open(__file__, "rb") as f:
|
|
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
# compiler
|
|
path_prefixes = [
|
|
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
]
|
|
for path, prefix in path_prefixes:
|
|
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
|
|
# backend
|
|
libtriton_hash = hashlib.sha256()
|
|
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
while True:
|
|
chunk = f.read(1024**2)
|
|
if not chunk:
|
|
break
|
|
libtriton_hash.update(chunk)
|
|
contents.append(libtriton_hash.hexdigest())
|
|
# language
|
|
language_path = os.path.join(TRITON_PATH, 'language')
|
|
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
return f'{__version__}' + '-'.join(contents)
|
|
|
|
|
|
def get_cache_key(src, backend, backend_options, env_vars):
|
|
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
|
|
return key
|