# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for discharging state primitives.""" from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses from functools import partial import math import operator from typing import Any, Protocol, TypeVar from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core from jax._src import literals from jax._src import linear_util as lu from jax._src import pjit from jax._src import sharding_impls from jax._src import source_info_util from jax._src import tree_util from jax._src import custom_derivatives from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing from jax._src.state.primitives import addupdate_p, get_p, swap_p, pin, unpin from jax._src.state.types import ( AbstractRef, RefBitcaster, RefEffect, RefReshaper, get_ref_aval_from_value, uninitialized,) from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array from jax._src.util import (foreach, safe_map, safe_zip, split_list, unzip2, weakref_lru_cache) import numpy as np ## JAX utilities map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip PyTreeDef = tree_util.PyTreeDef ## Discharging state def discharge_state( jaxpr: core.Jaxpr, consts: Sequence[Any], *, should_discharge: bool | Sequence[bool] = True, ) -> tuple[core.Jaxpr, Sequence[Any]]: """Converts a stateful jaxpr into a pure one. Discharging replaces ``Ref`` inputs with regular values, threads updates through the computation, and returns updated ``Ref``s as additional outputs. Args: jaxpr: A stateful jaxpr with ``Ref`` inputs. consts: Constants for the jaxpr. should_discharge: Whether to discharge each ``Ref`` input. If a single bool, applies to all inputs. Returns: A tuple of ``(new_jaxpr, new_consts)`` where ``new_jaxpr`` is a jaxpr with no ``Read``/``Write``/``Accum`` effects. Discharged ``Ref`` inputs become regular value inputs, and their updated values are appended to the outputs. """ if isinstance(should_discharge, bool): should_discharge = [should_discharge] * len(jaxpr.invars) in_avals = [v.aval.inner_aval if isinstance(v.aval, AbstractRef) and d else v.aval for v, d in zip(jaxpr.invars, should_discharge)] eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, should_discharge, consts), debug_info=jaxpr.debug_info.with_unknown_names()) new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts # TODO(mattjj): migrate callers to discharge_state2 for caching def discharge_state2(jaxpr: core.ClosedJaxpr, should_discharge: bool | Sequence[bool] = True, ) -> core.ClosedJaxpr: if isinstance(should_discharge, bool): should_discharge = (should_discharge,) * len(jaxpr.in_avals) return _discharge_state2(jaxpr, tuple(should_discharge)) @weakref_lru_cache def _discharge_state2(jaxpr: core.ClosedJaxpr, should_discharge: tuple[bool, ...], ) -> core.ClosedJaxpr: jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts, should_discharge=should_discharge) return core.ClosedJaxpr(jaxpr_, consts) @dataclasses.dataclass class Environment: env: dict[core.Var, Any] def read(self, v: core.Atom) -> Any: if type(v) is core.Literal: return v.val assert isinstance(v, core.Var) return self.env[v] def write(self, v: core.Var, val: Any) -> None: self.env[v] = val class DischargeRule(Protocol): def __call__( self, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], *args: Any, **params: Any, ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: """Discharge rule for a primitive. See :func:`discharge_state` for an explanation of what discharge means. Args: in_avals: Input abstract values. out_avals: Output abstract values. *args: Input values. **params: Primitive parameters. Returns: A tuple of ``(new_invals, new_outvals)`` where: * ``new_invals`` contains updated values for discharged ``Ref`` inputs, or ``None`` if the input is not a ``Ref`` or was not updated. * ``new_outvals`` is the primitive's output. A sequence if the primitive has multiple results, otherwise a single value. """ _discharge_rules: dict[core.Primitive, DischargeRule] = {} def register_discharge_rule(prim: core.Primitive): def register(f: DischargeRule): _discharge_rules[prim] = f return register class PartialDischargeRule(Protocol): """Discharge rule that supports selective discharging of ``Ref`` inputs. Generalizes :class:`DischargeRule` by accepting a ``should_discharge`` argument that specifies which ``Ref`` inputs to discharge. The returned ``new_invals`` must contain a non-``None`` value if and only if the corresponding ``Ref`` was discharged. """ def __call__( self, should_discharge: Sequence[bool], in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], *args: Any, **params: Any, ) -> tuple[Sequence[Any | None], Any | Sequence[Any]]: ... _partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {} def register_partial_discharge_rule(prim: core.Primitive): def register(f: PartialDischargeRule): _partial_discharge_rules[prim] = f return register def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): env = Environment({}) foreach(env.write, jaxpr.constvars, consts) # Here some args may correspond to `Ref` avals but they'll be treated like # regular values in this interpreter. foreach(env.write, jaxpr.invars, args) refs_to_discharge = {id(v.aval) for v, d in zip(jaxpr.invars, should_discharge) if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack traceback = eqn.source_info.traceback with source_info_util.user_context( traceback, name_stack=name_stack), eqn.ctx.manager: should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars] if eqn.primitive is core.ref_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) if config.refs_to_pins.value: ans = pin(ans) refs_to_discharge.add(id(outvar.aval)) elif eqn.primitive is core.freeze_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) if config.refs_to_pins.value: ans = unpin(ans) refs_to_discharge.remove(id(invar.aval)) elif any(should_discharge) or core.internal_mutable_array_effect in eqn.effects: if eqn.primitive in _partial_discharge_rules: rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge) elif eqn.primitive in _discharge_rules: rule = _discharge_rules[eqn.primitive] else: raise NotImplementedError( f"No state discharge rule implemented for primitive: {eqn.primitive}") invals = map(env.read, eqn.invars) in_avals = [v.aval for v in eqn.invars] out_avals = [v.aval for v in eqn.outvars] new_invals, ans = rule( in_avals, out_avals, *invals, **eqn.params) for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals): if new_inval is not None: if not should: raise ValueError( f"Did not ask for inval to be discharged but it was. ({invar=}," f" {new_inval=})" ) env.write(invar, new_inval) # type: ignore[arg-type] else: # Default primitive rule, similar to `core.eval_jaxpr`. Note that here # we assume any higher-order primitives inside of the jaxpr are *not* # stateful. subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars), **bind_params) if eqn.primitive.multiple_results: foreach(env.write, eqn.outvars, ans) else: env.write(eqn.outvars[0], ans) # By convention, we return the outputs of the jaxpr first and then the final # values of the `Ref`s. Callers to this function should be able to split # them up by looking at `len(jaxpr.outvars)`. out_vals = map(env.read, jaxpr.outvars) ref_vals = map( env.read, [v for v in jaxpr.invars if id(v.aval) in refs_to_discharge]) return out_vals + ref_vals def _is_trivial_indexer(indexer: indexing.NDIndexer): """Returns whether the indexer selects the entire shape.""" for s, idx in zip(indexer.shape, indexer.indices): if not isinstance(idx, indexing.Slice): return False if idx.is_dynamic_start or idx.is_dynamic_size: return False if idx.start != 0 or idx.size != s: return False return True def _maybe_convert_to_slice( indexer: indexing.NDIndexer ) -> list[tuple[int, int, int]] | None: args = [] for i in indexer.indices: if not isinstance(i, indexing.Slice): return None start = i.start end = i.start + (i.size - 1) * i.stride + 1 stride = i.stride # cannot convert to static `slice` if `start` or `end` is dynamic if not isinstance(start, int) or not isinstance(end, int): return None args.append((start, end, stride)) return args def _maybe_convert_to_dynamic_slice( indexer: indexing.NDIndexer, ) -> ( tuple[tuple[Array | int, ...], tuple[Array | int, ...], tuple[int, ...]] | None ): # An NDIndexer only corresponds to a `dynamic_slice` or `dynamic_update_slice` # if each of the indexers is a `Slice` or a ()-shaped value. if not all(isinstance(i, indexing.Slice) or not np.shape(i) for i in indexer.indices): return None # `lax.dynamic_slice` does not handle striding for i in indexer.indices: if isinstance(i, indexing.Slice) and i.stride > 1: return None _convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32")) starts = tuple( _convert_i32(i.start) if isinstance(i, indexing.Slice) else _convert_i32(i) for i in indexer.indices ) sizes = tuple( i.size if isinstance(i, indexing.Slice) else 1 for i in indexer.indices ) squeeze_dims = tuple( i for i, idx in enumerate(indexer.indices) if not isinstance(idx, indexing.Slice) ) return starts, sizes, squeeze_dims # In this code, indexing is handled in three ways: `slice`, `dynamic_slice`, and # gather. For the gather case, the goal is to create a gather array, which means # that we need to convert all other types of indexers into integer array # indexers. This is done by looping over all indexers and checking if they are # not integer array indexers, and if not, performing the conversion. However, # during this process, the indexing semantics may change. Specifically, # according to the indexing rules of NumPy, when there are integer array # indexers separated by other indexers, the axes corresponding to the integer # array indexers need to be moved to the front. After we convert all other # indexers to integer array indexers, the distinction between integer array # indexers and other types of indexers is lost. As a result, it becomes # impossible to determine which axes should be moved to the front. In this case, # we need to transpose the target array before the gather operation. We also # need to transpose the target array back after the gather operation, if it is # used in subsequent computations. def _maybe_transpose_before_gather( indexer: indexing.NDIndexer ) -> tuple[int, ...] | None: is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) int_indexers_contiguous = bool( np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) if int_indexers_contiguous: return None # no transpose needed int_indexer_idxs: list[int] = [] non_int_indexer_idxs: list[int] = [] for i, is_int_index in enumerate(is_int_indexing): (int_indexer_idxs if is_int_index else non_int_indexer_idxs).append(i) transpose_order = (*int_indexer_idxs, *non_int_indexer_idxs) return transpose_order def _perform_transpose_before_gather( target_arr: Array, indexer: indexing.NDIndexer, transpose_order: tuple[int, ...], ) -> tuple[Array, indexing.NDIndexer]: new_target_arr = target_arr.transpose(transpose_order) reordered_indices = tuple(indexer.indices[i] for i in transpose_order) new_indexer = indexing.NDIndexer( indices=reordered_indices, shape=indexer.shape, int_indexer_shape=indexer.int_indexer_shape, ) return new_target_arr, new_indexer def _convert_to_gather_arrays(indexer: indexing.NDIndexer) -> tuple[Array, ...]: # This is the general gather case. We need to create the gather arrays. total_shape = indexer.get_indexer_shape() is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) if any(is_int_indexing): n_idxers = len(indexer.indices) int_indexer_shape = indexer.int_indexer_shape n_int_indexers = sum(1 for p in is_int_indexing if p) last_int_index_idx = n_idxers - 1 - is_int_indexing[::-1].index(True) n_slice_index_dims_after_int = n_idxers - last_int_index_idx - 1 def get_idx_in_shape_after_indexing(i): if not any(is_int_indexing): return i if i < n_idxers - n_slice_index_dims_after_int - n_int_indexers: return i if i < n_idxers - n_slice_index_dims_after_int: raise ValueError return i - n_int_indexers + len(int_indexer_shape) arrs = [] for i, idxer in enumerate(indexer.indices): if isinstance(idxer, indexing.Slice): idx_in_shape_after_indexing = get_idx_in_shape_after_indexing(i) arr = ( lax.iota(np.int32, total_shape[idx_in_shape_after_indexing]) * idxer.stride + idxer.start ) diff = len(total_shape) - idx_in_shape_after_indexing - 1 arr = arr.reshape(arr.shape + (1,) * diff) arrs.append(arr) elif isinstance(idxer, (np.ndarray, Array, literals.TypedNdArray)): diff = n_idxers - 1 - last_int_index_idx arr = idxer.reshape(idxer.shape + (1,) * diff) arrs.append(arr) else: raise ValueError(f"Invalid type of idxer: {type(idxer).__name__}") return tuple(arrs) @register_discharge_rule(get_p) def _get_discharge_rule( in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], x, *idx, tree): del in_avals, out_avals y = _get_discharge(x, idx, tree) return (None,) * (len(idx) + 1), y def _index_array(x, indexer: indexing.NDIndexer): if _is_trivial_indexer(indexer): return x # Try the three APIs in the following order: `lax.slice`, # `lax.dynamic_slice` and gather if maybe_slice := _maybe_convert_to_slice(indexer): x = lax_slicing.slice(x, *zip(*maybe_slice)) # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. elif maybe_dynamic_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_dynamic_slice y = lax_slicing.dynamic_slice(x, starts, sizes) x = lax.squeeze(y, squeeze_dims) else: transpose_order = _maybe_transpose_before_gather(indexer) if transpose_order is not None: x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) arrays = _convert_to_gather_arrays(indexer) x = x[arrays] return x def transform_array(x, transforms): if transforms is None: transforms = [] result = x for transform in transforms: if transform is None: continue match transform: case indexing.NDIndexer(): result = _index_array(result, transform) case RefBitcaster(): result = bitcast(result, transform.dtype) case RefReshaper(): result = result.reshape(transform.shape) case _: raise NotImplementedError(f"Unsupported transform: {transform}") return result def transform_swap_array(x, transforms, val): if transforms is None: transforms = [] # Will hold the value read from `x` before the swap, and will have the same # shape as `val`. new_val = x # List of intermediate results by transforming `x`. intermediates = [x] # Read phase (forward loop) for transform in transforms: match transform: case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): intermediates.append(intermediates[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice new_val = lax.squeeze( lax_slicing.dynamic_slice(new_val, starts, sizes), squeeze_dims ) else: transpose_order = _maybe_transpose_before_gather(indexer) if transpose_order is not None: new_val, indexer = _perform_transpose_before_gather( new_val, indexer, transpose_order ) arrays = _convert_to_gather_arrays(indexer) new_val = new_val[arrays] # Here, we don't need to transpose `new_val` back because it now holds # the result of the indexing, and is no longer the original array that # was indexed into. intermediates.append(new_val) case RefBitcaster(): intermediates.append(bitcast(new_val, transform.dtype)) case RefReshaper(): intermediates.append(new_val.reshape(transform.shape)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") # Will hold the final state of the `x` after `val` has been written to the # transformed location, and will have the same shape as `x`. new_x = val # Write phase (reversed loop) for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): continue if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, _, squeeze_dims = maybe_slice new_x = lax_slicing.dynamic_update_slice( intermediate, lax.expand_dims(new_x, squeeze_dims), starts ) else: transpose_order = _maybe_transpose_before_gather(indexer) if transpose_order is not None: intermediate, indexer = _perform_transpose_before_gather( intermediate, indexer, transpose_order ) arrays = _convert_to_gather_arrays(indexer) new_x = intermediate.at[arrays].set(new_x) # pytype: disable=attribute-error if transpose_order is not None: transpose_order_inversed = np.argsort(transpose_order) new_x = new_x.transpose(transpose_order_inversed) else: raise NotImplementedError(f"Unsupported transform: {transform}") return new_val, new_x def _get_discharge(x, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) return transform_array(x, transforms) @register_discharge_rule(swap_p) def _swap_discharge_rule( in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], x, val, *idx, tree): del in_avals, out_avals z, x_new = _swap_discharge(x, val, idx, tree) return (x_new, None) + (None,) * len(idx), z def _swap_discharge(x, val, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) return transform_swap_array(x, transforms, val) @register_discharge_rule(addupdate_p) def _addupdate_discharge_rule( in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], x, val, *idx, tree): del in_avals, out_avals ans = _addupdate_discharge(x, val, idx, tree) return (ans, None) + (None,) * len(idx), [] def _addupdate_discharge(x, val, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) if not transforms: return x + val if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") indexer = transforms[0] if _is_trivial_indexer(indexer): return x + val # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice x_old = lax_slicing.dynamic_slice(x, starts, sizes) val = lax.expand_dims(val, squeeze_dims) y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) return y transpose_order = _maybe_transpose_before_gather(indexer) if transpose_order is not None: x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) arrays = _convert_to_gather_arrays(indexer) x = x.at[arrays].add(val) if transpose_order is not None: transpose_order_inversed = np.argsort(transpose_order) x = x.transpose(transpose_order_inversed) return x @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts num_outs = len(jaxpr.outvars) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr), debug_info=discharged_jaxpr.debug_info) return discharged_closed_jaxpr, num_outs, fun @register_discharge_rule(core.closed_call_p) def _closed_call_discharge_rule( in_avals: Sequence[core.AbstractValue], _,*args, call_jaxpr: core.ClosedJaxpr): discharged_closed_jaxpr, num_outs, fun = _cached_closed_jaxpr_discharge(call_jaxpr) out_and_ref_vals = core.closed_call_p.bind(fun, *args, call_jaxpr=discharged_closed_jaxpr) out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) ref_vals_iter = iter(ref_vals) new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) else None for aval in in_avals) sentinel = object() assert next(ref_vals_iter, sentinel) is sentinel return new_invals, out_vals # # `run_state` run_state_p = core.Primitive("run_state") run_state_p.multiple_results = True def _run_state_is_high(*_, jaxpr, **__): return jaxpr.is_high run_state_p.is_high = _run_state_is_high # type: ignore def _run_state_to_lojax(*args, jaxpr, is_initialized, **params): assert not jaxpr.constvars closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) arg_avals = map(core.typeof, args) args, is_initialized = unzip2( (lo_val, is_init) for a, x, is_init in zip(arg_avals, args, is_initialized) for lo_val in (a.read_loval(x) if a.has_qdd else a.lower_val(x))) lo_jaxpr = pe.lower_jaxpr(closed_jaxpr) all_outs = run_state_p.bind(*lo_jaxpr.consts, *args, jaxpr=lo_jaxpr.jaxpr, is_initialized=is_initialized, **params) out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(jaxpr)]) pe.apply_himut(jaxpr, args, out_mut) return pe.raise_lo_outs(arg_avals, lo_outs) run_state_p.to_lojax = _run_state_to_lojax def _default_initialization(x): assert hasattr(x, 'shape') assert hasattr(x, 'dtype') dtype = np.dtype(x) if np.issubdtype(dtype, np.integer): value = np.iinfo(dtype).min else: value = math.nan return lax.full(x.shape, value, dtype) def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], is_initialized: tuple[bool, ...]): del which_linear discharged_jaxpr, consts = discharge_state(jaxpr, ()) # Initialize the args that are not initialized. args_it = iter(args) args = tuple( next(args_it) if is_init else _default_initialization(var.aval) for is_init, var in zip(is_initialized, discharged_jaxpr.invars) ) return core.eval_jaxpr(discharged_jaxpr, consts, *args) run_state_p.def_impl(_run_state_impl) mlir.register_lowering(run_state_p, mlir.lower_fun(_run_state_impl)) def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], is_initialized: tuple[bool, ...]): del which_linear assert sum(is_initialized) == len(avals) # When we abstractly evaluate `run_state`, we want to keep track of which # input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to # "propagate" out its inner effects. Otherwise, the effects are local to this # `run_state`. inner_to_outer_aval_mapping = {} outer_ref_index = 0 for i, is_init in enumerate(is_initialized): if not is_init: pass inner_to_outer_aval_mapping[i] = outer_ref_index outer_ref_index += 1 nonlocal_effects = set() is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)} for eff in jaxpr.effects: if not isinstance(eff, RefEffect): nonlocal_effects.add(eff) continue if eff.input_index not in inner_to_outer_aval_mapping: # This means that this effect corresponds to an uninitialized Ref and # should not propagate out of the primitive. continue # If we do propagate the effect, we need to update the input index to # correspond to the outer index. outer_index = inner_to_outer_aval_mapping[eff.input_index] if outer_index in is_ref: # This means that the effect corresponds to a Ref from an outside scope. nonlocal_effects.add( eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index]) ) assert len(jaxpr.invars) == len(is_initialized) if not all(is_initialized): raise NotImplementedError # Uninitialized refs are not in avals. return avals, nonlocal_effects run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], is_initialized: tuple[bool, ...]): if not all(is_initialized): raise NotImplementedError("Uninitialized Refs are not supported in jvp.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr( core.ClosedJaxpr(discharged_jaxpr, body_consts), nonzero_tangents, instantiate=nonzero_tangents) if out_nonzero_tangents == nonzero_tangents: break nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) else: raise Exception("Invalid fixpoint") del discharged_jaxpr, body_consts, out_nonzero_tangents tangents = [ad.instantiate_zeros(t) if inst else t for t, inst in zip(tangents, nonzero_tangents)] tangents = [t for t in tangents if type(t) is not ad_util.Zero] closed_jvp_jaxpr, _ = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr), nonzero_tangents, []) jvp_jaxpr_, jvp_consts = closed_jvp_jaxpr.jaxpr, closed_jvp_jaxpr.consts jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_) jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents)) out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, which_linear=jvp_which_linear, # TODO(sharadmv): compute this in the general case is_initialized=(True,) * len(jvp_jaxpr.invars)) out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts), len(primals)]) del out_consts out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[run_state_p] = _run_state_jvp @register_discharge_rule(run_state_p) def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], *args: Any, jaxpr: core.Jaxpr, which_linear: Sequence[bool], is_initialized: tuple[bool, ...]): if not all(is_initialized): raise NotImplementedError( "Uninitialized Refs are not supported in discharge." ) del out_avals out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear, is_initialized=is_initialized) new_invals = [] for aval, out_val in zip(in_avals, out_vals): new_invals.append(out_val if isinstance(aval, AbstractRef) else None) return new_invals, out_vals def initial_style_jaxpr( fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], dbg: core.DebugInfo, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg) @weakref_lru_cache def _initial_style_jaxpr(fun: Callable, in_tree: api_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], debug: core.DebugInfo): fun_, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun, debug_info=debug), tree_util.treedef_tuple((in_tree,))) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals) return jaxpr, consts, out_tree_thunk() T = TypeVar('T') def run_state(f: Callable[..., None]) -> Callable[[T], T]: def wrapped(args): dbg = api_util.debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) # There may be some uninitialized values here in ref_args. jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) jaxpr = hoist_consts_to_refs(jaxpr_) which_linear = (False,) * (len(consts) + len(ref_args)) refs_is_initialized = tuple(r is not uninitialized for r in ref_args) init_args = tuple(r for r in ref_args if r is not uninitialized) # Consts are always initialized. is_initialized = (True,) * len(consts) + refs_is_initialized out_const_flat = run_state_p.bind(*consts, *init_args, jaxpr=jaxpr, which_linear=which_linear, is_initialized=is_initialized) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped def run_state_reference(f: Callable[..., None]): def wrapped(args): dbg = api_util.debug_info("run_state", f, (args,), {}) flat_args, in_tree = tree_util.tree_flatten(args) ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg) jaxpr = hoist_consts_to_refs(jaxpr_) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) # Initialize any uninitialized values here in ref_args in the reference. ref_args = [ _default_initialization(aval) if r is uninitialized else r for r, aval in zip(ref_args, ref_avals) ] out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, *consts, *ref_args) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped @register_discharge_rule(pjit.jit_p) def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, **params): if not (any(isinstance(e, RefEffect) for e in jaxpr.effects) or any(isinstance(a, AbstractRef) for a in jaxpr.in_avals)): # Only internal ref effects jaxpr_ = discharge_state2(jaxpr) out = pjit.jit_p.bind( *args, jaxpr=jaxpr_, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, **params, ) new_invals = [None] * len(in_avals) return new_invals, out if not all(isinstance(s, sharding_impls.UnspecifiedValue) for s in (*in_shardings, *out_shardings)): raise NotImplementedError if not (all(l is None for l in in_layouts) and all(l is None for l in out_layouts)): raise NotImplementedError discharged_jaxpr = discharge_state2(jaxpr) new_in_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.in_avals) new_out_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.out_avals) new_in_layouts = (None,) * len(discharged_jaxpr.in_avals) new_out_layouts = (None,) * len(discharged_jaxpr.out_avals) out_and_ref_vals = pjit.jit_p.bind( *args, jaxpr=discharged_jaxpr, in_shardings=new_in_shardings, out_shardings=new_out_shardings, in_layouts=new_in_layouts, out_layouts=new_out_layouts, **params) out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.out_avals)]) ref_vals_iter = iter(ref_vals) new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) else None for aval in in_avals) sentinel = object() assert next(ref_vals_iter, sentinel) is sentinel return new_invals, out_vals @register_discharge_rule(custom_derivatives.custom_vjp_call_p) def custom_vjp_call_discharge(in_avals, out_avals, *args, call_jaxpr, fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros, num_consts): # Discharge happens after all AD is done, so we can discard the AD rules. del fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros, num_consts dis_jaxpr, dis_consts = discharge_state(call_jaxpr.jaxpr, call_jaxpr.consts) outs = _eval_jaxpr_ad_error(dis_jaxpr, dis_consts, args) out_vals, ref_vals = split_list(outs, [len(call_jaxpr.out_avals)]) ref_vals_ = iter(ref_vals) new_invals = [next(ref_vals_) if isinstance(aval, AbstractRef) else None for aval in in_avals] assert next(ref_vals_, None) is None return new_invals, out_vals @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) def _eval_jaxpr_ad_error(dis_jaxpr, consts, args): return core.eval_jaxpr(dis_jaxpr, consts, *args) @_eval_jaxpr_ad_error.defjvp def _eval_jaxpr_ad_error_jvp(*_): raise Exception("should be unreachable, AD after discharge")