DriverTrac/venv/lib/python3.12/site-packages/jax/_src/pallas/helpers.py
2025-11-28 09:08:33 +05:30

147 lines
3.9 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas helper functions."""
from collections.abc import Callable
import jax
import jax.numpy as jnp
from jax._src import core as jax_core
from jax._src import checkify
from jax._src import config
from jax._src.pallas import core as pl_core
from jax._src.pallas import utils as pl_utils
from jax._src.pallas import pallas_call
@jax.named_call
def empty(
shape: tuple[int, ...],
dtype: jax.typing.DTypeLike,
*,
memory_space: object | None = None,
interpret: bool = False,
backend: pl_core.Backend | None = None,
):
return empty_like(
jax.ShapeDtypeStruct(shape, dtype),
memory_space=memory_space,
interpret=interpret,
backend=backend,
)
@jax.named_call
def empty_like(
x: object,
*,
memory_space: object | None = None,
interpret: bool = False,
backend: pl_core.Backend | None = None,
):
if hasattr(x, 'memory_space'):
if memory_space is not None:
raise ValueError(
'memory_space cannot be specified for a MemoryRef object.'
)
memory_space = x.memory_space
if memory_space is None:
memory_space = pl_core.MemorySpace.ANY
return pallas_call.pallas_call(
# No-op to leave the out_ref uninitialized
lambda *_: None,
out_specs=jax.tree.map(
lambda _: pl_core.BlockSpec(memory_space=memory_space), x
),
out_shape=x,
interpret=interpret,
backend=backend,
)()
def empty_ref_like(
x: object, *, backend: pl_core.Backend | None = None
) -> jax.Array:
"""Returns an empty array Ref with same shape/dtype/memory space as x."""
match x:
case pl_core.MemoryRef():
memory_space = x.memory_space
case jax.ShapeDtypeStruct():
memory_space = pl_core.MemorySpace.ANY
case _:
raise ValueError(f'alloc_ref does not support {type(x)}')
out = empty_like(x, backend=backend)
return jax_core.mutable_array(out, memory_space=memory_space)
def when(condition):
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
jax.lax.cond(condition, f, lambda: None)
return _wrapped
def loop(
lower: jax.typing.ArrayLike,
upper: jax.typing.ArrayLike,
*,
step: jax.typing.ArrayLike = 1,
unroll: int | bool | None = None,
) -> Callable[[Callable[[jax.Array], None]], None]:
"""Returns a decorator that calls the decorated function in a loop."""
idx_type = jnp.result_type(lower, upper, step)
lower = jax.lax.convert_element_type(lower, idx_type)
upper = jax.lax.convert_element_type(upper, idx_type)
step = jax.lax.convert_element_type(step, idx_type)
def decorator(body):
jax.lax.fori_loop(
0,
pl_utils.cdiv(upper - lower, step),
lambda idx, _: body(lower + idx * step),
init_val=None,
unroll=unroll,
)
return decorator
_ENABLE_DEBUG_CHECKS = config.bool_state(
"jax_pallas_enable_debug_checks",
default=False,
help=(
"If set, ``pl.debug_check`` calls are checked at runtime. Otherwise,"
" they are a noop."
),
)
enable_debug_checks = _ENABLE_DEBUG_CHECKS
def debug_checks_enabled() -> bool:
"""Returns runtime checks are enabled."""
return _ENABLE_DEBUG_CHECKS.value
def debug_check(condition, message):
"""Check the condition if
:func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise
do nothing.
"""
return checkify.debug_check(condition, message)