DriverTrac/venv/lib/python3.12/site-packages/triton/profiler/scope.py
2025-11-28 09:08:33 +05:30

130 lines
3.8 KiB
Python

import threading
import time
from functools import wraps
from typing import Optional, Union
from .flags import get_profiling_on
from triton._C.libproton import proton as libproton
thread_local_scopes = threading.local()
MetricValueType = Union[float, int]
class scope:
"""
A context manager and decorator for entering and exiting a scope.
Usage:
context manager:
```python
with proton.scope("test0", {metric_name: metric_value}):
foo[1,](x, y)
```
decorator:
```python
@proton.scope("test0", {metric_name: metric_value})
def foo(x, y):
...
```
Args:
name (str): The name of the scope.
metrics (dict[str, float], optional): The metrics of the scope. Default is None.
"""
def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None) -> None:
self.name = name
self.metrics = metrics
self.id = None
def _enter_scope(self):
if not get_profiling_on():
return
self.id = libproton.record_scope()
libproton.enter_scope(self.id, self.name)
if self.metrics:
libproton.add_metrics(self.id, self.metrics)
def _exit_scope(self):
if not get_profiling_on() or self.id is None:
return
libproton.exit_scope(self.id, self.name)
def __enter__(self):
self._enter_scope()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._exit_scope()
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
self._enter_scope()
try:
return func(*args, **kwargs)
finally:
self._exit_scope()
return wrapper
class cpu_timed_scope(scope):
"""
A scope that measures elapsed time (cpu_time).
Args:
name (str): The name of the scope.
metrics (dict[str, float], optional): Additional metrics to add. Default is None.
"""
def __init__(self, name: str, metrics: Optional[dict[str, float]] = None) -> None:
super().__init__(name, metrics)
self.start_time = None
if metrics and "cpu_time" in metrics:
raise ValueError("The metric name 'cpu_time' is reserved.")
def _enter_scope(self):
if not get_profiling_on():
return
self.start_time = time.time_ns()
super()._enter_scope()
def _exit_scope(self):
if not get_profiling_on():
return
super()._exit_scope()
if self.start_time is not None:
cpu_time = time.time_ns() - self.start_time
libproton.add_metrics(self.id, {"cpu_time (ns)(exc)": cpu_time})
def enter_scope(name: str, *, metrics: Optional[dict[str, MetricValueType]] = None) -> Optional[int]:
if not get_profiling_on():
return None
id = libproton.record_scope()
thread_local_scopes.scopes = getattr(thread_local_scopes, "scopes", [])
thread_local_scopes.scopes.append((id, name))
libproton.enter_scope(id, name)
if metrics:
libproton.add_metrics(id, metrics)
return id
def exit_scope(name: Optional[str] = None, *, metrics: Optional[dict[str, MetricValueType]] = None) -> Optional[int]:
# `name` is an optional argument here, only to match the counterpart in enter_scope to make the API consistent with `proton.language.exit_scope`
if not get_profiling_on():
return None
id, popped_name = thread_local_scopes.scopes.pop()
if name and name != popped_name:
raise ValueError(f"Scope name mismatch: {name} != {popped_name}")
elif not name:
name = popped_name
libproton.exit_scope(id, name)
if metrics:
libproton.add_metrics(id, metrics)
return id