DriverTrac/venv/lib/python3.12/site-packages/triton/profiler/hooks/instrumentation.py
2025-11-28 09:08:33 +05:30

361 lines
15 KiB
Python

import functools
from typing import Dict, Optional, Union, Any
import triton
from triton._C.libtriton import ir as triton_ir
from triton._C.libtriton import proton as triton_proton
from triton._C.libtriton import amd as triton_amd
from triton._C.libtriton import nvidia as triton_nvidia
from triton._C.libtriton import passes as triton_passes
from triton._C.libproton import proton as libproton
from triton.compiler import LazyDict
from triton.runtime.jit import JITFunction
from triton.runtime._allocation import set_profile_allocator, NullAllocator
from triton.backends import backends
from .hook import Hook
from ..flags import set_instrumentation_on, set_instrumentation_off
from .. import mode
# TODO(fywkevin): add support for major.minor
VERSION = 1
class CudaAllocator:
def __init__(self, instrumentation_hook):
self.instrumentation_hook = instrumentation_hook
def __call__(self, size: int, alignment: int, stream: Optional[int]):
if alignment != self.instrumentation_hook.profile_buffer_alignment:
raise RuntimeError(
f"Alignment mismatch: {alignment} != {self.instrumentation_hook.profile_buffer_alignment}")
aligned_size = (size + alignment - 1) // alignment * alignment
# Note: profile_buffer_size may be smaller than the aligned size if the kernel launches many blocks
# and the host CPU cannot store all profiling data in memory. This streaming mode is not yet implemented.
# In the future, we should support copying data incrementally from device to host to enable
# more efficient profiling data processing, rather than relying solely on post-processing.
aligned_size = max(aligned_size, self.instrumentation_hook.profile_buffer_size)
# Create the buffer
import torch
buffer = torch.empty((aligned_size, ), dtype=torch.uint8, device="cuda")
self.instrumentation_hook.buffer = buffer
return buffer
class Instrumentation:
def __init__(self, ir_map: Dict[str, Any]):
self.manager = ir_map
def register(self, ir: str, func):
if ir in self.manager:
raise RuntimeError(f"IR already registered: {ir}")
self.manager[ir] = func
def patch(self, ir: str, pm, context):
self.load_dialects(context)
if ir in self.manager:
self.manager[ir](pm)
def load_dialects(self, ctx):
triton_proton.load_dialects(ctx)
def _interpret_mode(mode_obj: Union[str, mode.InstrumentationMode]) -> mode.InstrumentationMode:
if isinstance(mode_obj, mode.InstrumentationMode):
return mode_obj
elif not mode_obj:
mode_obj = "default"
parts = mode_obj.split(":")
mode_name = parts[0]
opts: Dict[str, str] = {}
for opt in parts[1:]:
if "=" in opt:
key, val = opt.split("=", 1)
opts[key] = val
else:
raise ValueError(f"Malformed instrumentation option: '{opt}'")
# Get option values or empty strings
options = {
"metric_type": opts.get("metric_type", "cycle"), "buffer_type": opts.get("buffer_type", "shared"),
"buffer_strategy": opts.get("buffer_strategy", "circular"), "buffer_size": int(opts.get("buffer_size", "0")),
"granularity": opts.get("granularity", "warp"), "sampling_strategy": opts.get("sampling_strategy", "none"),
"sampling_options": opts.get("sampling_options", ""), "optimizations": opts.get("optimizations", "")
}
# Helper function to validate and map options to their enum values
def get_option_value(opt_name, mapping):
value = options[opt_name]
if value and value not in mapping:
raise ValueError(f"Unknown {opt_name}: {value}")
return mapping[value] if value else value
# Look up enum values for each option
options["metric_type"] = get_option_value("metric_type", mode.metric_types)
options["buffer_type"] = get_option_value("buffer_type", mode.buffer_types)
options["buffer_strategy"] = get_option_value("buffer_strategy", mode.buffer_strategies)
options["granularity"] = get_option_value("granularity", mode.granularities)
options["sampling_strategy"] = get_option_value("sampling_strategy", mode.sampling_strategies)
values = ([value.strip()
for value in options["optimizations"].split(",")] if len(options["optimizations"]) > 0 else [])
for value in values:
if value not in mode.optimizations:
raise ValueError(f"Unknown optimization: {value}")
options["optimizations"] = [mode.optimizations[value] for value in values]
# Create the appropriate mode instance
if mode_name == "default":
return mode.Default(**options)
elif mode_name == "mma":
return mode.MMA(**options)
else:
raise ValueError(f"Unknown mode: {mode_obj}")
def _get_backend_name() -> str:
backend = triton.runtime.driver.active.get_current_target().backend
if backend == "cuda":
return "nvidia"
elif backend == "hip":
return "amd"
else:
raise RuntimeError(f"Unsupported backend: {backend}")
class InstrumentationHook(Hook):
priority: int = 0
# It's important to note that only one instance of the instrumentation hook can be active at a time.
active_count: int = 0
enable_host_buffer: bool = False
host_buffer: Optional[Any] = None
# FIXME(fywkevin): change to a more reasonable value after we have support for periodic buffer dumping.
profile_buffer_size: int = 1
profile_buffer_alignment: int = 128
def __init__(self, mode_obj: Union[None, str, mode.InstrumentationMode]):
# Mapping of function objects to their scope ID pairs
self.mode: mode.InstrumentationMode = _interpret_mode(mode_obj)
self.allocator = CudaAllocator(self)
self.buffer = None
self.metadata_path: Dict[Any, Optional[str]] = {}
def activate(self):
if InstrumentationHook.active_count > 0:
raise RuntimeError("Only one instance of the instrumentation hook can be active at a time.")
InstrumentationHook.active_count += 1
set_instrumentation_on()
device = triton.runtime.driver.active.get_current_device()
max_shared_mem = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"]
backend_name = _get_backend_name()
def to_llvmir_passes(pm):
is_long_clk = False if mode.Optimize.CLOCK32 in self.mode.optimizations else True
triton_proton.add_convert_proton_to_protongpu(pm, self.mode.metric_type, self.mode.sampling_strategy,
self.mode.sampling_options, self.mode.granularity,
self.mode.buffer_strategy, self.mode.buffer_type,
self.mode.buffer_size, max_shared_mem,
self.profile_buffer_size, self.profile_buffer_alignment,
is_long_clk)
triton_passes.common.add_cse(pm)
if mode.Optimize.SCHED_STORES in self.mode.optimizations:
triton_proton.add_schedule_buffer_store(pm)
triton_proton.add_allocate_proton_shared_memory(pm)
if mode.Optimize.SCHED_BARRIERS in self.mode.optimizations and backend_name == "amd":
triton_proton.add_sched_barriers(pm)
def to_llvm_passes(pm):
triton_proton.add_allocate_proton_global_scratch_buffer(pm)
if backend_name == "nvidia":
triton_proton.add_convert_proton_nvidia_gpu_to_llvm(pm)
elif backend_name == "amd":
arch = triton.runtime.driver.active.utils.get_device_properties(device)["arch"].split(":")[0]
triton_proton.add_convert_proton_amd_gpu_to_llvm(pm, arch)
backends[backend_name].compiler.instrumentation = Instrumentation({
"ttgpuir_to_llvmir":
lambda pm: to_llvmir_passes(pm),
"llvmir_to_llvm":
lambda pm: to_llvm_passes(pm),
})
# Set up the profiling allocator
set_profile_allocator(self.allocator)
original_run = JITFunction.run
original_mode = self.mode
@functools.wraps(original_run)
def instrumented_run(self, *args, **kwargs):
kwargs["instrumentation_mode"] = str(original_mode)
return original_run(self, *args, **kwargs)
JITFunction.run = instrumented_run
def deactivate(self):
if InstrumentationHook.active_count == 0:
return
InstrumentationHook.active_count -= 1
backend_name = _get_backend_name()
# No instrumentation passes are registered anymore
backends[backend_name].compiler.instrumentation = {}
# No runtime instrumentation hook is active anymore
set_instrumentation_off()
# Restore original JIT function run method
if hasattr(JITFunction.run, "__wrapped__"):
JITFunction.run = JITFunction.run.__wrapped__
# Reset profile allocator
set_profile_allocator(NullAllocator())
# Reset host memory for external processing
if InstrumentationHook.enable_host_buffer:
InstrumentationHook.host_buffer = None
# Reset the buffer reference
self.buffer = None
def init_handle(self, module: Any, function: Any, name: str, metadata_group: Dict[str, str], hash: str) -> None:
if not function:
return
# Find the IR path in metadata
ir_path = next((path for key, path in metadata_group.items() if key.endswith(("ttgir"))), None)
metadata_path = next((path for key, path in metadata_group.items() if key.endswith(("json"))), None)
self.metadata_path[function] = metadata_path
if ir_path:
context = triton_ir.context()
triton_ir.load_dialects(context)
backend_name = _get_backend_name()
if backend_name == "nvidia":
triton_nvidia.load_dialects(context)
elif backend_name == "amd":
triton_amd.load_dialects(context)
triton_proton.load_dialects(context)
module = triton_ir.parse_mlir_module(ir_path, context)
module.context = context
scope_id_names = triton_proton.get_scope_id_names(module)
scope_id_parents = triton_proton.get_scope_id_parents(module)
libproton.init_function_metadata(function, name, scope_id_names, scope_id_parents, metadata_path)
else:
raise RuntimeError(f"IR path not found in metadata for function {function}")
def _data_ptr(self) -> int:
return 0 if self.buffer is None else self.buffer.data_ptr()
def enter(self, metadata: LazyDict) -> None:
func = metadata.data.get("function")
stream = metadata.data.get("stream")
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
libproton.enter_instrumented_op(stream, func, self._data_ptr(), alloc_size)
if InstrumentationHook.enable_host_buffer:
InstrumentationHook.host_buffer = None
def exit(self, metadata: LazyDict) -> None:
func = metadata.data.get("function")
stream = metadata.data.get("stream")
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
libproton.exit_instrumented_op(stream, func, self._data_ptr(), alloc_size)
if InstrumentationHook.enable_host_buffer:
self._populate_host_buffer(func)
def _populate_host_buffer(self, function: Any) -> None:
if function and self.metadata_path[function]:
import torch
import struct
import json
def encode_target(target: Dict[str, Any]) -> int:
#TODO(fywkevin): also account for `arch`
if target["backend"] == "cuda":
return 1
elif target["backend"] == "hip":
return 2
return 0
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
sampled_warps = self.mode.sampling_options.strip().split(",")
data = {}
with open(self.metadata_path[function], 'r') as file:
data = json.load(file)
device_type = encode_target(data["target"])
scratch_mem_size = data["profile_scratch_size"]
total_unit = data["num_warps"]
uid_num = total_unit if self.mode.sampling_strategy == triton_proton.SAMPLING_STRATEGY.NONE else len(
sampled_warps)
block_num = int(alloc_size / scratch_mem_size)
# Binary trace layout:
# +------------------+
# | version | 4 bytes
# +------------------+
# | header_offset | 4 bytes
# +------------------+
# | header_size | 4 bytes
# +------------------+
# | payload_offset | 4 bytes
# +------------------+
# | payload_size | 4 bytes
# +------------------+
# | device_type | 4 bytes
# +------------------+
# | block_num | 4 bytes
# +------------------+
# | total_unit | 4 bytes
# +------------------+
# | scratch_mem_size | 4 bytes
# +------------------+
# | uid_num | 4 bytes
# +------------------+
# | |
# | uid_vec | uid_num * 4 bytes
# | |
# +------------------+
# | |
# | payload | size_payload bytes
# | |
# +------------------+
is_all_warps = self.mode.sampling_options == "" and self.mode.granularity == triton_proton.GRANULARITY.WARP
if is_all_warps:
uid_vec = [i for i in range(total_unit)]
else:
uid_vec = [int(i) for i in sampled_warps]
header_size = 40 + uid_num * 4
header_offset = 4
payload_offset = header_size
payload_size = alloc_size
header_values = [
VERSION, header_offset, header_size, payload_offset, payload_size, device_type, block_num, total_unit,
scratch_mem_size, uid_num, *uid_vec
]
header_bytes = struct.pack("I" * len(header_values), *header_values)
InstrumentationHook.host_buffer = torch.empty(header_size + alloc_size, dtype=torch.uint8, device="cpu")
config_portion = InstrumentationHook.host_buffer[:header_size]
config_portion.copy_(torch.tensor(list(header_bytes), dtype=torch.uint8))
data_portion = InstrumentationHook.host_buffer[header_size:].view_as(self.buffer)
data_portion.copy_(self.buffer.cpu())