240 lines
7.6 KiB
Python
240 lines
7.6 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Pallas helper functions."""
|
|
|
|
from collections.abc import Callable, Mapping, Sequence
|
|
import functools
|
|
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import api
|
|
from jax._src import checkify
|
|
from jax._src import config
|
|
from jax._src import core as jax_core
|
|
from jax._src.pallas import core as pl_core
|
|
from jax._src.pallas import primitives as pl_primitives
|
|
from jax._src.pallas import utils as pl_utils
|
|
import jax.numpy as jnp
|
|
|
|
|
|
empty = jax.named_call(lax.empty)
|
|
|
|
|
|
@jax.named_call
|
|
def empty_like(x: object):
|
|
return jax.tree.map(lambda leaf: empty(leaf.shape, leaf.dtype), x)
|
|
|
|
|
|
def empty_ref_like(x: object) -> jax.Array:
|
|
"""Returns an empty array Ref with same shape/dtype/memory space as x."""
|
|
match x:
|
|
case pl_core.MemoryRef():
|
|
memory_space = x.memory_space
|
|
case jax.ShapeDtypeStruct():
|
|
memory_space = pl_core.MemorySpace.ANY
|
|
case _:
|
|
raise ValueError(f'empty_ref_like does not support {type(x)}')
|
|
return jax_core.new_ref(empty_like(x), memory_space=memory_space)
|
|
|
|
|
|
def when(
|
|
condition: bool | jax.typing.ArrayLike, /
|
|
) -> Callable[[Callable[[], None]], Callable[[], None]]:
|
|
"""Calls the decorated function when the condition is met.
|
|
|
|
Args:
|
|
condition: If a boolean, this is equivalent to ``if condition: f()``. If an
|
|
array, ``when`` produces a :func:`jax.lax.cond` with the decorated
|
|
function as the true branch.
|
|
|
|
Returns:
|
|
A decorator.
|
|
"""
|
|
def _wrapped(f):
|
|
if isinstance(condition, bool):
|
|
if condition:
|
|
f()
|
|
else:
|
|
jax.lax.cond(condition, f, lambda: None)
|
|
return _wrapped
|
|
|
|
|
|
def loop(
|
|
lower: jax.typing.ArrayLike,
|
|
upper: jax.typing.ArrayLike,
|
|
*,
|
|
step: jax.typing.ArrayLike = 1,
|
|
unroll: int | bool | None = None,
|
|
) -> Callable[[Callable[[jax.Array], None]], None]:
|
|
"""Returns a decorator that calls the decorated function in a loop."""
|
|
zero: jax.typing.ArrayLike
|
|
if not all(map(jax_core.is_concrete, (lower, upper, step))):
|
|
idx_type = jnp.result_type(lower, upper, step)
|
|
lower = jax.lax.convert_element_type(lower, idx_type)
|
|
upper = jax.lax.convert_element_type(upper, idx_type)
|
|
step = jax.lax.convert_element_type(step, idx_type)
|
|
zero = jnp.array(0, dtype=idx_type)
|
|
else:
|
|
zero = 0
|
|
|
|
def decorator(body):
|
|
jax.lax.fori_loop(
|
|
zero,
|
|
pl_utils.cdiv(upper - lower, step),
|
|
lambda idx, _: body(lower + idx * step),
|
|
init_val=None,
|
|
unroll=unroll,
|
|
)
|
|
|
|
return decorator
|
|
|
|
|
|
_ENABLE_DEBUG_CHECKS = config.bool_state(
|
|
"jax_pallas_enable_debug_checks",
|
|
default=False,
|
|
help=(
|
|
"If set, ``pl.debug_check`` calls are checked at runtime. Otherwise,"
|
|
" they are a noop."
|
|
),
|
|
)
|
|
|
|
|
|
enable_debug_checks = _ENABLE_DEBUG_CHECKS
|
|
|
|
|
|
def debug_checks_enabled() -> bool:
|
|
"""Returns runtime checks are enabled."""
|
|
return _ENABLE_DEBUG_CHECKS.value
|
|
|
|
|
|
def debug_check(condition, message):
|
|
"""Check the condition if
|
|
:func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise
|
|
do nothing.
|
|
"""
|
|
return checkify.debug_check(condition, message)
|
|
|
|
|
|
def _make_kernel(body,
|
|
out_shape: object,
|
|
mesh: pl_core.Mesh,
|
|
scratch_shapes: pl_core.ScratchShapeTree = (),
|
|
**mesh_kwargs
|
|
):
|
|
if unwrap_out := not isinstance(out_shape, (tuple, list)):
|
|
out_shape = (out_shape,)
|
|
|
|
@jax.jit
|
|
def wrapper(*operands):
|
|
arg_refs = jax.tree.map(jax_core.new_ref, operands)
|
|
out_refs = jax.tree.map(
|
|
lambda out: jax_core.new_ref(
|
|
lax.empty(out.shape, out.dtype),
|
|
memory_space=(
|
|
ms
|
|
if hasattr(out, "memory_space")
|
|
and not isinstance(
|
|
ms := out.memory_space, jax_core.MemorySpace
|
|
)
|
|
else None
|
|
),
|
|
),
|
|
out_shape,
|
|
)
|
|
|
|
@pl_core.core_map(mesh, **mesh_kwargs)
|
|
def _():
|
|
return pl_primitives.run_scoped(
|
|
functools.partial(body, *arg_refs, *out_refs),
|
|
*scratch_shapes if isinstance(scratch_shapes, Sequence) else (),
|
|
**scratch_shapes if isinstance(scratch_shapes, Mapping) else {},
|
|
)
|
|
|
|
outs = jax.tree.map(lambda ref: ref[...], out_refs)
|
|
return outs[0] if unwrap_out else outs
|
|
return wrapper
|
|
|
|
|
|
def kernel(body: Callable | api.NotSpecified = api.NotSpecified(), # pylint: disable=g-bare-generic
|
|
out_shape: object | None = None,
|
|
*,
|
|
mesh: pl_core.Mesh,
|
|
scratch_shapes: pl_core.ScratchShapeTree = (),
|
|
compiler_params: pl_core.CompilerParams | None = None,
|
|
interpret: bool = False,
|
|
cost_estimate: pl_core.CostEstimate | None = None,
|
|
debug: bool = False,
|
|
name: str | None = None,
|
|
metadata: dict[str, str] | None = None,
|
|
):
|
|
"""Entry point for creating a Pallas kernel.
|
|
|
|
This is a convenience wrapper around ``core_map`` for executing a kernel
|
|
over a mesh and ``run_scoped`` for allocating scratch memory.
|
|
|
|
If ``body`` is provided, this function behaves as a decorator::
|
|
|
|
def kernel_body(in_ref, out_ref):
|
|
...
|
|
kernel = pl.kernel(kernel_body, out_shape=...)
|
|
|
|
If ``body`` is omitted, this function behaves as a decorator factory and
|
|
will return a decorator that can be used to annotate a kernel body::
|
|
|
|
@pl.kernel(out_shape=...)
|
|
def kernel(in_ref, out_ref):
|
|
...
|
|
|
|
Args:
|
|
body: The body of the kernel. If provided, this function behaves as a
|
|
decorator, and if omitted, this function behaves as a decorator factory.
|
|
out_shape: The shape of the output. Should be a PyTree of
|
|
``jax.ShapeDtypeStruct`` or ``jax.Array``s.
|
|
mesh: The mesh to run the kernel on.
|
|
scratch_shapes: The shapes of the scratch arrays.
|
|
compiler_params: The compiler parameters to pass to the backend.
|
|
interpret: Whether to run the function in interpret mode.
|
|
debug: Whether or not to out helpful debugging information.
|
|
cost_estimate: The cost estimate of the function.
|
|
name: The (optional) name of the kernel.
|
|
metadata: Optional dictionary of information about the kernel that will be
|
|
serialized as JSON in the HLO. Can be used for debugging and analysis.
|
|
|
|
Returns:
|
|
If ``body`` is provided, returns a function that runs the kernel.
|
|
It should take any number of input operands and returns an output with the
|
|
same PyTree structure as `out_shape`.
|
|
If ``body`` is omitted, returns a decorator that can be used to annotate
|
|
a kernel body.
|
|
"""
|
|
# Note we default out_shape to None to allow `body` to come before it
|
|
# in the function signature, but `body` itself is optional.
|
|
if out_shape is None:
|
|
raise ValueError('out_shape must be provided.')
|
|
kwds = dict(
|
|
out_shape=out_shape,
|
|
mesh=mesh,
|
|
scratch_shapes=scratch_shapes,
|
|
compiler_params=compiler_params,
|
|
interpret=interpret,
|
|
cost_estimate=cost_estimate,
|
|
debug=debug,
|
|
name=name,
|
|
metadata=metadata)
|
|
if isinstance(body, api.NotSpecified):
|
|
return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type]
|
|
else:
|
|
return _make_kernel(body, **kwds) # type: ignore[arg-type]
|