DriverTrac/venv/lib/python3.12/site-packages/jax/experimental/roofline/roofline.py

386 lines
12 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol
from collections.abc import Callable, Sequence
import numpy as np
from absl import logging
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax._src import api
from jax._src import core
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api import make_jaxpr
from jax._src.interpreters.partial_eval import dce_jaxpr
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
from jax._src.util import foreach
from jax._src.shard_map import shard_map, shard_map_p
ShapeDtypeStructTree = Any
Specs = Any
ValidRooflineDtype = np.dtype | prng.KeyTy
map = util.safe_map
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineRuleContext:
name_stack: source_info_util.NameStack
primitive: core.Primitive
avals_in: Sequence[core.AbstractValue]
avals_out: Sequence[core.AbstractValue]
jaxpr_eqn_ctx: core.JaxprEqnContext
mesh: Mesh | AbstractMesh | None
pin_lhs_in_vmem: bool
pin_rhs_in_vmem: bool
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineShape:
shape: tuple[int, ...]
dtype: ValidRooflineDtype
@classmethod
def from_aval(cls, aval: core.AbstractValue) -> RooflineShape:
if not isinstance(aval, core.ShapedArray):
raise TypeError(f"Expected ShapedArray, got {type(aval)}.")
if not isinstance(aval.dtype, ValidRooflineDtype):
raise TypeError(
f"Expected numpy or prng.KeyTy dtype, got {type(aval.dtype)}."
)
return cls(shape=aval.shape, dtype=aval.dtype)
@property
def size(self) -> int:
return int(np.prod(self.shape))
@property
def bytes(self) -> int:
return int(self.size * self.dtype.itemsize)
@classmethod
def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int:
return sum(cls.from_aval(aval).bytes for aval in avals)
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineResult:
flops: int = 0
unfused_flops: int = 0
ici_bytes: dict[str, int] = field(default_factory=dict)
ici_latency: dict[str, int] = field(default_factory=dict)
hbm_bytes: int = 0
peak_hbm_bytes: int = 0
unfused_hbm_bytes: int = 0
@classmethod
def zeros(cls) -> RooflineResult:
return cls()
def __add__(self, other: RooflineResult) -> RooflineResult:
def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]:
return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
return RooflineResult(
flops=self.flops + other.flops,
unfused_flops=self.unfused_flops + other.unfused_flops,
ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes),
ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency),
hbm_bytes=self.hbm_bytes + other.hbm_bytes,
peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes),
unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes,
)
def __mul__(self, constant: int | float) -> RooflineResult:
return RooflineResult(
flops=int(self.flops * constant),
unfused_flops=int(self.unfused_flops * constant),
ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()},
ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()},
hbm_bytes=int(self.hbm_bytes * constant),
peak_hbm_bytes=int(self.peak_hbm_bytes * constant),
unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant),
)
def __rmul__(self, constant: int | float) -> RooflineResult:
return self.__mul__(constant)
class _RooflineRule(Protocol):
def __call__(
self, ctx: RooflineRuleContext, *args: RooflineShape, **kw
) -> RooflineResult: ...
_rooflines: dict[core.Primitive, _RooflineRule] = {}
def _roofline_interpreter(
f_name: str,
jaxpr: core.Jaxpr,
mesh: Mesh | AbstractMesh,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
) -> RooflineResult:
name_stack = source_info_util.new_name_stack(util.wrap_name("roofline", f_name))
result = RooflineResult.zeros()
env: dict[core.Var, RooflineShape] = {}
def write(v: core.Var, node: RooflineShape):
assert node is not None
env[v] = node
def read(v: core.Atom) -> RooflineShape:
if type(v) is core.Literal:
return RooflineShape.from_aval(core.abstractify(v.val))
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return core.abstractify(v.val)
else:
return v.aval
def sum_bytes(shapes: Sequence[RooflineShape]) -> int:
return sum(shape.bytes for shape in shapes)
jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
foreach(
write,
jaxpr.constvars,
map(make_roofline_shape, jaxpr.constvars),
)
foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
last_used = core.last_used(jaxpr)
current_hbm_bytes = sum_bytes(list(env.values()))
peak_hbm_bytes = current_hbm_bytes
for eqn in jaxpr.eqns:
source_info = eqn.source_info.replace(
name_stack=name_stack + eqn.source_info.name_stack
)
with source_info_util.user_context(
eqn.source_info.traceback, name_stack=source_info.name_stack
):
if "jaxpr" in eqn.params:
result += _roofline_interpreter(
util.wrap_name(eqn.primitive.name, f_name),
eqn.params["jaxpr"],
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
elif "call_jaxpr" in eqn.params:
# Used for custom_jvp_call_p. Recursively calculates roofline result for
# all primitives in the custom function.
result += _roofline_interpreter(
util.wrap_name(eqn.primitive.name, f_name),
eqn.params['call_jaxpr'],
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
elif eqn.primitive not in _rooflines:
msg = f"No roofline rule for {eqn.primitive}, skipping..."
for attr in dir(eqn):
if not attr.startswith("_"):
msg += f"\n{attr}: {getattr(eqn, attr)}"
logging.warning(msg)
else:
rule = _rooflines[eqn.primitive]
result += rule(
RooflineRuleContext(
name_stack=source_info.name_stack,
primitive=eqn.primitive,
avals_in=map(aval, eqn.invars),
avals_out=map(aval, eqn.outvars),
jaxpr_eqn_ctx=eqn.ctx,
mesh=mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
),
*map(read, eqn.invars),
**eqn.params,
)
# Add bytes for the newly-created output variables.
outvar_shapes = map(make_roofline_shape, eqn.outvars)
current_hbm_bytes += sum_bytes(outvar_shapes)
foreach(write, eqn.outvars, outvar_shapes)
# Remove bytes for the no-longer-needed input variables.
removed_shapes = [
env[v] for v in eqn.invars
if not isinstance(v, core.Literal) and last_used[v] is eqn
]
current_hbm_bytes -= sum_bytes(removed_shapes)
core.clean_up_dead_vars(eqn, env, last_used)
peak_hbm_bytes = max(peak_hbm_bytes, current_hbm_bytes)
result += RooflineResult(peak_hbm_bytes=peak_hbm_bytes)
return result
def _f_with_vjp(f: Callable):
@util.wraps(f)
def wrapped(*args):
primals, f_vjp = api.vjp(f, *args)
return f_vjp(tree_map(jnp.bfloat16, primals))
return wrapped
def roofline(
f: Callable,
mesh: Mesh | AbstractMesh | None = None,
in_specs: Specs | None = None,
out_specs: Specs | None = None,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
vjp: bool = False,
print_jaxpr: bool = False,
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]:
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
wrapped_f = f
if in_specs is not None and out_specs is not None and mesh is not None:
wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs,
out_specs=out_specs)
if vjp:
wrapped_f = _f_with_vjp(wrapped_f)
jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args)
def make_sharded_shape_dtype_struct(
shape: api.ShapeDtypeStruct, out_spec: Specs
) -> api.ShapeDtypeStruct:
return api.ShapeDtypeStruct(
shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) # type: ignore
)
if out_specs is not None and mesh is not None:
out_specs_flat = broadcast_prefix(out_specs, out_shapes)
flat_out_shapes, treedef = tree_flatten(out_shapes)
flat_out_shapes = map(
make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat
)
out_shapes = tree_unflatten(treedef, flat_out_shapes)
used_outputs = (True,) * len(jaxpr.jaxpr.outvars)
jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs)
shard_map_eqns = [
e for e in jaxpr.eqns if e.primitive == shard_map_p
]
if shard_map_eqns:
try:
jaxpr = shard_map_eqns[-1].params["jaxpr"]
except KeyError:
raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.")
if print_jaxpr:
print(jaxpr)
return out_shapes, _roofline_interpreter(
util.fun_qual_name(f),
jaxpr,
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
return wrapped
def register_roofline(prim: core.Primitive):
def register(rule: _RooflineRule):
_rooflines[prim] = rule
return rule
return register
def register_standard_roofline(prim: core.Primitive):
def standard_rule(ctx: RooflineRuleContext, *args, **kwargs):
return RooflineResult.zeros()
_rooflines[prim] = standard_rule
def roofline_and_grad(
f: Callable,
mesh: Mesh | AbstractMesh,
in_specs: Specs,
out_specs: Specs,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
print_jaxpr: bool = False,
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]:
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
primal_shapes, fwd_result = roofline(
f,
mesh,
in_specs,
out_specs,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
print_jaxpr=print_jaxpr,
)(*args)
return (
primal_shapes,
fwd_result,
roofline(
f,
mesh,
in_specs,
out_specs,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
vjp=True,
print_jaxpr=print_jaxpr,
)(
*tree_map(
lambda x: api.ShapeDtypeStruct(
x.shape,
jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16,
sharding=x.sharding,
),
args,
)
)[1],
)
return wrapped