DriverTrac/venv/lib/python3.12/site-packages/jax/experimental/mosaic/gpu/profiler.py

436 lines
14 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections.abc import Callable
import contextlib
import itertools
import json
import math
import os
import tempfile
from typing import Literal, ParamSpec, TypeVar, overload
import warnings
import jax
from jax._src import stages
from jax._src import util
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import memref
import numpy as np
from .utils import * # noqa: F403
try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
except ImportError:
mosaic_gpu_lib = None # type: ignore[assignment]
# ruff: noqa: F405
T = TypeVar("T")
P = ParamSpec("P")
@dataclasses.dataclass(frozen=True, kw_only=True)
class Cupti:
"""CUPTI-based profiler."""
# If `True`, detach CUPTI from the process after measurement.
finalize: bool = True
def measure(
self, f, *, aggregate: bool = True, iterations: int = 1,
):
if not isinstance(f, (stages.Wrapped, stages.Compiled)):
f = jax.jit(f)
def wrapper(*args, **kwargs):
if mosaic_gpu_lib is None:
raise RuntimeError("CUPTI profiling is not supported on this platform")
jax.block_until_ready(f(*args, **kwargs)) # Warmup.
ext = mosaic_gpu_lib._mosaic_gpu_ext
ext._cupti_init()
try:
all_results = [f(*args, **kwargs) for _ in range(iterations)]
for r in all_results:
jax.block_until_ready(r)
results = all_results[0]
finally:
timings = ext._cupti_get_timings(self.finalize)
if not timings:
return results, None
if len(timings) % iterations != 0:
raise RuntimeError(
"The number of kernel launches is not divisible by the number of"
" iterations"
)
kernels_per_iter = len(timings) // iterations
iter_timings = util.split_list(
timings, [kernels_per_iter] * (iterations - 1)
)
for kernel_idx, (kernel_name, _) in enumerate(iter_timings[0]):
for i in range(1, iterations):
if iter_timings[i][kernel_idx][0] != kernel_name:
raise RuntimeError("Kernel names are not consistent across iterations")
if aggregate:
iter_timings = [
sum(item[1] for item in timings) for timings in iter_timings
]
return results, iter_timings[0] if len(iter_timings) == 1 else iter_timings
return wrapper
@overload
def measure(
f: Callable[P, T],
*,
aggregate: Literal[True] = ...,
iterations: Literal[1] = ...,
) -> Callable[P, tuple[T, float | None]]:
...
@overload
def measure(
f: Callable[P, T],
*,
aggregate: Literal[False] = ...,
iterations: Literal[1] = ...,
) -> Callable[P, tuple[T, list[tuple[str, float]] | None]]:
...
@overload
def measure(
f: Callable[P, T],
*,
aggregate: Literal[True] = ...,
iterations: int = ...,
) -> Callable[P, tuple[T, list[float] | None]]:
...
@overload
def measure(
f: Callable[P, T],
*,
aggregate: Literal[False] = ...,
iterations: int = ...,
) -> Callable[P, tuple[T, list[list[tuple[str, float]]] | None]]:
...
def measure(
f, *, aggregate: bool = True, iterations: int = 1,
):
"""Measures the GPU runtime of a function using CUPTI.
``measure`` is a higher-order function that wraps a function ``f`` to
return GPU runtime in milliseconds, in addition to its regular outputs.
Args:
f: The function to measure.
aggregate: Whether to report an aggregate runtime. When ``False`` (only
supported by ``mode="cupti"``), the per-kernel timings are returned as a
list of tuples ``(<kernel name>, <runtime in ms>)``.
iterations: How many times to run the function. Only supported by
``mode="cupti"``. When greater than 1, the return type will become a list
of measurements.
Returns:
A function that accepts the same inputs as ``f`` and returns
``(f_outputs, timings)``, where ``f_outputs`` are the outputs of ``f``,
and ``timings`` is either a float or a list of tuples, depending on
``aggregate``. If no kernels are launched, ``timings`` is ``None``.
Notes:
`CUPTI (CUDA Profiling Tools Interface)
<https://docs.nvidia.com/cupti/index.html>`_ is a high-accuracy profiling
API used by Nsight Systems and Nsight Compute. The CUPTI API only allows a
single subscriber, so ``measure`` cannot be used with other CUPTI-based
tools like CUDA-GDB, Compute Sanitizer, Nsight Systems, or Nsight
Compute.
""" # fmt: skip
if iterations < 1:
raise ValueError(f"{iterations=} must be positive")
return Cupti().measure(f, aggregate=aggregate, iterations=iterations)
class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
def __init__(self, entries_per_warpgroup: int, dump_path: str = "sponge"):
self.entries_per_warpgroup = entries_per_warpgroup
self.interned_names: dict[str, int] = {}
if dump_path == "sponge":
self.dump_path = os.getenv(
"TEST_UNDECLARED_OUTPUTS_DIR", tempfile.gettempdir()
)
else:
self.dump_path = dump_path
def _num_warpgroups(
self, grid: tuple[int, ...], block: tuple[int, ...]
) -> int:
if math.prod(block) % WARPGROUP_SIZE:
raise ValueError("Block size is not a multiple of warpgroup size")
return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE
def mlir_buffer_type(
self, grid: tuple[int, ...], block: tuple[int, ...]
) -> ir.Type:
return ir.MemRefType.get(
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
ir.IntegerType.get_signless(32),
)
def jax_buffer_type(
self, grid: tuple[int, ...], block: tuple[int, ...]
) -> ir.Type:
return jax.ShapeDtypeStruct(
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
jnp.uint32,
)
def smem_i32_elements(self, block: tuple[int, ...]):
num_warpgroups = self._num_warpgroups((), block)
return int(num_warpgroups * self.entries_per_warpgroup)
def smem_bytes(self, block: tuple[int, ...]):
bytes_per_entry = 4
return self.smem_i32_elements(block) * bytes_per_entry
def intern_name(self, name: str) -> int:
if (name_id := self.interned_names.get(name, None)) is not None:
return name_id
name_id = self.interned_names[name] = len(self.interned_names)
if name_id & self.EXIT:
raise RuntimeError("Allocated too many names")
return name_id
def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]):
buffer = np.asarray(buffer)
num_blocks = math.prod(grid)
warpgroups_per_block = self._num_warpgroups((), block)
entries = buffer.reshape(
num_blocks, warpgroups_per_block, self.entries_per_warpgroup
)
start_times = entries[..., 0]
sm_ids = entries[..., 1]
traces_used = entries[..., 2]
entries_used = traces_used + 3
if np.any(entries_used > self.entries_per_warpgroup):
raise RuntimeError("Insufficient space to capture a full trace")
traces = entries[..., 3:]
# Estimate the overhead of profiling.
time_events = traces[:, :, 1::2]
valid_times_mask = np.arange(traces.shape[-1])[1::2] < traces_used[..., None]
# 12 cycles is a ballpark estimate for H100
profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min(
where=valid_times_mask[:, :, 1:], initial=12
)
profiling_overhead = max(0, profiling_overhead - 1)
unintern = {v: k for k, v in self.interned_names.items()}
events = []
for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block):
valid_entries = traces_used[block_idx, wg_idx]
local_clock_offset = None
assert valid_entries % 2 == 0, valid_entries
start_time = start_times[block_idx, wg_idx]
block_events = []
last_time = float("-inf")
for i in range(0, valid_entries, 2):
tag = traces[block_idx, wg_idx, i]
time = traces[block_idx, wg_idx, i + 1]
if local_clock_offset is None:
local_clock_offset = time
time -= local_clock_offset
time -= (i // 2) * profiling_overhead # Account for the overhead of profiling.
if time < 0:
break # Detect a timer wraparound
name_id = tag
begin = True
if name_id & ProfilerSpec.EXIT:
name_id = name_id ^ ProfilerSpec.EXIT
begin = False
name = unintern[name_id]
if last_time >= time:
if last_time - time > 10:
warnings.warn(
"Profiler clock went significantly backwards for event"
f" {'start' if begin else 'end'} `{name}`: {last_time} ->"
f" {time}"
)
time = last_time + 1
last_time = time
block_events.append({
"name": name,
"ph": "B" if begin else "E",
"ts": float(start_time + time) / 1e3,
"pid": 1 + int(sm_ids[block_idx, wg_idx]),
"tid": 1 + wg_idx + warpgroups_per_block * block_idx,
})
else: # If we didn't break
if block_events:
events.append(block_events)
events = sorted(events, key=lambda x: x[0]["ts"])
flat_events = list(itertools.chain.from_iterable(events))
return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f)
@dataclasses.dataclass(frozen=True)
class _ProfilerCtx:
"""Set of IR values referenced by the profiler logic.
The profiler logic is implemented using `CustomPrimitiveOp` which requires
that all IR values referenced in its body be passed as operands to the op.
"""
start: ir.Value
is_profiling_thread: ir.Value
smem_buffer: ir.Value
gmem_buffer: ir.Value
offset: ir.Value
class OnDeviceProfiler:
def __init__(
self,
spec: ProfilerSpec,
smem_buffer: ir.Value,
gmem_buffer: ir.Value,
wrap_in_custom_primitive: bool,
):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
self.spec = spec
self.entries_per_wg = spec.entries_per_warpgroup
self.wrap_in_custom_primitive = wrap_in_custom_primitive
wg_idx = warpgroup_idx(sync=False)
wg_offset = arith.index_cast(
index, arith.muli(wg_idx, c(self.entries_per_wg, i32))
)
smem_buffer = memref_slice(smem_buffer, ds(wg_offset, self.entries_per_wg))
is_profiling_thread = arith.cmpi(
arith.CmpIPredicate.eq,
arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)),
c(0, i32),
)
# Hopefully mem2reg will remove the allocation.
offset = memref.alloca(ir.MemRefType.get((), index), [], [])
memref.store(c(0, index), offset, [])
self.ctx = _ProfilerCtx(
start=globaltimer("low"),
is_profiling_thread=is_profiling_thread,
smem_buffer=smem_buffer,
gmem_buffer=gmem_buffer,
offset=offset,
)
@contextlib.contextmanager
def _profiler_ctx(self):
if not self.wrap_in_custom_primitive:
yield self.ctx
return
def fields(obj) -> list[ir.Value]:
return [getattr(obj, field.name) for field in dataclasses.fields(obj)]
op = dialect.CustomPrimitiveOp(
result=[],
operands_=fields(self.ctx),
in_layouts=[],
in_transforms=[ir.ArrayAttr.get([])],
out_layouts=[],
)
args_ty = [arg.type for arg in op.operands_]
block = op.body.blocks.append(*args_ty)
with ir.InsertionPoint(block):
yield _ProfilerCtx(*block.arguments)
dialect.return_([])
@contextlib.contextmanager
def record(self, name: str):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
name_id = self.spec.intern_name(name)
def store(modifier):
with self._profiler_ctx() as ctx:
# smem_buffer[offset] = modifier | name_id
# smem_buffer[offset + 1] = %clock
# offset += 2
offset = memref.load(ctx.offset, [])
base_ref = memref_slice(ctx.smem_buffer, offset)
base_ptr = memref_ptr(base_ref, memory_space=3)
i64 = ir.IntegerType.get_signless(64)
base_addr = llvm.ptrtoint(i64, base_ptr)
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[ctx.is_profiling_thread, base_addr, c(modifier | name_id, i32)],
"""
@$0 st.shared.v2.u32 [$1], {$2, %clock};
""",
"b,l,r",
has_side_effects=True,
)
new_offset = arith.addi(offset, c(2, index))
memref.store(new_offset, ctx.offset, [])
store(ProfilerSpec.ENTER)
yield
store(ProfilerSpec.EXIT)
def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]):
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
with self._profiler_ctx() as ctx:
gpu.barrier() # Make sure all warpgroups are done.
block_idx = c(0, index)
for dim in gpu.Dimension: # pytype: disable=wrong-arg-types
block_idx = arith.addi(
arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
)
wg_idx = warpgroup_idx(sync=False)
wg_per_block = math.prod(block) // WARPGROUP_SIZE
global_wg_idx = arith.addi(
arith.muli(block_idx, c(wg_per_block, index)),
arith.index_cast(index, wg_idx),
)
start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index))
wg_gmem_buffer = memref_slice(
ctx.gmem_buffer, ds(start_offset, self.entries_per_wg)
)
with when(ctx.is_profiling_thread):
memref.store(ctx.start, wg_gmem_buffer, [c(0, index)])
memref.store(smid(), wg_gmem_buffer, [c(1, index)])
num_traces = arith.index_cast(i32, memref.load(ctx.offset, []))
memref.store(num_traces, wg_gmem_buffer, [c(2, index)])
traces = vector.load(
ir.VectorType.get((self.entries_per_wg - 3,), i32),
ctx.smem_buffer,
[c(0, index)],
)
vector.store(traces, wg_gmem_buffer, [c(3, index)])