# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it from typing import TypeVar, Any, Union import numpy as np from jax._src import ad_checkpoint from jax._src import api from jax._src import api_util from jax._src import callback from jax._src import config from jax._src import core from jax._src import custom_derivatives from jax._src import dtypes from jax._src import effects from jax._src import lax from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import numpy as jnp from jax._src import pjit from jax._src import shard_map as jshmap from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util as jtu from jax._src.ad_util import SymbolicZero from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec as P from jax._src.tree_util import tree_flatten from jax._src.tree_util import tree_map from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache, HashableWrapper, foreach) # Backward compatibility: some downstream users implicitly rely on this import, # and reference jax.experimental.shard_map without an explicit import. # TODO(yashkatariya): remove this once users are migrated to jax.shard_map. try: import jax.experimental.shard_map as _ # pytype: disable=import-error # noqa: F401 except ImportError: pass source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip Bool = Union[bool, Array] Int = Union[int, Array] ErrorCategory = type['JaxException'] Payload = list[Union[np.ndarray, Array]] PyTreeDef = jtu.PyTreeDef Out = TypeVar('Out') ## Utils def popattr(obj, attrname): val = getattr(obj, attrname) delattr(obj, attrname) return val def setnewattr(obj, name, val): sentinel = object() assert getattr(obj, name, sentinel) is sentinel setattr(obj, name, val) # Concrete errors class JaxException(Exception): """Python exception which can contain an error message with JAX run-time info.""" def __init__(self, traceback_info): self.traceback_info = traceback_info # TODO(lenamartens): re-enable tracebacks when they don't leak tracers. # self.with_traceback(self.traceback_info) def __init_subclass__(cls): jtu.register_pytree_node_class(cls) def tree_flatten(self): return ([], self.traceback_info) @classmethod def tree_unflatten(cls, metadata, payload): del payload return cls(metadata) def get_effect_type(self) -> ErrorEffect: raise NotImplementedError @functools.total_ordering @dataclasses.dataclass(eq=True, frozen=True) class ErrorEffect(effects.Effect): error_type: type[JaxException] shape_dtypes: tuple[api.ShapeDtypeStruct, ...] def __lt__(self, other: ErrorEffect): shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable for sd in x.shape_dtypes) unpack = lambda x: (str(x.error_type), shape_dtypes(x)) return (unpack(self) < unpack(other)) effects.lowerable_effects.add_type(ErrorEffect) effects.control_flow_allowed_effects.add_type(ErrorEffect) effects.custom_derivatives_allowed_effects.add_type(ErrorEffect) effects.remat_allowed_effects.add_type(ErrorEffect) class DivisionByZeroError(JaxException): def __str__(self): return 'division by zero' def get_effect_type(self): return ErrorEffect(DivisionByZeroError, ()) class NaNError(JaxException): def __init__(self, traceback_info, primitive_name): super().__init__(traceback_info) self.prim = primitive_name def tree_flatten(self): return ([], (self.traceback_info, self.prim)) @classmethod def tree_unflatten(cls, metadata, _): return cls(*metadata) def get_effect_type(self): return ErrorEffect(NaNError, ()) def __str__(self): return f'nan generated by primitive: {self.prim}.' class OOBError(JaxException): def __init__(self, traceback_info, primitive_name, operand_shape, payload): super().__init__(traceback_info) self.prim = primitive_name self.operand_shape = operand_shape self._payload = payload def tree_flatten(self): return ([self._payload], (self.traceback_info, self.prim, self.operand_shape)) @classmethod def tree_unflatten(cls, metadata, payload): return cls(*metadata, payload[0]) def __str__(self): return (f'out-of-bounds indexing for array of ' f'shape {self.operand_shape}: ' f'index {self._payload[0]} is out of bounds for axis ' f'{self._payload[1]} with size {self._payload[2]}. ') def get_effect_type(self): return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), np.int32),)) class FailedCheckError(JaxException): def __init__(self, traceback_info, fmt_string, *a, **k): super().__init__(traceback_info) self.fmt_string = fmt_string self.args = a self.kwargs = k def tree_flatten(self): return ((self.args, self.kwargs), # leaves (self.traceback_info, self.fmt_string)) # treedef @classmethod def tree_unflatten(cls, metadata, payload): args, kwargs = payload return cls(*metadata, *args, **kwargs) def __str__(self): return (self.fmt_string.format(*self.args, **self.kwargs) + ' (`check` failed)') def get_effect_type(self): vals = jtu.tree_leaves((self.args, self.kwargs)) return ErrorEffect( FailedCheckError, tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)) @dataclasses.dataclass class BatchedError(JaxException): error_mapping: dict[tuple[int, ...], JaxException] def __post_init__(self): traceback_info = list(self.error_mapping.values())[0].traceback_info super().__init__(traceback_info) def __str__(self): return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}' for idx, e in self.error_mapping.items()) # Error Value @jtu.register_pytree_node_class @dataclasses.dataclass(frozen=True) class Error: _pred: dict[ErrorEffect, Bool] _code: dict[ErrorEffect, Int] _metadata: dict[Int, PyTreeDef] # mapping of code to JaxException treedef. _payload: dict[ErrorEffect, Payload] def get(self) -> str | None: """Returns error message if error happened, None if no error happened.""" exp = self.get_exception() if exp is not None: return str(exp) return None def get_exception(self) -> JaxException | None: """Returns Python exception if error happened, None if no error happened.""" if any(map(np.shape, self._pred.values())): return self._get_batched_exception() else: min_code = None cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect]: if min_code is None or code < min_code: min_code = code cur_effect = error_effect if cur_effect is not None: return tree_unflatten(self._metadata[int(min_code)], # type: ignore self._payload[cur_effect]) return None def throw(self): _check_error(self) def __str__(self): return f'Error({self.get()})' # Internal helpers def _get_batched_exception(self) -> BatchedError | None: shape = np.shape(list(self._pred.values())[0]) error_mapping = {} for idx in np.ndindex(*shape): min_code = None cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore if min_code is None or code[idx] < min_code: # type: ignore[index] min_code = code[idx] # type: ignore cur_effect = error_effect if cur_effect is not None: payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect]) jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore error_mapping[idx] = jax_error if error_mapping: return BatchedError(error_mapping) else: return None def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload): new_errs = {**self._pred, **{effect_type: pred}} new_codes = {**self._code, **{effect_type: code}} new_payload = {**self._payload, **{effect_type: payload}} new_metadata = {**self._metadata, **metadata} return Error(new_errs, new_codes, new_metadata, new_payload) def _add_placeholder_effects(self, effects: set[ErrorEffect]): """Fill out Error with `effects` and np.ones arrays of their payloads.""" new_err = self._pred.copy() new_code = self._code.copy() new_payload = self._payload.copy() for effect in effects: if effect not in self._pred.keys(): new_err[effect] = False new_payload[effect] = list( tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes)) # The error value associated with this effect will never become True, so # we don't need to set a meaningful code. new_code[effect] = -1 return Error(new_err, new_code, self._metadata, new_payload) def _replace(self, *args, **kwargs): return dataclasses.replace(self, *args, **kwargs) # PyTree methods def tree_flatten(self): return ((self._pred, self._code, self._payload), (self._metadata)) @classmethod def tree_unflatten(cls, metadata, data): pred, code, payload = data return cls(pred, code, metadata, payload) init_error = Error({}, {}, {}, {}) # value used as initial (empty) error. next_code = it.count(1).__next__ # globally unique ids, could be uuid4 def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error: code = next_code() effect_type = new_error.get_effect_type() new_payload, new_metadata = tree_flatten(new_error) return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type) def update_error(error, pred, code, metadata, payload, effect_type): err_of_type = error._pred.get(effect_type, False) out_err = err_of_type | pred out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code) cur_payload = error._payload.get(effect_type, None) if cur_payload is not None: out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload) else: out_payload = payload return error._update(effect_type, out_err, out_code, metadata, out_payload) ## Checkify transformation for plumbing functional error values. @lu.transformation_with_aux2 def _flatten_and_get_error_metadata_thunk(f, store, *invals): error, out = f(*invals) out_vals, out_tree = jtu.tree_flatten((error, out)) store.store((out_tree, set(error._pred.keys()))) return out_vals def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, **params: Any) -> tuple[Error, Sequence[core.Value]]: """Default rule for primitives in `checkify` interpreter.""" if 'call_jaxpr' not in params: # Default non-HOP case: just call primitive and don't update error. return error, primitive.bind(*invals, **params) # Code below handles call- and map-primitives, by recursively calling # checkify_jaxpr. err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) if 'donated_invars' in params: params = dict(params, donated_invars=(*[False]*num_error_vals, *params['donated_invars'])) # call_jaxpr handling call_jaxpr = params.pop('call_jaxpr') if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts else: jaxpr, consts = call_jaxpr, () consts_ = tuple(HashableWrapper(c) for c in consts) partial_checkify = lu.hashable_partial( lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info), jaxpr, consts_, enabled_errors, err_tree) partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) # map-specific params handling. if isinstance(primitive, core.MapPrimitive): # Update `in_axes` and `out_axes_thunk` params for map primitive. out_val_axes = params.pop('out_axes') @as_hashable_function(closure=out_val_axes) def out_axes_thunk(): out_err_num = metadata()[0].num_leaves - len(out_val_axes) return (*(0,)*out_err_num, *out_val_axes) params = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']), out_axes_thunk=out_axes_thunk) all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params) out_tree, _ = metadata() error, out_vals = tree_unflatten(out_tree, all_vals) if isinstance(primitive, core.MapPrimitive): error = _reduce_any_error(error) return error, out_vals def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, *args) -> tuple[Error, list[core.Value]]: err_vals, err_tree = jtu.tree_flatten(error) return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *err_vals, *args) def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], enabled_errors, err_tree: PyTreeDef, *args: core.Value) -> tuple[Error, list[Any]]: env: dict[core.Var, Any] = {} err_vals, in_args = split_list(args, [err_tree.num_leaves]) error = jtu.tree_unflatten(err_tree, err_vals) last_used = core.last_used(jaxpr) def read_env(var: core.Atom): if isinstance(var, core.Literal): return var.val return env[var] def write_env(var: core.Var, val: Any): env[var] = val foreach(write_env, jaxpr.constvars, consts) foreach(write_env, jaxpr.invars, in_args) # interpreter loop for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) checkify_rule = error_checks.get( eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive)) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): error, outvals = checkify_rule(error, enabled_errors, *invals, **eqn.params) if eqn.primitive.multiple_results: foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) core.clean_up_dead_vars(eqn, env, last_used) return error, map(read_env, jaxpr.outvars) def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, err_tree, *args): consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) @lu.transformation_with_aux2 def flatten_fun_output(f, store, *args): ans = f(*args) ans, out_tree = tree_flatten(ans) store.store(out_tree) return ans def _reduce_any_error(error: Error): out_error = init_error for error_effect in error._pred.keys(): errs, codes, payloads = (error._pred[error_effect], error._code[error_effect], error._payload[error_effect]) reduced_idx = jnp.argsort(errs)[-1] pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx], (errs, codes, payloads)) out_error = out_error._update(error_effect, pred, code, {}, payload) out_error = out_error._replace(_metadata=error._metadata) return out_error ## check_p primitive check_p = core.Primitive('check') check_p.is_effectful = lambda _: True # type: ignore check_p.multiple_results = True # zero results def _pp_check(eqn, context, settings) -> core.pp.Doc: annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) name_stack_annotation = (f'[{eqn.source_info.name_stack}]' if settings.name_stack else None) trimmed_params = sorted((k, v) for (k, v) in eqn.params.items() if k != "err_tree") rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation), core.pp_kv_pairs(trimmed_params, context, settings), core.pp.text(" ") + core.pp_vars(eqn.invars, context)] return core.pp.concat([core.pp.text("", annotation), *rhs]) core.pp_eqn_rules[check_p] = _pp_check # TODO(lenamartens): inherit from Exception instead of ValueError. class JaxRuntimeError(ValueError): pass @check_p.def_impl def check_impl(*args, err_tree, debug): if debug: # NOOP (check will only trigger when discharged) return [] error = tree_unflatten(err_tree, args) exc = error.get_exception() if exc: filtered_tb = traceback_util.filter_traceback( exc.traceback_info.as_python_traceback()) exc.with_traceback(filtered_tb) raise JaxRuntimeError(str(exc)) from exc return [] @check_p.def_effectful_abstract_eval def check_abstract_eval(*args, err_tree, debug): del debug return [], set(tree_unflatten(err_tree, args)._pred.keys()) # TODO(lenamartens) add in-depth error explanation to link to in module docs. functionalization_error = ValueError( 'Cannot abstractly evaluate a checkify.check which was not' ' functionalized. This probably means you tried to stage' ' (jit/scan/pmap/...) a `check` without functionalizing it' ' through `checkify.checkify`.' ) def check_lowering_rule(ctx, *args, err_tree, debug): if debug: # NOOP (check will only trigger when discharged) return [] if not config.xla_runtime_errors.value: raise functionalization_error out_op, _, _ = callback.emit_python_callback( ctx, callback=functools.partial(python_err, err_tree), token=None, operands=args, operand_avals=list(ctx.avals_in), result_avals=list(ctx.avals_out), has_side_effect=True) return out_op def check_lowering_rule_unsupported(*a, debug, **k): if debug: return [] raise functionalization_error def python_err(err_tree, *args): error = tree_unflatten(err_tree, args) _check_error(error) return [] mlir.register_lowering(check_p, check_lowering_rule_unsupported, platform='tpu') mlir.register_lowering(check_p, check_lowering_rule, platform='cpu') mlir.register_lowering(check_p, check_lowering_rule, platform='gpu') def check_batching_rule(batched_args, batch_dims, *, err_tree, debug): size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) if dim is not batching.not_mapped) batched_args = (batching.bdim_at_front(a, d, size) for a, d in zip(batched_args, batch_dims)) err = tree_unflatten(err_tree, batched_args) _check_error(err, debug=debug) return [], [] batching.primitive_batchers[check_p] = check_batching_rule def check_jvp_rule(primals, _, *, err_tree, debug): # Check primals, discard tangents. check_p.bind(*primals, err_tree=err_tree, debug=debug) return [], [] ad.primitive_jvps[check_p] = check_jvp_rule ## checkify rules ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: dict[core.Primitive, ErrorCheckRule] = {} def get_traceback(): return source_info_util.current().traceback def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) err = check_nans(prim, error, enabled_errors, out) return err, out def check_nans(prim, error, enabled_errors, out): if NaNError not in enabled_errors: return error def isnan(x): if dtypes.issubdtype(x.dtype, dtypes.prng_key): return False return jnp.any(jnp.isnan(x)) any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) if prim.multiple_results else isnan(out)) return assert_func(error, any_nans, NaNError(get_traceback(), prim.name)) # All primitives which can generate a NaN. nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p, lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p, lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p, lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p, lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p, lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p, lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p, lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p] for _prim in nan_primitives: error_checks[_prim] = functools.partial(nan_error_check, _prim) def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes): out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes) if OOBError not in enabled_errors: return error, out start_indices = jnp.array(start_indices) operand_dims = np.array(operand.shape, dtype=start_indices.dtype) slice_sizes = np.array(slice_sizes, dtype=start_indices.dtype) oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_slice", operand.shape, payload)) return error, out error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices): out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices) if OOBError not in enabled_errors: return error, out operand_dims = np.array(operand.shape) update_dims = np.array(update.shape) start_indices = jnp.array(start_indices) oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_update_slice", operand.shape, payload)) return error, out error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) if OOBError not in enabled_errors: return error, out # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) num_batch_dims = len(start_indices.shape) - 1 upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype)) payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape) error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "gather", operand.shape, payload)) return error, out error_checks[lax.gather_p] = gather_error_check def div_error_check(error, enabled_errors, x, y): """Checks for division by zero and NaN.""" if DivisionByZeroError in enabled_errors: any_zero = jnp.any(jnp.equal(y, 0)) error = assert_func(error, any_zero, DivisionByZeroError(get_traceback())) return nan_error_check(lax.div_p, error, enabled_errors, x, y) error_checks[lax.div_p] = div_error_check def oob_payload(oob_mask, indices, dims_map, operand_shape): # Get first OOB index, axis and axis size so it can be added to the error msg. flat_idx = jnp.argmin(jnp.logical_not(oob_mask)) multi_idx = jnp.unravel_index(flat_idx, indices.shape) oob_axis = jnp.array(dims_map)[multi_idx[-1]] oob_axis_size = jnp.array(operand_shape)[oob_axis] oob_index = jnp.ravel(indices)[flat_idx] payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=np.int32) return payload def scatter_oob(operand, indices, updates, dnums): # Ref: see clamping code used in scatter_translation_rule slice_sizes = [] pos = 0 for i in range(len(operand.shape)): if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) pos += 1 upper_bound = np.array([operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims], np.int64) upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) lower_oob = jnp.less(indices, 0) upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype)) oob_mask = jnp.logical_or(lower_oob, upper_oob) payload = oob_payload(oob_mask, indices, dnums.scatter_dims_to_operand_dims, operand.shape) return jnp.any(oob_mask), payload def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): """Checks if indices are within bounds and update does not generate NaN.""" out = prim.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if OOBError not in enabled_errors: return error, out out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers) oob_error = OOBError(get_traceback(), prim.name, operand.shape, payload) error = assert_func(error, out_of_bounds, oob_error) error = check_nans(prim, error, enabled_errors, out) return error, out error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p) error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check, lax.scatter_add_p) error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check, lax.scatter_mul_p) error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check, lax.scatter_min_p) error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, lax.scatter_max_p) # HOP error check rules @weakref_lru_cache def jaxpr_to_checkify_jaxpr( jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree) fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()) fun, metadata = _flatten_and_get_error_metadata_thunk(fun) new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects def cond_error_check(error: Error, enabled_errors, index, *ops, branches, **params): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) in_avals = map(core.get_aval, [*err_vals, *ops]) def get_error_effects_from_jaxpr(jxpr): _, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree, *in_avals) return effects effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches] merged_error = error._add_placeholder_effects(set().union(*effects)) err_vals, err_tree = jtu.tree_flatten(merged_error) # Update branch jaxprs to be checkified jaxprs. in_avals = map(core.get_aval, [*err_vals, *ops]) new_branches, out_trees, _ = unzip3( jaxpr_to_checkify_jaxpr( jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches) err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, branches=tuple(new_branches), **params) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) merged_metadata = err0._metadata for tr in out_trees[1:]: err, _ = tree_unflatten(tr, err_and_outs) merged_metadata = {**merged_metadata, **err._metadata} return err0._replace(_metadata=merged_metadata), out error_checks[lax.cond_p] = cond_error_check def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs] # Query body effects to create a merged error containing all effects (such # that in and out carried error are of the same type). err_vals, err_tree = jtu.tree_flatten(error) new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped _, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *new_in_aval) merged_error = error._add_placeholder_effects(effects) err_vals, err_tree = jtu.tree_flatten(merged_error) # Create checked-jaxpr, with the needed pre-processing on the inputs. new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *new_in_aval) tomove = ([False] * len(err_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))) checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) new_in_flat = [*consts, *err_vals, *carry, *xs] new_linear = (*[False] * len(err_vals), *linear) err_and_out = lax.scan_p.bind( *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, num_consts=len(consts), num_carry=len(carry)+len(err_vals), linear=new_linear, unroll=unroll, _split_transpose=_split_transpose) err, out = tree_unflatten(out_tree, err_and_out) return err, out error_checks[lax.scan_p] = scan_error_check def checkify_while_body_jaxpr( cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*c_consts_and_vals): c_consts, vals = split_list(c_consts_and_vals, [c_consts_num]) out = body_f(*vals) # This checks if the next cond application will error lax.dce_sink(cond_f(*c_consts, *out)) return out new_body_f_ = lu.wrap_init( new_body_f, debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names()) c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] jaxpr, _, () = pe.trace_to_jaxpr_dynamic( new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals]) closed_jaxpr = pe.close_jaxpr(jaxpr) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(core.get_aval, err_vals) flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals] jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return jaxpr, out_tree, error_effects @weakref_lru_cache def ignore_error_output_jaxpr(jaxpr, num_error_vals: int): """Constructs a checked jaxpr which does not output its error value.""" consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:]) return core.ClosedJaxpr(new_jaxpr, consts) def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): if cond_jaxpr.out_avals[0].shape: # TODO(lenamartens, sharadmv): support batched while. raise ValueError('Checkify does not support batched while-loops ' '(checkify-of-vmap-of-while). \nHint: if possible, move ' 'the vmap to the outer level to get ' 'vmap-of-checkify-of-while.') c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry) _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts) # merged error! error = error._add_placeholder_effects(error_effects) err_vals, err_tree = jtu.tree_flatten(error) checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts) num_error_vals = len(err_vals) to_move = ([False] * num_error_vals + [True] * cond_nconsts + [True] * body_nconsts + [False] * len(carry)) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) cond_in_flat = [*err_vals, *c_consts, *carry] cond_in_flat = map(core.get_aval, cond_in_flat) checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors, err_tree, *cond_in_flat) compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals) to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry] all_out_vals = lax.while_p.bind( *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr) # body_out_tree will have all the metadata of cond because it executes a cond! error, out = tree_unflatten(body_out_tree, all_out_vals) return error, out error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, inline, keep_unused, compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] in_avals = tuple(map(core.get_aval, new_vals_in)) checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors, err_tree, *in_avals) # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) err_and_out = pjit.jit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, out_shardings=new_out_shardings, in_layouts=new_in_layouts, out_layouts=new_out_layouts, donated_invars=new_donated_invars, ctx_mesh=ctx_mesh, name=name, inline=inline, keep_unused=keep_unused, compiler_options_kvs=compiler_options_kvs, ) return tree_unflatten(out_tree, err_and_out) error_checks[pjit.jit_p] = pjit_error_check def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] in_avals = tuple(map(core.get_aval, new_vals_in)) checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals) checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr, **params) return tree_unflatten(out_tree, err_and_out) error_checks[ad_checkpoint.remat_p] = remat_error_check def shard_map_error_check( error: Error, enabled_errors, *vals_in, jaxpr: core.Jaxpr, in_specs, out_specs, **kwargs ): if (mesh := kwargs.get('mesh')) is None: raise ValueError('Mesh must be provided for shard_map with checkify.') err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) # Replicated sharding for in errors. new_in_specs = (*([P()] * num_error_vals), *in_specs) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) manual_axes = kwargs.get('manual_axes') check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) with (jshmap._extend_axis_env(mesh, manual_axes), mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals ) num_out_error_vals = out_tree.num_leaves - len(out_specs) def expand_errors_leading_dim(*xs): outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) errs, outs = split_list(outs, [num_out_error_vals]) errs = [lax.expand_dims(e, [0]) for e in errs] return *errs, *outs with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(expand_errors_leading_dim, debug_info=checked_jaxpr.jaxpr.debug_info), checked_jaxpr.in_avals ) checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) # Update shard_map params to account for extra error values. # Use fully sharded partitioning for out errors. new_out_specs = (*([P(mesh.axis_names)] * num_out_error_vals), *out_specs) subfun = lu.hashable_partial( lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info), checked_jaxpr.jaxpr, checked_jaxpr.consts ) new_params = dict( jaxpr=checked_jaxpr.jaxpr, in_specs=new_in_specs, out_specs=new_out_specs, **kwargs, ) _, new_params = jshmap.shard_map_p.get_bind_params(new_params) err_and_out = jshmap.shard_map_p.bind(subfun, *new_vals_in, **new_params) return tree_unflatten(out_tree, err_and_out) error_checks[jshmap.shard_map_p] = shard_map_error_check def custom_jvp_call_rule(in_err: Error, enabled_errors: set, *in_vals, num_consts, jvp_jaxpr_fun: lu.WrappedFun, call_jaxpr: core.ClosedJaxpr, **params): # The types to have in mind are: # jvp : (a -> b) -> (a, T a) -> (b, T b) # checkify : (a -> b) -> a -> Err b # jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b)) # where because Err is a pytree, we necessarily have T (Err b) = Err' (T b) # where the other Err' components are trivial (of float0 dtype). # Semantically, we don't add checks to the JVP rule. To check the result of a # JVP rule, one must instead use checkify-of-jvp. Thus this implementation # just forwards the input error and code (and trivial tangents) to the output. err_vals, err_tree = jtu.tree_flatten(in_err) partial_checkify = lu.wrap_init( functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, call_jaxpr.consts, enabled_errors, err_tree), debug_info=call_jaxpr.jaxpr.debug_info) partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun) jvp, jvp_out_tree = flatten_fun_output(jvp) all_outs = custom_derivatives.custom_jvp_call_p.bind( partial_checkify, jvp, *err_vals, *in_vals, **params) fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree) if fst: err_and_out_tree, _ = out_metadata out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) else: err_vals, out_vals = split_list(all_outs, [len(err_vals)]) # forward input error to output out_err = jtu.tree_unflatten(err_tree, err_vals) return out_err, out_vals error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule # Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and # outputs that checkify adds (just forwarding the error data's primal and # tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those. # TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp? # Adding another layer of lu.transformation was tricky, though maybe doable. def lift_jvp(num_errs: int, num_consts: int, jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun: def jvp(*xs): n, ragged = divmod(len(xs), 2) assert not ragged primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:] zeros = [type(t) is SymbolicZero for t in tangents] jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros) nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero] out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None primal_errs = xs[num_consts:num_consts+num_errs] tangent_errs = xs[n+num_consts:n+num_consts+num_errs] return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) def custom_vjp_call_rule(in_err, enabled_errors, *in_vals, call_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk, num_consts, bwd: lu.WrappedFun, out_trees, symbolic_zeros: bool): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, call_jaxpr.consts, enabled_errors, err_tree), debug_info=call_jaxpr.jaxpr.debug_info) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] xs, zeros = xs[num_errs:], zeros[num_errs:] fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr checkified_fwd_wrapped = lu.wrap_init(checkified_fwd, debug_info=fwd_jaxpr_thunk.debug_info) bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)), debug_info=bwd.debug_info) checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped) all_outs = custom_derivatives.custom_vjp_call_p.bind( checkified_fun, checkified_fwd_wrapped, bwd_, *err_vals, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs) else: out_err, out_vals = in_err, all_outs return out_err, out_vals error_checks[custom_derivatives.custom_vjp_call_p] = custom_vjp_call_rule def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): del debug new_error = tree_unflatten(err_tree, args) # Split up new_error into error to be functionalized if it's included in # enabled_errors (=discharged_error) and an error to be defunctionalized if # it's not included (=recharged_error) discharged_error = error recharged_error = init_error for error_effect in new_error._pred.keys(): pred = new_error._pred[error_effect] code = new_error._code[error_effect] payload = new_error._payload[error_effect] if error_effect.error_type in enabled_errors: discharged_error = update_error(discharged_error, pred, code, {}, payload, error_effect) else: recharged_error = update_error(recharged_error, pred, code, {}, payload, error_effect) discharged_error = discharged_error._replace( _metadata={**new_error._metadata, **discharged_error._metadata}) recharged_error = recharged_error._replace(_metadata=new_error._metadata) # TODO(lenamartens): we actually need to recharge, but this would be a # breaking API change so leaving for a follow-up. # check_error(recharged_error) return discharged_error, [] error_checks[check_p] = check_discharge_rule ## checkify public api user_checks = frozenset({FailedCheckError}) nan_checks = frozenset({NaNError}) index_checks = frozenset({OOBError}) div_checks = frozenset({DivisionByZeroError}) float_checks = nan_checks | div_checks automatic_checks = float_checks | index_checks all_checks = automatic_checks | user_checks def checkify(f: Callable[..., Out], errors: frozenset[ErrorCategory] = user_checks ) -> Callable[..., tuple[Error, Out]]: """Functionalize `check` calls in `fun`, and optionally add run-time error checks. Run-time errors are either user-added :func:`~check` assertions, or automatically added checks like NaN checks, depending on the ``errors`` argument. The returned function will return an Error object `err` along with the output of the original function. ``err.get()`` will either return ``None`` (if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred. ``err.throw()`` will raise a ValueError with the error message if an error occurred. By default only user-added :func:`~check` assertions are enabled. You can enable automatic checks through the ``errors`` argument. The automatic check sets which can be enabled, and when an error is generated: - ``user_checks``: a :func:`~check` evaluated to False. - ``nan_checks``: a floating-point operation generated a NaN value as output. - ``div_checks``: a division by zero. - ``index_checks``: an index was out-of-bounds. Multiple categories can be enabled together by passing in an error `Set` (eg. ``errors=nan_checks``). Multiple sets can be re-combined (eg. ``errors=float_checks|user_checks``) Args: fun: Callable which can contain user checks (see :func:`~check`). errors: A set of ErrorCategory values which defines the set of enabled checks. By default only explicit ``checks`` are enabled (``user_checks``). You can also for example enable NAN and DIV errors by passing the ``float_checks`` set, or for example combine multiple sets through set operations (``float_checks | user_checks``) Returns: A function which accepts the same arguments as ``fun`` and returns as output a pair where the first element is an ``Error`` value, representing the first failed :func:`~check`, and the second element is the original output of ``fun``. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin """ @traceback_util.api_boundary def checked_fun(*args, **kwargs): # close over all arguments so they're not turned into abstract values. in_tree = jtu.tree_structure(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: debug = api_util.debug_info("checkify", f, args, kwargs) fun_, out_tree = api_util.flatten_fun( lu.wrap_init(closed_f, debug_info=debug.with_unknown_names()), in_tree) jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, ()) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) return error, jtu.tree_unflatten(out_tree(), out_flat) return checked_fun def check(pred: Bool, msg: str, *fmt_args, debug: bool = False, **fmt_kwargs, ) -> None: """Check a predicate, add an error with msg if predicate is False. This is an effectful operation, and can't be staged (jitted/scanned/...). Before staging a function with checks, :func:`~checkify` it! Args: pred: if False, a FailedCheckError error is added. msg: error message if error is added. Can be a format string. debug: Whether to turn on debugging mode. If True, check will be removed during execution. If False, the the check must be functionalized using checkify.checkify. fmt_args, fmt_kwargs: Positional and keyword formatting arguments for `msg`, eg.: ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)`` Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "{x} needs to be positive!", x=x) ... return 1/x >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(-3.) >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: -3. needs to be positive! """ _check(pred, msg, debug, *fmt_args, **fmt_kwargs) def _check(pred, msg, debug, *fmt_args, **fmt_kwargs): if not is_scalar_pred(pred): prim_name = 'debug_check' if debug else 'check' raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}') for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)): if not isinstance(arg, (Array, np.ndarray)): raise TypeError('Formatting arguments to checkify.check need to be ' 'PyTrees of arrays, but got ' f'{arg!r} of type {type(arg)}.') new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs) error = assert_func(init_error, jnp.logical_not(pred), new_error) _check_error(error, debug=debug) def _check_error(error, *, debug=False): if any(map(np.shape, error._pred.values())): error = _reduce_any_error(error) err_args, tree_def = tree_flatten(error) return check_p.bind(*err_args, err_tree=tree_def, debug=debug) def is_scalar_pred(pred) -> bool: return (isinstance(pred, bool) or isinstance(pred, Array) and pred.shape == () and pred.dtype == np.dtype('bool')) def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None: """Check a predicate when running under checkify, otherwise is a no-op. A `debug_check` will only be run if it is transformed by :func:`~checkify`, otherwise the check will be dropped. Args: pred: if False, a FailedCheckError error is added. msg: error message if error is added. fmt_args, fmt_kwargs: Positional and keyword formatting arguments for `msg`, eg.: ``debug_check(.., "check failed on values {} and {named}", x, named=y)`` Note that these arguments can be traced values allowing you to add run-time values to the error message. Note that tracking these run-time arrays will increase your memory usage, even if no error happens. For example: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> def f(x): ... checkify.debug_check(x!=0, "cannot be zero!") ... return x >>> _ = f(0) # running without checkify means no debug_check is run. >>> checked_f = checkify.checkify(f) >>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check. >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: cannot be zero! """ _check(pred, msg, True, *fmt_args, **fmt_kwargs) def check_error(error: Error) -> None: """Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`. The semantics of this function are equivalent to: >>> def check_error(err: Error) -> None: ... err.throw() # can raise ValueError But unlike that implementation, ``check_error`` can be functionalized using the :func:`~checkify` transformation. This function is similar to :func:`~check` but with a different signature: whereas :func:`~check` takes as arguments a boolean predicate and a new error message string, this function takes an ``Error`` value as argument. Both :func:`~check` and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`, :func:`~jax.lax.scan`, etc. Both also can be functionalized by using :func:`~checkify`. But unlike :func:`~check`, this function is like a direct inverse of :func:`~checkify`: whereas :func:`~checkify` takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces an ``Error`` value as output, this ``check_error`` function can accept an ``Error`` value as input and can produce the side-effect of raising an Exception. That is, while :func:`~checkify` goes from functionalizable Exception effect to error value, this ``check_error`` goes from error value to functionalizable Exception effect. ``check_error`` is useful when you want to turn checks represented by an ``Error`` value (produced by functionalizing ``checks`` via :func:`~checkify`) back into Python Exceptions. Args: error: Error to check. For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through :func:`~jax.jit`, then re-inject your error value outside of the :func:`~jax.jit`: >>> import jax >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "must be positive!") ... return x >>> def with_inner_jit(x): ... checked_f = checkify.checkify(f) ... # a checkified function can be jitted ... error, out = jax.jit(checked_f)(x) ... checkify.check_error(error) ... return out >>> _ = with_inner_jit(1) # no failed check >>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ if not isinstance(error, Error): raise TypeError('check_error takes an Error as argument, ' f'got type {type(error)} instead.') _check_error(error, debug=False)