DriverTrac/venv/lib/python3.12/site-packages/triton/profiler/hooks/hook.py
2025-11-28 09:08:33 +05:30

129 lines
4.7 KiB
Python

from triton.compiler import LazyDict
from abc import abstractmethod
from typing import Dict, Any, Optional
from collections import defaultdict
import triton.knobs as knobs
class Hook:
priority: int = 0
@abstractmethod
def init_handle(self, module: Any, function: Any, name: str, metadata_group: Dict[str, str],
hash: str) -> None: # noqa: D401
raise NotImplementedError
@abstractmethod
def enter(self, metadata: LazyDict) -> None:
raise NotImplementedError
@abstractmethod
def exit(self, metadata: LazyDict) -> None:
raise NotImplementedError
@abstractmethod
def activate(self) -> None:
raise NotImplementedError
@abstractmethod
def deactivate(self) -> None:
raise NotImplementedError
class HookManager:
# active hooks
active_hooks: list[Hook] = []
# session_id -> (hook_type -> active)
session_hooks: Dict[int, Dict[Hook, bool]] = defaultdict(lambda: defaultdict(bool))
@staticmethod
def init_handle(module: Any, function: Any, name: str, metadata_group: Dict[str, str], hash: str) -> None:
for hook in HookManager.active_hooks:
hook.init_handle(module, function, name, metadata_group, hash)
@staticmethod
def enter(metadata: LazyDict) -> None:
for hook in HookManager.active_hooks:
hook.enter(metadata)
@staticmethod
def exit(metadata: LazyDict) -> None:
# It's important to reverse the order of hooks so that we keep the first in last out order
for hook in reversed(HookManager.active_hooks):
hook.exit(metadata)
@staticmethod
def activate(session: Optional[int] = None) -> None:
if session is None:
sessions = HookManager.session_hooks.keys()
else:
sessions = [session]
for session in sessions:
for hook in HookManager.session_hooks[session]:
if hook not in HookManager.active_hooks:
hook.activate()
HookManager.active_hooks.append(hook)
HookManager.session_hooks[session][hook] = True
# Sort active_hooks by priority
HookManager.active_hooks.sort(key=lambda x: x.priority, reverse=True)
@staticmethod
def deactivate(session: Optional[int] = None) -> None:
if session is None:
sessions = HookManager.session_hooks.keys()
else:
sessions = [session]
deactivated_hooks = set()
for session in sessions:
for hook in HookManager.session_hooks[session]:
if hook in HookManager.active_hooks:
deactivated_hooks.add(hook)
HookManager.session_hooks[session][hook] = False
# Check if any other sessions rely on this hook
for hook in deactivated_hooks:
if not any(session_hooks[hook] for session_hooks in HookManager.session_hooks.values()):
hook.deactivate()
HookManager.active_hooks.remove(hook)
@staticmethod
def register(hook: Hook, session: int) -> None:
HookManager.session_hooks[session][hook] = True
if hook not in HookManager.active_hooks:
hook.activate()
HookManager.active_hooks.append(hook)
# Sort active_hooks by priority
HookManager.active_hooks.sort(key=lambda x: x.priority, reverse=True)
# Register the heads
knobs.runtime.kernel_load_end_hook.add(HookManager.init_handle)
knobs.runtime.launch_enter_hook.add(HookManager.enter)
knobs.runtime.launch_exit_hook.add(HookManager.exit)
@staticmethod
def unregister(session: Optional[int] = None) -> None:
if session is not None and session not in HookManager.session_hooks:
return
if session is None:
for hook in HookManager.active_hooks:
hook.deactivate()
HookManager.active_hooks.clear()
HookManager.session_hooks.clear()
else:
popped_hooks = HookManager.session_hooks.pop(session)
# Deactivate hooks that are not used by any other session
for hook, active in popped_hooks.items():
if not active:
continue
if not any(session_hooks[hook] for session_hooks in HookManager.session_hooks.values()):
hook.deactivate()
HookManager.active_hooks.remove(hook)
# Unregister the heads
if not HookManager.active_hooks:
knobs.runtime.kernel_load_end_hook.remove(HookManager.init_handle)
knobs.runtime.launch_enter_hook.remove(HookManager.enter)
knobs.runtime.launch_exit_hook.remove(HookManager.exit)