DriverTrac/venv/lib/python3.12/site-packages/jax/_src/pallas/helpers.py

240 lines
7.6 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas helper functions."""
from collections.abc import Callable, Mapping, Sequence
import functools
import jax
from jax import lax
from jax._src import api
from jax._src import checkify
from jax._src import config
from jax._src import core as jax_core
from jax._src.pallas import core as pl_core
from jax._src.pallas import primitives as pl_primitives
from jax._src.pallas import utils as pl_utils
import jax.numpy as jnp
empty = jax.named_call(lax.empty)
@jax.named_call
def empty_like(x: object):
return jax.tree.map(lambda leaf: empty(leaf.shape, leaf.dtype), x)
def empty_ref_like(x: object) -> jax.Array:
"""Returns an empty array Ref with same shape/dtype/memory space as x."""
match x:
case pl_core.MemoryRef():
memory_space = x.memory_space
case jax.ShapeDtypeStruct():
memory_space = pl_core.MemorySpace.ANY
case _:
raise ValueError(f'empty_ref_like does not support {type(x)}')
return jax_core.new_ref(empty_like(x), memory_space=memory_space)
def when(
condition: bool | jax.typing.ArrayLike, /
) -> Callable[[Callable[[], None]], Callable[[], None]]:
"""Calls the decorated function when the condition is met.
Args:
condition: If a boolean, this is equivalent to ``if condition: f()``. If an
array, ``when`` produces a :func:`jax.lax.cond` with the decorated
function as the true branch.
Returns:
A decorator.
"""
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
jax.lax.cond(condition, f, lambda: None)
return _wrapped
def loop(
lower: jax.typing.ArrayLike,
upper: jax.typing.ArrayLike,
*,
step: jax.typing.ArrayLike = 1,
unroll: int | bool | None = None,
) -> Callable[[Callable[[jax.Array], None]], None]:
"""Returns a decorator that calls the decorated function in a loop."""
zero: jax.typing.ArrayLike
if not all(map(jax_core.is_concrete, (lower, upper, step))):
idx_type = jnp.result_type(lower, upper, step)
lower = jax.lax.convert_element_type(lower, idx_type)
upper = jax.lax.convert_element_type(upper, idx_type)
step = jax.lax.convert_element_type(step, idx_type)
zero = jnp.array(0, dtype=idx_type)
else:
zero = 0
def decorator(body):
jax.lax.fori_loop(
zero,
pl_utils.cdiv(upper - lower, step),
lambda idx, _: body(lower + idx * step),
init_val=None,
unroll=unroll,
)
return decorator
_ENABLE_DEBUG_CHECKS = config.bool_state(
"jax_pallas_enable_debug_checks",
default=False,
help=(
"If set, ``pl.debug_check`` calls are checked at runtime. Otherwise,"
" they are a noop."
),
)
enable_debug_checks = _ENABLE_DEBUG_CHECKS
def debug_checks_enabled() -> bool:
"""Returns runtime checks are enabled."""
return _ENABLE_DEBUG_CHECKS.value
def debug_check(condition, message):
"""Check the condition if
:func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise
do nothing.
"""
return checkify.debug_check(condition, message)
def _make_kernel(body,
out_shape: object,
mesh: pl_core.Mesh,
scratch_shapes: pl_core.ScratchShapeTree = (),
**mesh_kwargs
):
if unwrap_out := not isinstance(out_shape, (tuple, list)):
out_shape = (out_shape,)
@jax.jit
def wrapper(*operands):
arg_refs = jax.tree.map(jax_core.new_ref, operands)
out_refs = jax.tree.map(
lambda out: jax_core.new_ref(
lax.empty(out.shape, out.dtype),
memory_space=(
ms
if hasattr(out, "memory_space")
and not isinstance(
ms := out.memory_space, jax_core.MemorySpace
)
else None
),
),
out_shape,
)
@pl_core.core_map(mesh, **mesh_kwargs)
def _():
return pl_primitives.run_scoped(
functools.partial(body, *arg_refs, *out_refs),
*scratch_shapes if isinstance(scratch_shapes, Sequence) else (),
**scratch_shapes if isinstance(scratch_shapes, Mapping) else {},
)
outs = jax.tree.map(lambda ref: ref[...], out_refs)
return outs[0] if unwrap_out else outs
return wrapper
def kernel(body: Callable | api.NotSpecified = api.NotSpecified(), # pylint: disable=g-bare-generic
out_shape: object | None = None,
*,
mesh: pl_core.Mesh,
scratch_shapes: pl_core.ScratchShapeTree = (),
compiler_params: pl_core.CompilerParams | None = None,
interpret: bool = False,
cost_estimate: pl_core.CostEstimate | None = None,
debug: bool = False,
name: str | None = None,
metadata: dict[str, str] | None = None,
):
"""Entry point for creating a Pallas kernel.
This is a convenience wrapper around ``core_map`` for executing a kernel
over a mesh and ``run_scoped`` for allocating scratch memory.
If ``body`` is provided, this function behaves as a decorator::
def kernel_body(in_ref, out_ref):
...
kernel = pl.kernel(kernel_body, out_shape=...)
If ``body`` is omitted, this function behaves as a decorator factory and
will return a decorator that can be used to annotate a kernel body::
@pl.kernel(out_shape=...)
def kernel(in_ref, out_ref):
...
Args:
body: The body of the kernel. If provided, this function behaves as a
decorator, and if omitted, this function behaves as a decorator factory.
out_shape: The shape of the output. Should be a PyTree of
``jax.ShapeDtypeStruct`` or ``jax.Array``s.
mesh: The mesh to run the kernel on.
scratch_shapes: The shapes of the scratch arrays.
compiler_params: The compiler parameters to pass to the backend.
interpret: Whether to run the function in interpret mode.
debug: Whether or not to out helpful debugging information.
cost_estimate: The cost estimate of the function.
name: The (optional) name of the kernel.
metadata: Optional dictionary of information about the kernel that will be
serialized as JSON in the HLO. Can be used for debugging and analysis.
Returns:
If ``body`` is provided, returns a function that runs the kernel.
It should take any number of input operands and returns an output with the
same PyTree structure as `out_shape`.
If ``body`` is omitted, returns a decorator that can be used to annotate
a kernel body.
"""
# Note we default out_shape to None to allow `body` to come before it
# in the function signature, but `body` itself is optional.
if out_shape is None:
raise ValueError('out_shape must be provided.')
kwds = dict(
out_shape=out_shape,
mesh=mesh,
scratch_shapes=scratch_shapes,
compiler_params=compiler_params,
interpret=interpret,
cost_estimate=cost_estimate,
debug=debug,
name=name,
metadata=metadata)
if isinstance(body, api.NotSpecified):
return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type]
else:
return _make_kernel(body, **kwds) # type: ignore[arg-type]