361 lines
15 KiB
Python
361 lines
15 KiB
Python
import functools
|
|
from typing import Dict, Optional, Union, Any
|
|
|
|
import triton
|
|
from triton._C.libtriton import ir as triton_ir
|
|
from triton._C.libtriton import proton as triton_proton
|
|
from triton._C.libtriton import amd as triton_amd
|
|
from triton._C.libtriton import nvidia as triton_nvidia
|
|
from triton._C.libtriton import passes as triton_passes
|
|
from triton._C.libproton import proton as libproton
|
|
from triton.compiler import LazyDict
|
|
from triton.runtime.jit import JITFunction
|
|
from triton.runtime._allocation import set_profile_allocator, NullAllocator
|
|
from triton.backends import backends
|
|
|
|
from .hook import Hook
|
|
from ..flags import set_instrumentation_on, set_instrumentation_off
|
|
from .. import mode
|
|
|
|
# TODO(fywkevin): add support for major.minor
|
|
VERSION = 1
|
|
|
|
|
|
class CudaAllocator:
|
|
|
|
def __init__(self, instrumentation_hook):
|
|
self.instrumentation_hook = instrumentation_hook
|
|
|
|
def __call__(self, size: int, alignment: int, stream: Optional[int]):
|
|
if alignment != self.instrumentation_hook.profile_buffer_alignment:
|
|
raise RuntimeError(
|
|
f"Alignment mismatch: {alignment} != {self.instrumentation_hook.profile_buffer_alignment}")
|
|
aligned_size = (size + alignment - 1) // alignment * alignment
|
|
# Note: profile_buffer_size may be smaller than the aligned size if the kernel launches many blocks
|
|
# and the host CPU cannot store all profiling data in memory. This streaming mode is not yet implemented.
|
|
# In the future, we should support copying data incrementally from device to host to enable
|
|
# more efficient profiling data processing, rather than relying solely on post-processing.
|
|
aligned_size = max(aligned_size, self.instrumentation_hook.profile_buffer_size)
|
|
|
|
# Create the buffer
|
|
import torch
|
|
buffer = torch.empty((aligned_size, ), dtype=torch.uint8, device="cuda")
|
|
self.instrumentation_hook.buffer = buffer
|
|
return buffer
|
|
|
|
|
|
class Instrumentation:
|
|
|
|
def __init__(self, ir_map: Dict[str, Any]):
|
|
self.manager = ir_map
|
|
|
|
def register(self, ir: str, func):
|
|
if ir in self.manager:
|
|
raise RuntimeError(f"IR already registered: {ir}")
|
|
self.manager[ir] = func
|
|
|
|
def patch(self, ir: str, pm, context):
|
|
self.load_dialects(context)
|
|
if ir in self.manager:
|
|
self.manager[ir](pm)
|
|
|
|
def load_dialects(self, ctx):
|
|
triton_proton.load_dialects(ctx)
|
|
|
|
|
|
def _interpret_mode(mode_obj: Union[str, mode.InstrumentationMode]) -> mode.InstrumentationMode:
|
|
if isinstance(mode_obj, mode.InstrumentationMode):
|
|
return mode_obj
|
|
elif not mode_obj:
|
|
mode_obj = "default"
|
|
|
|
parts = mode_obj.split(":")
|
|
mode_name = parts[0]
|
|
opts: Dict[str, str] = {}
|
|
for opt in parts[1:]:
|
|
if "=" in opt:
|
|
key, val = opt.split("=", 1)
|
|
opts[key] = val
|
|
else:
|
|
raise ValueError(f"Malformed instrumentation option: '{opt}'")
|
|
|
|
# Get option values or empty strings
|
|
options = {
|
|
"metric_type": opts.get("metric_type", "cycle"), "buffer_type": opts.get("buffer_type", "shared"),
|
|
"buffer_strategy": opts.get("buffer_strategy", "circular"), "buffer_size": int(opts.get("buffer_size", "0")),
|
|
"granularity": opts.get("granularity", "warp"), "sampling_strategy": opts.get("sampling_strategy", "none"),
|
|
"sampling_options": opts.get("sampling_options", ""), "optimizations": opts.get("optimizations", "")
|
|
}
|
|
|
|
# Helper function to validate and map options to their enum values
|
|
def get_option_value(opt_name, mapping):
|
|
value = options[opt_name]
|
|
if value and value not in mapping:
|
|
raise ValueError(f"Unknown {opt_name}: {value}")
|
|
return mapping[value] if value else value
|
|
|
|
# Look up enum values for each option
|
|
options["metric_type"] = get_option_value("metric_type", mode.metric_types)
|
|
options["buffer_type"] = get_option_value("buffer_type", mode.buffer_types)
|
|
options["buffer_strategy"] = get_option_value("buffer_strategy", mode.buffer_strategies)
|
|
options["granularity"] = get_option_value("granularity", mode.granularities)
|
|
options["sampling_strategy"] = get_option_value("sampling_strategy", mode.sampling_strategies)
|
|
|
|
values = ([value.strip()
|
|
for value in options["optimizations"].split(",")] if len(options["optimizations"]) > 0 else [])
|
|
for value in values:
|
|
if value not in mode.optimizations:
|
|
raise ValueError(f"Unknown optimization: {value}")
|
|
options["optimizations"] = [mode.optimizations[value] for value in values]
|
|
|
|
# Create the appropriate mode instance
|
|
if mode_name == "default":
|
|
return mode.Default(**options)
|
|
elif mode_name == "mma":
|
|
return mode.MMA(**options)
|
|
else:
|
|
raise ValueError(f"Unknown mode: {mode_obj}")
|
|
|
|
|
|
def _get_backend_name() -> str:
|
|
backend = triton.runtime.driver.active.get_current_target().backend
|
|
if backend == "cuda":
|
|
return "nvidia"
|
|
elif backend == "hip":
|
|
return "amd"
|
|
else:
|
|
raise RuntimeError(f"Unsupported backend: {backend}")
|
|
|
|
|
|
class InstrumentationHook(Hook):
|
|
priority: int = 0
|
|
# It's important to note that only one instance of the instrumentation hook can be active at a time.
|
|
active_count: int = 0
|
|
enable_host_buffer: bool = False
|
|
host_buffer: Optional[Any] = None
|
|
# FIXME(fywkevin): change to a more reasonable value after we have support for periodic buffer dumping.
|
|
profile_buffer_size: int = 1
|
|
profile_buffer_alignment: int = 128
|
|
|
|
def __init__(self, mode_obj: Union[None, str, mode.InstrumentationMode]):
|
|
# Mapping of function objects to their scope ID pairs
|
|
self.mode: mode.InstrumentationMode = _interpret_mode(mode_obj)
|
|
|
|
self.allocator = CudaAllocator(self)
|
|
self.buffer = None
|
|
self.metadata_path: Dict[Any, Optional[str]] = {}
|
|
|
|
def activate(self):
|
|
if InstrumentationHook.active_count > 0:
|
|
raise RuntimeError("Only one instance of the instrumentation hook can be active at a time.")
|
|
|
|
InstrumentationHook.active_count += 1
|
|
|
|
set_instrumentation_on()
|
|
|
|
device = triton.runtime.driver.active.get_current_device()
|
|
max_shared_mem = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
backend_name = _get_backend_name()
|
|
|
|
def to_llvmir_passes(pm):
|
|
is_long_clk = False if mode.Optimize.CLOCK32 in self.mode.optimizations else True
|
|
triton_proton.add_convert_proton_to_protongpu(pm, self.mode.metric_type, self.mode.sampling_strategy,
|
|
self.mode.sampling_options, self.mode.granularity,
|
|
self.mode.buffer_strategy, self.mode.buffer_type,
|
|
self.mode.buffer_size, max_shared_mem,
|
|
self.profile_buffer_size, self.profile_buffer_alignment,
|
|
is_long_clk)
|
|
triton_passes.common.add_cse(pm)
|
|
|
|
if mode.Optimize.SCHED_STORES in self.mode.optimizations:
|
|
triton_proton.add_schedule_buffer_store(pm)
|
|
|
|
triton_proton.add_allocate_proton_shared_memory(pm)
|
|
|
|
if mode.Optimize.SCHED_BARRIERS in self.mode.optimizations and backend_name == "amd":
|
|
triton_proton.add_sched_barriers(pm)
|
|
|
|
def to_llvm_passes(pm):
|
|
triton_proton.add_allocate_proton_global_scratch_buffer(pm)
|
|
if backend_name == "nvidia":
|
|
triton_proton.add_convert_proton_nvidia_gpu_to_llvm(pm)
|
|
elif backend_name == "amd":
|
|
arch = triton.runtime.driver.active.utils.get_device_properties(device)["arch"].split(":")[0]
|
|
triton_proton.add_convert_proton_amd_gpu_to_llvm(pm, arch)
|
|
|
|
backends[backend_name].compiler.instrumentation = Instrumentation({
|
|
"ttgpuir_to_llvmir":
|
|
lambda pm: to_llvmir_passes(pm),
|
|
"llvmir_to_llvm":
|
|
lambda pm: to_llvm_passes(pm),
|
|
})
|
|
|
|
# Set up the profiling allocator
|
|
set_profile_allocator(self.allocator)
|
|
|
|
original_run = JITFunction.run
|
|
|
|
original_mode = self.mode
|
|
|
|
@functools.wraps(original_run)
|
|
def instrumented_run(self, *args, **kwargs):
|
|
kwargs["instrumentation_mode"] = str(original_mode)
|
|
return original_run(self, *args, **kwargs)
|
|
|
|
JITFunction.run = instrumented_run
|
|
|
|
def deactivate(self):
|
|
if InstrumentationHook.active_count == 0:
|
|
return
|
|
|
|
InstrumentationHook.active_count -= 1
|
|
|
|
backend_name = _get_backend_name()
|
|
|
|
# No instrumentation passes are registered anymore
|
|
backends[backend_name].compiler.instrumentation = {}
|
|
|
|
# No runtime instrumentation hook is active anymore
|
|
set_instrumentation_off()
|
|
|
|
# Restore original JIT function run method
|
|
if hasattr(JITFunction.run, "__wrapped__"):
|
|
JITFunction.run = JITFunction.run.__wrapped__
|
|
|
|
# Reset profile allocator
|
|
set_profile_allocator(NullAllocator())
|
|
|
|
# Reset host memory for external processing
|
|
if InstrumentationHook.enable_host_buffer:
|
|
InstrumentationHook.host_buffer = None
|
|
|
|
# Reset the buffer reference
|
|
self.buffer = None
|
|
|
|
def init_handle(self, module: Any, function: Any, name: str, metadata_group: Dict[str, str], hash: str) -> None:
|
|
if not function:
|
|
return
|
|
|
|
# Find the IR path in metadata
|
|
ir_path = next((path for key, path in metadata_group.items() if key.endswith(("ttgir"))), None)
|
|
metadata_path = next((path for key, path in metadata_group.items() if key.endswith(("json"))), None)
|
|
self.metadata_path[function] = metadata_path
|
|
|
|
if ir_path:
|
|
context = triton_ir.context()
|
|
triton_ir.load_dialects(context)
|
|
backend_name = _get_backend_name()
|
|
if backend_name == "nvidia":
|
|
triton_nvidia.load_dialects(context)
|
|
elif backend_name == "amd":
|
|
triton_amd.load_dialects(context)
|
|
triton_proton.load_dialects(context)
|
|
module = triton_ir.parse_mlir_module(ir_path, context)
|
|
module.context = context
|
|
|
|
scope_id_names = triton_proton.get_scope_id_names(module)
|
|
scope_id_parents = triton_proton.get_scope_id_parents(module)
|
|
libproton.init_function_metadata(function, name, scope_id_names, scope_id_parents, metadata_path)
|
|
else:
|
|
raise RuntimeError(f"IR path not found in metadata for function {function}")
|
|
|
|
def _data_ptr(self) -> int:
|
|
return 0 if self.buffer is None else self.buffer.data_ptr()
|
|
|
|
def enter(self, metadata: LazyDict) -> None:
|
|
func = metadata.data.get("function")
|
|
stream = metadata.data.get("stream")
|
|
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
|
|
libproton.enter_instrumented_op(stream, func, self._data_ptr(), alloc_size)
|
|
if InstrumentationHook.enable_host_buffer:
|
|
InstrumentationHook.host_buffer = None
|
|
|
|
def exit(self, metadata: LazyDict) -> None:
|
|
func = metadata.data.get("function")
|
|
stream = metadata.data.get("stream")
|
|
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
|
|
libproton.exit_instrumented_op(stream, func, self._data_ptr(), alloc_size)
|
|
|
|
if InstrumentationHook.enable_host_buffer:
|
|
self._populate_host_buffer(func)
|
|
|
|
def _populate_host_buffer(self, function: Any) -> None:
|
|
if function and self.metadata_path[function]:
|
|
import torch
|
|
import struct
|
|
import json
|
|
|
|
def encode_target(target: Dict[str, Any]) -> int:
|
|
#TODO(fywkevin): also account for `arch`
|
|
if target["backend"] == "cuda":
|
|
return 1
|
|
elif target["backend"] == "hip":
|
|
return 2
|
|
return 0
|
|
|
|
alloc_size = 0 if self.buffer is None else self.buffer.element_size() * self.buffer.numel()
|
|
sampled_warps = self.mode.sampling_options.strip().split(",")
|
|
data = {}
|
|
with open(self.metadata_path[function], 'r') as file:
|
|
data = json.load(file)
|
|
|
|
device_type = encode_target(data["target"])
|
|
scratch_mem_size = data["profile_scratch_size"]
|
|
total_unit = data["num_warps"]
|
|
uid_num = total_unit if self.mode.sampling_strategy == triton_proton.SAMPLING_STRATEGY.NONE else len(
|
|
sampled_warps)
|
|
block_num = int(alloc_size / scratch_mem_size)
|
|
|
|
# Binary trace layout:
|
|
# +------------------+
|
|
# | version | 4 bytes
|
|
# +------------------+
|
|
# | header_offset | 4 bytes
|
|
# +------------------+
|
|
# | header_size | 4 bytes
|
|
# +------------------+
|
|
# | payload_offset | 4 bytes
|
|
# +------------------+
|
|
# | payload_size | 4 bytes
|
|
# +------------------+
|
|
# | device_type | 4 bytes
|
|
# +------------------+
|
|
# | block_num | 4 bytes
|
|
# +------------------+
|
|
# | total_unit | 4 bytes
|
|
# +------------------+
|
|
# | scratch_mem_size | 4 bytes
|
|
# +------------------+
|
|
# | uid_num | 4 bytes
|
|
# +------------------+
|
|
# | |
|
|
# | uid_vec | uid_num * 4 bytes
|
|
# | |
|
|
# +------------------+
|
|
# | |
|
|
# | payload | size_payload bytes
|
|
# | |
|
|
# +------------------+
|
|
|
|
is_all_warps = self.mode.sampling_options == "" and self.mode.granularity == triton_proton.GRANULARITY.WARP
|
|
if is_all_warps:
|
|
uid_vec = [i for i in range(total_unit)]
|
|
else:
|
|
uid_vec = [int(i) for i in sampled_warps]
|
|
|
|
header_size = 40 + uid_num * 4
|
|
header_offset = 4
|
|
payload_offset = header_size
|
|
payload_size = alloc_size
|
|
header_values = [
|
|
VERSION, header_offset, header_size, payload_offset, payload_size, device_type, block_num, total_unit,
|
|
scratch_mem_size, uid_num, *uid_vec
|
|
]
|
|
header_bytes = struct.pack("I" * len(header_values), *header_values)
|
|
|
|
InstrumentationHook.host_buffer = torch.empty(header_size + alloc_size, dtype=torch.uint8, device="cpu")
|
|
config_portion = InstrumentationHook.host_buffer[:header_size]
|
|
config_portion.copy_(torch.tensor(list(header_bytes), dtype=torch.uint8))
|
|
data_portion = InstrumentationHook.host_buffer[header_size:].view_as(self.buffer)
|
|
data_portion.copy_(self.buffer.cpu())
|