56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
from __future__ import annotations
|
|
from typing import Callable, Optional
|
|
from concurrent.futures import Executor, as_completed, Future
|
|
from contextvars import ContextVar
|
|
|
|
active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
|
|
|
|
|
|
class FutureKernel:
|
|
|
|
def __init__(self, finalize_compile: Callable, future: Future):
|
|
self.finalize_compile = finalize_compile
|
|
self.kernel = None
|
|
self.future = future
|
|
|
|
def result(self):
|
|
if self.kernel is not None:
|
|
return self.kernel
|
|
|
|
kernel = self.future.result()
|
|
self.finalize_compile(kernel)
|
|
self.kernel = kernel
|
|
return kernel
|
|
|
|
|
|
class AsyncCompileMode:
|
|
|
|
def __init__(self, executor: Executor):
|
|
self.executor = executor
|
|
self.raw_futures = []
|
|
self.future_kernels = {}
|
|
|
|
def submit(self, key, compile_fn, finalize_fn):
|
|
future = self.future_kernels.get(key)
|
|
if future is not None:
|
|
return future
|
|
|
|
future = self.executor.submit(compile_fn)
|
|
future._key = key
|
|
self.raw_futures.append(future)
|
|
future_kernel = FutureKernel(finalize_fn, future)
|
|
self.future_kernels[key] = future_kernel
|
|
return future_kernel
|
|
|
|
def __enter__(self):
|
|
if active_mode.get() is not None:
|
|
raise RuntimeError("Another AsyncCompileMode is already active")
|
|
active_mode.set(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
# Finalize any outstanding compiles
|
|
for future in as_completed(self.raw_futures):
|
|
self.future_kernels[future._key].result()
|
|
active_mode.set(None)
|