DriverTrac/venv/lib/python3.12/site-packages/triton/profiler/mode.py
2025-11-28 09:08:33 +05:30

124 lines
4.3 KiB
Python

from dataclasses import dataclass, field
from triton._C.libtriton import proton as triton_proton
from typing import List
from enum import Enum
metric_types = {"cycle": triton_proton.METRIC_TYPE.CYCLE}
buffer_strategies = {
"circular": triton_proton.BUFFER_STRATEGY.CIRCULAR,
"flush": triton_proton.BUFFER_STRATEGY.FLUSH,
}
buffer_types = {
"shared": triton_proton.BUFFER_TYPE.SHARED,
"global": triton_proton.BUFFER_TYPE.GLOBAL,
}
sampling_strategies = {
"none": triton_proton.SAMPLING_STRATEGY.NONE,
"selective": triton_proton.SAMPLING_STRATEGY.SELECTIVE,
}
granularities = {
"cta": triton_proton.GRANULARITY.CTA,
"warp": triton_proton.GRANULARITY.WARP,
"warp_2": triton_proton.GRANULARITY.WARP_2,
"warp_4": triton_proton.GRANULARITY.WARP_4,
"warp_8": triton_proton.GRANULARITY.WARP_8,
"warp_group": triton_proton.GRANULARITY.WARP_GROUP,
"warp_group_2": triton_proton.GRANULARITY.WARP_GROUP_2,
"warp_group_4": triton_proton.GRANULARITY.WARP_GROUP_4,
"warp_group_8": triton_proton.GRANULARITY.WARP_GROUP_8,
}
class Optimize(Enum):
TIMESHIFT = "time_shift"
SCHED_STORES = "sched_stores"
SCHED_BARRIERS = "sched_barriers"
CLOCK32 = "clock32"
def __str__(self):
return self.value
optimizations = {
"time_shift": Optimize.TIMESHIFT,
"sched_stores": Optimize.SCHED_STORES,
"sched_barriers": Optimize.SCHED_BARRIERS,
"clock32": Optimize.CLOCK32,
}
@dataclass(frozen=True)
class BaseMode:
name: str
@dataclass(frozen=True)
class PCSampling(BaseMode):
name: str = field(default="pcsampling", init=False)
interval: int = 1000
def __post_init__(self):
if self.interval <= 0:
raise ValueError("Interval must be a positive integer.")
def __str__(self):
return f"{self.name}:interval={self.interval}"
@dataclass(frozen=True)
class InstrumentationMode(BaseMode):
"""Common base class for instrumentation modes with shared configuration."""
metric_type: triton_proton.METRIC_TYPE = triton_proton.METRIC_TYPE.CYCLE
sampling_strategy: triton_proton.SAMPLING_STRATEGY = triton_proton.SAMPLING_STRATEGY.NONE
sampling_options: str = ""
granularity: triton_proton.GRANULARITY = triton_proton.GRANULARITY.WARP
buffer_strategy: triton_proton.BUFFER_STRATEGY = triton_proton.BUFFER_STRATEGY.CIRCULAR
buffer_type: triton_proton.BUFFER_TYPE = triton_proton.BUFFER_TYPE.SHARED
buffer_size: int = 0
optimizations: List[Optimize] = field(default_factory=list)
def __post_init__(self):
# automatically map string inputs to enums using the global lookup dicts
mappings = [
("metric_type", metric_types),
("sampling_strategy", sampling_strategies),
("granularity", granularities),
("buffer_strategy", buffer_strategies),
("buffer_type", buffer_types),
]
for field_name, lookup in mappings:
value = getattr(self, field_name)
if isinstance(value, str):
if value not in lookup:
raise ValueError(f"Unknown {field_name}: {value}")
object.__setattr__(self, field_name, lookup[value])
values_str = getattr(self, "optimizations")
if isinstance(values_str, str):
values = [value.strip() for value in values_str.split(",")] if len(values_str) > 0 else []
for value in values:
if value not in optimizations:
raise ValueError(f"Unknown optimization: {value}")
object.__setattr__(self, "optimizations", [optimizations[value] for value in values])
def __str__(self):
optimizations_str = ",".join([str(opt) for opt in self.optimizations])
return (f"{self.name}:metric_type={self.metric_type}:sampling_strategy={self.sampling_strategy}"
f":sampling_options={self.sampling_options}:granularity={self.granularity}"
f":buffer_strategy={self.buffer_strategy}:buffer_type={self.buffer_type}"
f":buffer_size={self.buffer_size}:optimizations={optimizations_str}")
@dataclass(frozen=True)
class Default(InstrumentationMode):
name: str = field(default="default", init=False)
@dataclass(frozen=True)
class MMA(InstrumentationMode):
name: str = field(default="mma", init=False)