# Copyright 2024 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from collections.abc import Callable import contextlib import itertools import json import math import os import tempfile from typing import Literal, ParamSpec, TypeVar, overload import warnings import jax from jax._src import stages from jax._src import util import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import memref import numpy as np from .utils import * # noqa: F403 try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: mosaic_gpu_lib = None # type: ignore[assignment] # ruff: noqa: F405 T = TypeVar("T") P = ParamSpec("P") @dataclasses.dataclass(frozen=True, kw_only=True) class Cupti: """CUPTI-based profiler.""" # If `True`, detach CUPTI from the process after measurement. finalize: bool = True def measure( self, f, *, aggregate: bool = True, iterations: int = 1, ): if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) def wrapper(*args, **kwargs): if mosaic_gpu_lib is None: raise RuntimeError("CUPTI profiling is not supported on this platform") jax.block_until_ready(f(*args, **kwargs)) # Warmup. ext = mosaic_gpu_lib._mosaic_gpu_ext ext._cupti_init() try: all_results = [f(*args, **kwargs) for _ in range(iterations)] for r in all_results: jax.block_until_ready(r) results = all_results[0] finally: timings = ext._cupti_get_timings(self.finalize) if not timings: return results, None if len(timings) % iterations != 0: raise RuntimeError( "The number of kernel launches is not divisible by the number of" " iterations" ) kernels_per_iter = len(timings) // iterations iter_timings = util.split_list( timings, [kernels_per_iter] * (iterations - 1) ) for kernel_idx, (kernel_name, _) in enumerate(iter_timings[0]): for i in range(1, iterations): if iter_timings[i][kernel_idx][0] != kernel_name: raise RuntimeError("Kernel names are not consistent across iterations") if aggregate: iter_timings = [ sum(item[1] for item in timings) for timings in iter_timings ] return results, iter_timings[0] if len(iter_timings) == 1 else iter_timings return wrapper @overload def measure( f: Callable[P, T], *, aggregate: Literal[True] = ..., iterations: Literal[1] = ..., ) -> Callable[P, tuple[T, float | None]]: ... @overload def measure( f: Callable[P, T], *, aggregate: Literal[False] = ..., iterations: Literal[1] = ..., ) -> Callable[P, tuple[T, list[tuple[str, float]] | None]]: ... @overload def measure( f: Callable[P, T], *, aggregate: Literal[True] = ..., iterations: int = ..., ) -> Callable[P, tuple[T, list[float] | None]]: ... @overload def measure( f: Callable[P, T], *, aggregate: Literal[False] = ..., iterations: int = ..., ) -> Callable[P, tuple[T, list[list[tuple[str, float]]] | None]]: ... def measure( f, *, aggregate: bool = True, iterations: int = 1, ): """Measures the GPU runtime of a function using CUPTI. ``measure`` is a higher-order function that wraps a function ``f`` to return GPU runtime in milliseconds, in addition to its regular outputs. Args: f: The function to measure. aggregate: Whether to report an aggregate runtime. When ``False`` (only supported by ``mode="cupti"``), the per-kernel timings are returned as a list of tuples ``(, )``. iterations: How many times to run the function. Only supported by ``mode="cupti"``. When greater than 1, the return type will become a list of measurements. Returns: A function that accepts the same inputs as ``f`` and returns ``(f_outputs, timings)``, where ``f_outputs`` are the outputs of ``f``, and ``timings`` is either a float or a list of tuples, depending on ``aggregate``. If no kernels are launched, ``timings`` is ``None``. Notes: `CUPTI (CUDA Profiling Tools Interface) `_ is a high-accuracy profiling API used by Nsight Systems and Nsight Compute. The CUPTI API only allows a single subscriber, so ``measure`` cannot be used with other CUPTI-based tools like CUDA-GDB, Compute Sanitizer, Nsight Systems, or Nsight Compute. """ # fmt: skip if iterations < 1: raise ValueError(f"{iterations=} must be positive") return Cupti().measure(f, aggregate=aggregate, iterations=iterations) class ProfilerSpec: ENTER = 0 EXIT = 1 << 31 def __init__(self, entries_per_warpgroup: int, dump_path: str = "sponge"): self.entries_per_warpgroup = entries_per_warpgroup self.interned_names: dict[str, int] = {} if dump_path == "sponge": self.dump_path = os.getenv( "TEST_UNDECLARED_OUTPUTS_DIR", tempfile.gettempdir() ) else: self.dump_path = dump_path def _num_warpgroups( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> int: if math.prod(block) % WARPGROUP_SIZE: raise ValueError("Block size is not a multiple of warpgroup size") return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE def mlir_buffer_type( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> ir.Type: return ir.MemRefType.get( (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), ir.IntegerType.get_signless(32), ) def jax_buffer_type( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> ir.Type: return jax.ShapeDtypeStruct( (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), jnp.uint32, ) def smem_i32_elements(self, block: tuple[int, ...]): num_warpgroups = self._num_warpgroups((), block) return int(num_warpgroups * self.entries_per_warpgroup) def smem_bytes(self, block: tuple[int, ...]): bytes_per_entry = 4 return self.smem_i32_elements(block) * bytes_per_entry def intern_name(self, name: str) -> int: if (name_id := self.interned_names.get(name, None)) is not None: return name_id name_id = self.interned_names[name] = len(self.interned_names) if name_id & self.EXIT: raise RuntimeError("Allocated too many names") return name_id def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): buffer = np.asarray(buffer) num_blocks = math.prod(grid) warpgroups_per_block = self._num_warpgroups((), block) entries = buffer.reshape( num_blocks, warpgroups_per_block, self.entries_per_warpgroup ) start_times = entries[..., 0] sm_ids = entries[..., 1] traces_used = entries[..., 2] entries_used = traces_used + 3 if np.any(entries_used > self.entries_per_warpgroup): raise RuntimeError("Insufficient space to capture a full trace") traces = entries[..., 3:] # Estimate the overhead of profiling. time_events = traces[:, :, 1::2] valid_times_mask = np.arange(traces.shape[-1])[1::2] < traces_used[..., None] # 12 cycles is a ballpark estimate for H100 profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min( where=valid_times_mask[:, :, 1:], initial=12 ) profiling_overhead = max(0, profiling_overhead - 1) unintern = {v: k for k, v in self.interned_names.items()} events = [] for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): valid_entries = traces_used[block_idx, wg_idx] local_clock_offset = None assert valid_entries % 2 == 0, valid_entries start_time = start_times[block_idx, wg_idx] block_events = [] last_time = float("-inf") for i in range(0, valid_entries, 2): tag = traces[block_idx, wg_idx, i] time = traces[block_idx, wg_idx, i + 1] if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset time -= (i // 2) * profiling_overhead # Account for the overhead of profiling. if time < 0: break # Detect a timer wraparound name_id = tag begin = True if name_id & ProfilerSpec.EXIT: name_id = name_id ^ ProfilerSpec.EXIT begin = False name = unintern[name_id] if last_time >= time: if last_time - time > 10: warnings.warn( "Profiler clock went significantly backwards for event" f" {'start' if begin else 'end'} `{name}`: {last_time} ->" f" {time}" ) time = last_time + 1 last_time = time block_events.append({ "name": name, "ph": "B" if begin else "E", "ts": float(start_time + time) / 1e3, "pid": 1 + int(sm_ids[block_idx, wg_idx]), "tid": 1 + wg_idx + warpgroups_per_block * block_idx, }) else: # If we didn't break if block_events: events.append(block_events) events = sorted(events, key=lambda x: x[0]["ts"]) flat_events = list(itertools.chain.from_iterable(events)) return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) @dataclasses.dataclass(frozen=True) class _ProfilerCtx: """Set of IR values referenced by the profiler logic. The profiler logic is implemented using `CustomPrimitiveOp` which requires that all IR values referenced in its body be passed as operands to the op. """ start: ir.Value is_profiling_thread: ir.Value smem_buffer: ir.Value gmem_buffer: ir.Value offset: ir.Value class OnDeviceProfiler: def __init__( self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value, wrap_in_custom_primitive: bool, ): i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() self.spec = spec self.entries_per_wg = spec.entries_per_warpgroup self.wrap_in_custom_primitive = wrap_in_custom_primitive wg_idx = warpgroup_idx(sync=False) wg_offset = arith.index_cast( index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) ) smem_buffer = memref_slice(smem_buffer, ds(wg_offset, self.entries_per_wg)) is_profiling_thread = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)), c(0, i32), ) # Hopefully mem2reg will remove the allocation. offset = memref.alloca(ir.MemRefType.get((), index), [], []) memref.store(c(0, index), offset, []) self.ctx = _ProfilerCtx( start=globaltimer("low"), is_profiling_thread=is_profiling_thread, smem_buffer=smem_buffer, gmem_buffer=gmem_buffer, offset=offset, ) @contextlib.contextmanager def _profiler_ctx(self): if not self.wrap_in_custom_primitive: yield self.ctx return def fields(obj) -> list[ir.Value]: return [getattr(obj, field.name) for field in dataclasses.fields(obj)] op = dialect.CustomPrimitiveOp( result=[], operands_=fields(self.ctx), in_layouts=[], in_transforms=[ir.ArrayAttr.get([])], out_layouts=[], ) args_ty = [arg.type for arg in op.operands_] block = op.body.blocks.append(*args_ty) with ir.InsertionPoint(block): yield _ProfilerCtx(*block.arguments) dialect.return_([]) @contextlib.contextmanager def record(self, name: str): i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() name_id = self.spec.intern_name(name) def store(modifier): with self._profiler_ctx() as ctx: # smem_buffer[offset] = modifier | name_id # smem_buffer[offset + 1] = %clock # offset += 2 offset = memref.load(ctx.offset, []) base_ref = memref_slice(ctx.smem_buffer, offset) base_ptr = memref_ptr(base_ref, memory_space=3) i64 = ir.IntegerType.get_signless(64) base_addr = llvm.ptrtoint(i64, base_ptr) llvm.inline_asm( ir.Type.parse("!llvm.void"), [ctx.is_profiling_thread, base_addr, c(modifier | name_id, i32)], """ @$0 st.shared.v2.u32 [$1], {$2, %clock}; """, "b,l,r", has_side_effects=True, ) new_offset = arith.addi(offset, c(2, index)) memref.store(new_offset, ctx.offset, []) store(ProfilerSpec.ENTER) yield store(ProfilerSpec.EXIT) def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) with self._profiler_ctx() as ctx: gpu.barrier() # Make sure all warpgroups are done. block_idx = c(0, index) for dim in gpu.Dimension: # pytype: disable=wrong-arg-types block_idx = arith.addi( arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) ) wg_idx = warpgroup_idx(sync=False) wg_per_block = math.prod(block) // WARPGROUP_SIZE global_wg_idx = arith.addi( arith.muli(block_idx, c(wg_per_block, index)), arith.index_cast(index, wg_idx), ) start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) wg_gmem_buffer = memref_slice( ctx.gmem_buffer, ds(start_offset, self.entries_per_wg) ) with when(ctx.is_profiling_thread): memref.store(ctx.start, wg_gmem_buffer, [c(0, index)]) memref.store(smid(), wg_gmem_buffer, [c(1, index)]) num_traces = arith.index_cast(i32, memref.load(ctx.offset, [])) memref.store(num_traces, wg_gmem_buffer, [c(2, index)]) traces = vector.load( ir.VectorType.get((self.entries_per_wg - 3,), i32), ctx.smem_buffer, [c(0, index)], ) vector.store(traces, wg_gmem_buffer, [c(3, index)])