# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pallas helper functions.""" from collections.abc import Callable, Mapping, Sequence import functools import jax from jax import lax from jax._src import api from jax._src import checkify from jax._src import config from jax._src import core as jax_core from jax._src.pallas import core as pl_core from jax._src.pallas import primitives as pl_primitives from jax._src.pallas import utils as pl_utils import jax.numpy as jnp empty = jax.named_call(lax.empty) @jax.named_call def empty_like(x: object): return jax.tree.map(lambda leaf: empty(leaf.shape, leaf.dtype), x) def empty_ref_like(x: object) -> jax.Array: """Returns an empty array Ref with same shape/dtype/memory space as x.""" match x: case pl_core.MemoryRef(): memory_space = x.memory_space case jax.ShapeDtypeStruct(): memory_space = pl_core.MemorySpace.ANY case _: raise ValueError(f'empty_ref_like does not support {type(x)}') return jax_core.new_ref(empty_like(x), memory_space=memory_space) def when( condition: bool | jax.typing.ArrayLike, / ) -> Callable[[Callable[[], None]], Callable[[], None]]: """Calls the decorated function when the condition is met. Args: condition: If a boolean, this is equivalent to ``if condition: f()``. If an array, ``when`` produces a :func:`jax.lax.cond` with the decorated function as the true branch. Returns: A decorator. """ def _wrapped(f): if isinstance(condition, bool): if condition: f() else: jax.lax.cond(condition, f, lambda: None) return _wrapped def loop( lower: jax.typing.ArrayLike, upper: jax.typing.ArrayLike, *, step: jax.typing.ArrayLike = 1, unroll: int | bool | None = None, ) -> Callable[[Callable[[jax.Array], None]], None]: """Returns a decorator that calls the decorated function in a loop.""" zero: jax.typing.ArrayLike if not all(map(jax_core.is_concrete, (lower, upper, step))): idx_type = jnp.result_type(lower, upper, step) lower = jax.lax.convert_element_type(lower, idx_type) upper = jax.lax.convert_element_type(upper, idx_type) step = jax.lax.convert_element_type(step, idx_type) zero = jnp.array(0, dtype=idx_type) else: zero = 0 def decorator(body): jax.lax.fori_loop( zero, pl_utils.cdiv(upper - lower, step), lambda idx, _: body(lower + idx * step), init_val=None, unroll=unroll, ) return decorator _ENABLE_DEBUG_CHECKS = config.bool_state( "jax_pallas_enable_debug_checks", default=False, help=( "If set, ``pl.debug_check`` calls are checked at runtime. Otherwise," " they are a noop." ), ) enable_debug_checks = _ENABLE_DEBUG_CHECKS def debug_checks_enabled() -> bool: """Returns runtime checks are enabled.""" return _ENABLE_DEBUG_CHECKS.value def debug_check(condition, message): """Check the condition if :func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise do nothing. """ return checkify.debug_check(condition, message) def _make_kernel(body, out_shape: object, mesh: pl_core.Mesh, scratch_shapes: pl_core.ScratchShapeTree = (), **mesh_kwargs ): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) @jax.jit def wrapper(*operands): arg_refs = jax.tree.map(jax_core.new_ref, operands) out_refs = jax.tree.map( lambda out: jax_core.new_ref( lax.empty(out.shape, out.dtype), memory_space=( ms if hasattr(out, "memory_space") and not isinstance( ms := out.memory_space, jax_core.MemorySpace ) else None ), ), out_shape, ) @pl_core.core_map(mesh, **mesh_kwargs) def _(): return pl_primitives.run_scoped( functools.partial(body, *arg_refs, *out_refs), *scratch_shapes if isinstance(scratch_shapes, Sequence) else (), **scratch_shapes if isinstance(scratch_shapes, Mapping) else {}, ) outs = jax.tree.map(lambda ref: ref[...], out_refs) return outs[0] if unwrap_out else outs return wrapper def kernel(body: Callable | api.NotSpecified = api.NotSpecified(), # pylint: disable=g-bare-generic out_shape: object | None = None, *, mesh: pl_core.Mesh, scratch_shapes: pl_core.ScratchShapeTree = (), compiler_params: pl_core.CompilerParams | None = None, interpret: bool = False, cost_estimate: pl_core.CostEstimate | None = None, debug: bool = False, name: str | None = None, metadata: dict[str, str] | None = None, ): """Entry point for creating a Pallas kernel. This is a convenience wrapper around ``core_map`` for executing a kernel over a mesh and ``run_scoped`` for allocating scratch memory. If ``body`` is provided, this function behaves as a decorator:: def kernel_body(in_ref, out_ref): ... kernel = pl.kernel(kernel_body, out_shape=...) If ``body`` is omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:: @pl.kernel(out_shape=...) def kernel(in_ref, out_ref): ... Args: body: The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory. out_shape: The shape of the output. Should be a PyTree of ``jax.ShapeDtypeStruct`` or ``jax.Array``s. mesh: The mesh to run the kernel on. scratch_shapes: The shapes of the scratch arrays. compiler_params: The compiler parameters to pass to the backend. interpret: Whether to run the function in interpret mode. debug: Whether or not to out helpful debugging information. cost_estimate: The cost estimate of the function. name: The (optional) name of the kernel. metadata: Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis. Returns: If ``body`` is provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as `out_shape`. If ``body`` is omitted, returns a decorator that can be used to annotate a kernel body. """ # Note we default out_shape to None to allow `body` to come before it # in the function signature, but `body` itself is optional. if out_shape is None: raise ValueError('out_shape must be provided.') kwds = dict( out_shape=out_shape, mesh=mesh, scratch_shapes=scratch_shapes, compiler_params=compiler_params, interpret=interpret, cost_estimate=cost_estimate, debug=debug, name=name, metadata=metadata) if isinstance(body, api.NotSpecified): return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type] else: return _make_kernel(body, **kwds) # type: ignore[arg-type]