# Copyright 2020 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides JAX and TensorFlow interoperation APIs.""" from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial import contextlib import math import os import threading from typing import Any, Union import warnings from absl import logging import numpy as np import jax from jax import tree_util from jax import export from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import op_shardings from jax._src import source_info_util from jax._src import util from jax._src.export import _export from jax._src.export import shape_poly from jax._src.lib import xla_client import tensorflow as tf # These don't have public equivalents. # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.tf2xla.python import xla as tfxla from tensorflow.compiler.xla import xla_data_pb2 try: from tensorflow.python.compiler.xla.experimental import xla_sharding except ModuleNotFoundError: # This can be removed when TF 2.10 support is no longer needed. from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding from tensorflow.python.eager import context as tf_context # pylint: enable=g-direct-tensorflow-import NameStack = source_info_util.NameStack PolyShape = shape_poly.PolyShape # TODO: deprecate DType = Any DisabledSafetyCheck = export.DisabledSafetyCheck map = util.safe_map zip = util.safe_zip # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any PrecisionType = int # Enum xla_data.PrecisionConfig.Precision def _is_tfval(v: TfVal) -> bool: if isinstance(v, (tf.Tensor, tf.Variable)): return True try: # Include all convertible types, even if not supported on accelerators. with tf.device("CPU"): tf.constant(v) return True except: return False class _DefaultNativeSerialization: pass DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization() # In order to ensure that JAX picks up the proper user-frame for source # locations we will register the TensorFlow source path as an internal # path with source_info_util. The typical stack when a JAX primitive # conversion happens is: # jax2tf.process_primitive (top of stack) # jax tracing machinery ... # tf.custom_gradient machinery ... # jax2tf.converted_fun # tf function machinery ... # user code invokes the converted function on TF tensors # # We need to skip over not only JAX internal frames, but TF internal frames # also. # We register the TensorFlow source path lazily _has_registered_tf_source_path = False class _ThreadLocalState(threading.local): def __init__(self): # Keep track if we are inside a call_tf. In that context we disable the # safety check that we are not inside JAX transformations. self.inside_call_tf = False # Maps dimension variables to TF expressions, for non-native lowering self.shape_env: Sequence[tuple[str, TfVal]] = () # A dict collecting all tf concrete_functions called by stablehlo.custom_call # This is used only by native serialization (unlike all the other # thread-local state). self.call_tf_concrete_function_list: list[Any] | None = None _thread_local_state = _ThreadLocalState() @contextlib.contextmanager def inside_call_tf(): # Set the inside_call_tf flag for a context. prev = _thread_local_state.inside_call_tf _thread_local_state.inside_call_tf = True try: yield finally: _thread_local_state.inside_call_tf = prev def get_thread_local_state_call_tf_concrete_function_list() -> ( list[Any] | None ): return _thread_local_state.call_tf_concrete_function_list @partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun_jax: Callable, *, polymorphic_shapes: str | PolyShape | None | Sequence[str | PolyShape | None] = None, polymorphic_constraints: Sequence[str] = (), with_gradient: bool = True, enable_xla: bool = DEFAULT_NATIVE_SERIALIZATION, # type: ignore native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION, # type: ignore native_serialization_platforms: Sequence[str] | None = None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable: """Allows calling a JAX function from a TensorFlow program. See [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for more details about usage and common problems. Args: fun_jax: target JAX function to be called. Its arguments and return value should be JAX arrays, or nested standard Python containers (tuple/list/dict) thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during lowering. .. warning:: The shape-polymorphic lowering is an experimental feature. It is meant to be sound, but it is known to reject some JAX programs that are shape polymorphic. The details of this feature can change. It should be `None` (all arguments are monomorphic), a single PolyShape or string (applies to all arguments), or a tuple/list of the same length as the function arguments. For each argument the shape specification should be `None` (monomorphic argument), or a Python object with the same pytree structure as the argument. See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A shape specification for an array argument should be an object `PolyShape(dim0, dim1, ..., dimn)` where each `dim` is a dimension specification: a positive integer denoting a monomorphic dimension of the given size, or a string denoting a dimension variable assumed to range over non-zero dimension sizes, or the special placeholder string "_" denoting a monomorphic dimension whose size is given by the actual argument. As a shortcut, an Ellipsis suffix in the list of dimension specifications stands for a list of "_" placeholders. For convenience, a shape specification can also be given as a string representation, e.g.: "batch, ...", "batch, height, width, _", possibly with surrounding parentheses: "(batch, ...)". The lowering fails if it cannot ensure that the it would produce the same sequence of TF ops for any non-zero values of the dimension variables. polymorphic_shapes are only supported for positional arguments; shape polymorphism is not supported for keyword arguments. See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. polymorphic_constraints: a sequence of constraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`. See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the value of the gradient will be JAX-accurate. native_serialization_platforms: Specifies the platform(s) for which to lower the code. Must be a tuple of strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default (`None``), specifies the JAX default backend on the machine where the lowering is done. native_serialization_disabled_checks: Disables the specified safety checks. See docstring of `DisabledSafetyCheck`. Returns: A version of `fun_jax` that expects TfVals as arguments (or tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ if native_serialization is not DEFAULT_NATIVE_SERIALIZATION: warnings.warn( "The `native_serialization` parameter is deprecated and " "will be removed in a future version of JAX.", DeprecationWarning, stacklevel=2) del native_serialization if enable_xla is not DEFAULT_NATIVE_SERIALIZATION: warnings.warn( "The `enable_xla` parameter is deprecated and " "will be removed in a future version of JAX.", DeprecationWarning, stacklevel=2) del enable_xla if native_serialization_platforms: if (not isinstance(native_serialization_platforms, (list, tuple)) or not all(p in ["cpu", "cuda", "rocm", "tpu"] for p in native_serialization_platforms)): raise ValueError( "native_serialization_platforms must be a sequence " "containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. " f"Got: {native_serialization_platforms}") native_serialization_platforms = tuple(native_serialization_platforms) api.check_callable(fun_jax) def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: # TODO: is there a better way to check if we are inside a transformation? if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + f"Trace state: {core.trace_ctx}") global _has_registered_tf_source_path if not _has_registered_tf_source_path: source_info_util.register_exclusion(os.path.dirname(tf.__file__)) _has_registered_tf_source_path = True def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct: # The shape and JAX dtype for a TF argument tf_arg_shape = np.shape(a) # Fix the shape for TF1 tf_arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape) _, a_jax_dtype = _tfval_to_tensor_jax_dtype(a) # We count on the fact that jax.ShapeDtypeStruct allows shapes that # contain None. return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype) args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf) args_specs = export.symbolic_args_specs( args_jax_specs, polymorphic_shapes, constraints=polymorphic_constraints) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf) kwargs_specs = export.symbolic_args_specs( kwargs_jax_specs, None) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) args_flat_tf = tuple( map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf)) impl = NativeSerializationImpl( fun_jax, args_specs=args_specs, kwargs_specs=kwargs_specs, native_serialization_platforms=native_serialization_platforms, native_serialization_disabled_checks=native_serialization_disabled_checks) try: impl.before_conversion() outs_tree: tree_util.PyTreeDef = None # type: ignore if with_gradient: @tf.custom_gradient def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: nonlocal outs_tree outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf) return (tuple(outs_tf), _make_custom_gradient_fn_tf( fun_jax, impl=impl, with_gradient=with_gradient, args_specs=args_specs, kwargs_specs=kwargs_specs, args_tf=args_flat_tf, outs_avals=outs_avals, outs_tf=outs_tf)) outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf) else: outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf) message = ("The jax2tf-converted function does not support gradients. " "Use `with_gradient` parameter to enable gradients") # We use PreventGradient, which is propagated through a SavedModel. outs_flat_tf = [ tf.raw_ops.PreventGradient(input=o, message=message) for o in outs_tf ] finally: impl.after_conversion() outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf] out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf) return out_tf return converted_fun_tf class NativeSerializationImpl: def __init__(self, fun_jax, *, args_specs, kwargs_specs, native_serialization_platforms: Sequence[str] | None, native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]): self.convert_kwargs = dict(native_serialization_platforms=native_serialization_platforms, native_serialization_disabled_checks=native_serialization_disabled_checks) if hasattr(fun_jax, "trace"): # If we have a pjit or pmap already we do not wrap with another, and we # allow shardings. fun_jit = fun_jax else: # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also # convert(f_jax), in which case a "jit" is implied. In that case we raise # an error if the lowered function contains non-replicated sharding annotations. fun_jit = jax.jit(fun_jax) self.fun_jax = fun_jit self.args_specs = args_specs self.kwargs_specs = kwargs_specs self.native_serialization_disabled_checks = native_serialization_disabled_checks self.native_serialization_platforms = native_serialization_platforms def before_conversion(self): _prev_func_list = _thread_local_state.call_tf_concrete_function_list _thread_local_state.call_tf_concrete_function_list = [] def _restore_context(): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context _exported_device_assignment = [None] self.exported = _export._export_internal( self.fun_jax, platforms=self.native_serialization_platforms, disabled_checks=self.native_serialization_disabled_checks, _device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment, )(*self.args_specs, **self.kwargs_specs) assert(_exported_device_assignment[0] is not None) self.device_assignment = _exported_device_assignment[0] def after_conversion(self): self._restore_context() def run_fun_tf(self, args_flat_tf: Sequence[TfVal] ) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]: results = _run_exported_as_tf(args_flat_tf, self.exported) return results, tuple(self.exported.out_avals), self.exported.out_tree def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: return _export._get_vjp_fun(self.fun_jax, in_tree=self.exported.in_tree, in_avals=self.exported.in_avals, in_shardings_hlo=self.exported.in_shardings_hlo, out_avals=self.exported.out_avals, out_shardings_hlo=self.exported.out_shardings_hlo, device_assignment=self.device_assignment, apply_jit=True) def dtype_of_val(val: TfVal) -> DType: """Computes the TensorFlow dtype using JAX's typing rules. If the value is a tf.Tensor, it starts with its dtype. If the value is a constant it uses JAX to infer its dtype. The resulting dtype follows the JAX type inference rules, and depends on the value of the JAX_ENABLE_X64 flag. See README.md for how 64-bit values are treated. """ tval, _ = _tfval_to_tensor_jax_dtype(val) return tval.dtype @partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes") def eval_polymorphic_shape(fun_jax: Callable, *, polymorphic_shapes=None) -> Callable: """Evaluates the output shape in presence of shape polymorphism. This is done without lowering or executing the function, same as for `jax.eval_shape`. Args: fun_jax: target JAX function to be called. Its arguments and return value should be JAX arrays, or nested standard Python containers (tuple/list/dict) thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during shape evaluation. See discussion for `jax2tf.convert`. .. warning:: The shape-polymorphic lowering is an experimental feature. Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values with `.shape` and `.dtype` attributes) corresponding to the inputs for `fun_jax`, and returns a tuple with: * the jax.ShapeDtypeStruct corresponding to the result, as for `jax.eval_shape`. The shape may contain symbolic dimension expressions. * the value that can be passed to `polymorphic_shapes` for a subsequent call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`. For example: >>> import jax >>> from jax.experimental import jax2tf >>> from jax import numpy as jnp >>> >>> f = lambda A, x: jnp.sin(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x) >>> print(out_spec.shape) ("a", "c") >>> print(out_poly_shape) (a, c) >>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec) >>> print(res_poly_shape) (c, a) """ def do_eval_polymorphic_shape(*args_specs) -> Any: args_poly_specs = export.symbolic_args_specs( args_specs, polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec) return res_poly_spec, res_polymorphic_shape return do_eval_polymorphic_shape # Internals def flatten_fun_jax(fun_jax: Callable, in_tree, ) -> tuple[Callable, Callable]: """Wraps the function to take a (flat) list of positional args. jax2tf works better and is simpler when the JAX function takes and returns just a tuple of values (no pytrees, no kwargs). This is in part because jax.vjp does not support kwargs and we can only set tf.custom_gradient on functions with flat arguments and results Returns: * the wrapped JAX function taking and returning a flat list of arguments * a thunk that can be called after the wrapped function has been called to return the output pytree. """ out_tree_ref = None def fun_flat_jax(*args_flat_jax): tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax) tree_res = fun_jax(*tree_args, **tree_kwargs) res_flat_jax, out_tree = tree_util.tree_flatten(tree_res) nonlocal out_tree_ref assert out_tree_ref is None or out_tree_ref == out_tree out_tree_ref = out_tree return res_flat_jax return fun_flat_jax, lambda: out_tree_ref def preprocess_arg_tf(arg_idx: int, arg_tf: TfVal) -> TfVal: """Pre-processes the TF args. Returns: a tuple with the pre-processed TF arg, the TF shape, and the JAX dtype. """ if not _is_tfval(arg_tf): msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) # May cast the args_flat to JAX types, using JAX's interpretation # of types of constants. arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf) # Name input tensors; do this after we have cast the arguments arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}") return arg_tf def _make_custom_gradient_fn_tf(fun_jax, *, impl: NativeSerializationImpl, with_gradient: bool, args_specs, kwargs_specs, args_tf: Sequence[TfVal], outs_avals: Sequence[core.ShapedArray], outs_tf: Sequence[TfVal]): """Prepares the TF function to be used with tf.custom_gradient. Args: impl: the serialization implementation details with_gradient: whether to include a tf.custom_gradient args_specs, kwargs_specs: the jax.ShapeDtypeArrays for the args and kwargs args_tf: the flattened TF arguments of the primal function outs_avals: the flattened output JAX abstract values of the primal function outs_tf: the flattened TF outputs of the primal function """ def grad_fn_tf(*out_cts_flat_tf: TfVal, variables=None): if variables: raise ValueError( "Unexpected variables used in forward pass. " "This should not happen for first-order differentiation. " f"{variables=}") # TODO: enable higher-order gradients with tf.name_scope("jax2tf_vjp"): def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal): # If the primal function has outputs of integer or bool types, and if we are # under a tf.function context, then TF will pass None in _out_cts_flat # in place of these values. We should change these to float0 or # else JAX gets unhappy. See issue #6975. if out_ct_tf is not None: return out_ct_tf assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}" # Note that out_ct_aval.shape contains dimension variable from the # primal function scope. We use tf.zeros_like to make a 0 of the right shape. return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0) out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf)) vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf fun_vjp_jax, vjp_in_avals = impl.get_vjp_fun() vjp_polymorphic_shapes = tuple( str(a.shape) # Note: may be _DimExpr, not just DimVar for a in vjp_in_avals) in_cts_flat = convert( fun_vjp_jax, with_gradient=with_gradient, polymorphic_shapes=vjp_polymorphic_shapes, **impl.convert_kwargs)(*vjp_args_flat_tf) # We do not need to fix the in_cts because the TF gradient machinery # will adjust the unconnected gradients and those for integer types. return in_cts_flat return grad_fn_tf def _run_exported_as_tf(args_flat_tf: Sequence[TfVal], exported: export.Exported, ) -> Sequence[TfVal]: """Runs the `exported` as an XlaCallModule TF op. Returns: the flattened tuple of results. """ args_avals = exported.in_avals # TF values may be integer types for float0 def _convert_value(val, aval): # Check the shape assert all(d_aval == d_val for d_aval, d_val in zip(aval.shape, val.shape) if core.is_constant_dim(d_aval)), (aval, val) conversion_dtype = _to_tf_dtype(aval.dtype) if conversion_dtype != aval.dtype: return tf.cast(val, conversion_dtype) else: return val args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals)) out_shapes_tf = tuple( tuple(d if core.is_constant_dim(d) else None for d in out_aval.shape) for out_aval in exported.out_avals) out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals) kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] version = exported.calling_convention_version try: get_max_supported_version = tfxla.call_module_maximum_supported_version except AttributeError: get_max_supported_version = None if get_max_supported_version: max_supported_version = get_max_supported_version() else: max_supported_version = 6 if version > max_supported_version: raise NotImplementedError( "XlaCallModule from your TensorFlow installation supports up to " f"serialization version {max_supported_version} but the serialized " f"module needs version {version}. " "You should upgrade TensorFlow, e.g., to tf_nightly." ) call_module_attrs = dict( version=version, Tout=out_types, Sout=out_shapes_tf, function_list=[ concrete_fn.function_def.signature.name for concrete_fn in _thread_local_state.call_tf_concrete_function_list ] if _thread_local_state.call_tf_concrete_function_list is not None else [], # We always set has_token_input_output because it requires real tokens # for versions less than 9 and is not used starting with version 9. has_token_input_output=False ) call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms) if version >= 6: call_module_attrs["disabled_checks"] = tuple( str(dc) for dc in exported.disabled_safety_checks) else: if version >= 3: if DisabledSafetyCheck.platform() in exported.disabled_safety_checks: call_module_attrs["platforms"] = () # No platform checking if version >= 10: call_module_attrs["use_shardy_partitioner"] = ( config.use_shardy_partitioner.value ) if logging.vlog_is_on(3): # We already logged the MLIR module when we exported it. logging.vlog(3, "XlaCallModule %s", str(call_module_attrs)) call_module_attrs["module"] = exported.mlir_module_serialized # Apply the shardings on arguments and results for pjit. This is redundant # because the mlir_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the # sharding without digging into the `module` attribute of the `XlaCallModule` # op, in the same way as it is done for the legacy jax2tf conversion. # Do not apply XlaSharding for REPLICATED, on inputs and outputs. # This is an agreed convention, and also improves usability under TF eager. # See b/255511660. kept_in_shardings = [] for i in exported.module_kept_var_idx: kept_in_shardings.append(exported.in_shardings_hlo[i]) args_flat_tf = tuple( map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), kept_args_flat_tf, kept_in_shardings)) res = tfxla.call_module(args_flat_tf, **call_module_attrs) # TODO(b/278940799): Replace the TF v1 API with public TF2 API. # Add the custom call tf.function into the default graph, so those functions # will be available during tf.SavedModel.save. if _thread_local_state.call_tf_concrete_function_list is not None: for concrete_fn in _thread_local_state.call_tf_concrete_function_list: tf.compat.v1.get_default_graph()._add_function_recursive( concrete_fn._inference_function ) res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), res, exported.out_shardings_hlo)) res = tuple(map(_convert_value, res, exported.out_avals)) return res def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray: """Converts JAX avals from logical to physical, if relevant. JAX might have avals whose logical vs physical shape/dtype may differ, and only the physical view is expected to possibly relate to TF. TF impl rules should operate on the physical form. A JAX logical aval might even correspond, in principle, to several physical avals, but we don't support those here. Instead we assert there is only one and return it. """ physical_aval = core.physical_aval(aval) assert (len(physical_aval.shape) >= len(aval.shape) and physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval) return physical_aval def _jax_physical_dtype(dtype): # assuming () is a fine stand-in shape return _jax_physical_aval(core.ShapedArray((), dtype)).dtype def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]: """Generate a TF shape, possibly containing None for polymorphic dimensions.""" aval = _jax_physical_aval(aval) return tuple(map(lambda d: None if export.is_symbolic_dim(d) else d, aval.shape)) # In the TF world, we represent float0 as zeros of this type. # We pick bool because this is what JAX uses when it lowers float0 to HLO. _tf_np_dtype_for_float0 = np.bool_ def _to_tf_dtype(jax_dtype): # Note that converting _to_tf_dtype and _to_jax_dtype are not inverses, # due to float0 and 64-bit behavior. try: jax_dtype = _jax_physical_dtype(jax_dtype) except TypeError: # `jax_dtype` isn't actually a valid jax dtype (e.g. it is # tf.float32), so there is no physical dtype anyway pass if jax_dtype == dtypes.float0: jax_dtype = _tf_np_dtype_for_float0 return tf.dtypes.as_dtype(jax_dtype) def _to_jax_dtype(tf_dtype): # Note that converting _to_tf_dtype and _to_jax_dtype are not inverses, # due to float0 and 64-bit behavior. dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype) if dt not in dtypes._jax_dtype_set: raise TypeError(f"dtype {dt} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") return dt def _tfval_to_tensor_jax_dtype(val: TfVal, jax_dtype: DType | None = None, memoize_constants=False) -> tuple[TfVal, DType]: """Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type. If `jax_dtype` is missing, uses JAX typing rules. See README.md for details regarding 64-bit values. Args: val: a scalar, ndarray, tf.Tensor, or tf.Variable jax_dtype: an optional dtype to use. If missing, uses JAX type inference rules for constants. memoize_constants: whether to memoize TF constants. We can't do this everywhere, we may be outside of a conversion scope. Returns: a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type. """ if isinstance(val, (tf.Tensor, tf.Variable)): jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type conversion_dtype = _to_tf_dtype(jax_dtype) if conversion_dtype != val.dtype: # May need to cast for 64-bit values return tf.cast(val, conversion_dtype), jax_dtype else: return val, jax_dtype else: # A constant jax_dtype = jax_dtype or core.abstractify(val).dtype # TODO(document): We assume that the value of a constant does not # change through the scope of the function. But it may be an ndarray, ... # JAX has the same problem when generating HLO. const_key = (id(val), jax_dtype) # Since we use id(val) as a cache key, we have to make sure that we keep # the previous `val` alive. Otherwise, for a ndarray, it can get garbage # collected and reused for a different value, which would create correctness # issues. We keep the `val` alive by storing in the cache the pair # `(val, tf_val)`. # Only memoize non-scalars. JAX will lift all non-scalar constants as # Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them # early, when entering the Jaxpr, so we create the tf.const early and its # scope is the entire Jaxpr. do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None) if do_memoize: _, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None)) else: tf_val = None if tf_val is None: conversion_dtype = _to_tf_dtype(jax_dtype) # The float0 type is not known to TF. if jax_dtype == dtypes.float0: val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype) if hasattr(val, 'dtype') and dtypes.issubdtype(val.dtype, dtypes.extended): val = val.dtype._rules.physical_const(val) tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype) if do_memoize: _thread_local_state.constant_cache[const_key] = (val, tf_val) return tf_val, jax_dtype PartitionsOrReplicated = Union[tuple[int, ...], None] def split_to_logical_devices(tensor: TfVal, partition_dimensions: PartitionsOrReplicated): """Like TPUMPStrategy.experimental_split_to_logical_devices. For jax2tf purposes we want to avoid needing to thread the `strategy` object through the generated computation. It seems that the original function needs the strategy object only for error checking, which we assume is done upstream by JAX. Args: tensor: Input tensor to annotate. partition_dimensions: A list of integers, with one integer per tensor dimension, specifying in how many parts the dimension should be split. The product of integers must equal the number of devices per replica. use_sharding_op: whether to use a sharding op, or not. Returns: an annotated tensor. """ # TODO: this is only for sharded_jit. Either remove, or implement in terms # of _shard_values. if partition_dimensions is None: return xla_sharding.replicate(tensor, use_sharding_op=True) num_partition_splits = math.prod(partition_dimensions) tile_assignment = np.arange(num_partition_splits).reshape( partition_dimensions) return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) def _shard_value(val: TfVal, sd: xla_client.HloSharding | None, *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" if sd is None: return val sharding_proto = sd.to_proto() if (skip_replicated_sharding and op_shardings.is_hlo_sharding_replicated(sd)): return val # Tensorflow heavily relies on tile_assignment_devices proto fields specific # to V1 sharding format, falling back to this format. if ( not sharding_proto.tile_assignment_devices and sharding_proto.iota_reshape_dims ): tad = list( np.arange(math.prod(sharding_proto.tile_assignment_dimensions)) .reshape(sharding_proto.iota_reshape_dims) .transpose(sharding_proto.iota_transpose_perm) .flat ) else: tad = sharding_proto.tile_assignment_devices # type: ignore # To use xla_sharding.py, we must have a xla_data_pb2.OpSharding. xla_sharding_v1_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding( type=int(sharding_proto.type), tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, tile_assignment_devices=tad, replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim, last_tile_dims=sharding_proto.last_tile_dims, ) # Shardy requires V2 sharding format. if config.use_shardy_partitioner.value: xla_sharding_v2_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding( type=int(sharding_proto.type), tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions, tile_assignment_devices=sharding_proto.tile_assignment_devices, iota_reshape_dims=sharding_proto.iota_reshape_dims, iota_transpose_perm=sharding_proto.iota_transpose_perm, replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim, last_tile_dims=sharding_proto.last_tile_dims, ) else: xla_sharding_v2_proto = None if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2]) # apply_to_tensor comes from a tensorflow package, check the tensorflow # version to make sure that it has the sharding_v2_proto parameter. if tf_version < (2, 20): return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor( val, use_sharding_op=True ) return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor( val, use_sharding_op=True, sharding_v2_proto=xla_sharding_v2_proto ) def _register_checkpoint_pytrees(): """Registers TF custom container types as pytrees.""" m = tf.Module() # The types here are automagically changed by TensorFlow's checkpointing # infrastructure. m.a = (tf.Module(), tf.Module()) m.b = [tf.Module(), tf.Module()] m.c = {"a": tf.Module()} tuple_wrapper = type(m.a) list_wrapper = type(m.b) dict_wrapper = type(m.c) # TF AutoTrackable swaps container types out for wrappers. assert tuple_wrapper is not tuple assert list_wrapper is not list assert dict_wrapper is not dict jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: tuple(xs)) jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None), lambda _, xs: list(xs)) jax.tree_util.register_pytree_node( dict_wrapper, lambda s: (tuple(s.values()), tuple(s.keys())), lambda k, xs: dict_wrapper(zip(k, xs))) _register_checkpoint_pytrees()