DriverTrac/venv/lib/python3.12/site-packages/jax/experimental/jax2tf/jax2tf.py

928 lines
38 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides JAX and TensorFlow interoperation APIs."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import contextlib
import math
import os
import threading
from typing import Any, Union
import warnings
from absl import logging
import numpy as np
import jax
from jax import tree_util
from jax import export
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import op_shardings
from jax._src import source_info_util
from jax._src import util
from jax._src.export import _export
from jax._src.export import shape_poly
from jax._src.lib import xla_client
import tensorflow as tf
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla
from tensorflow.compiler.xla import xla_data_pb2
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.eager import context as tf_context
# pylint: enable=g-direct-tensorflow-import
NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape # TODO: deprecate
DType = Any
DisabledSafetyCheck = export.DisabledSafetyCheck
map = util.safe_map
zip = util.safe_zip
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
class _DefaultNativeSerialization:
pass
DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization()
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions, for non-native lowering
self.shape_env: Sequence[tuple[str, TfVal]] = ()
# A dict collecting all tf concrete_functions called by stablehlo.custom_call
# This is used only by native serialization (unlike all the other
# thread-local state).
self.call_tf_concrete_function_list: list[Any] | None = None
_thread_local_state = _ThreadLocalState()
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
def get_thread_local_state_call_tf_concrete_function_list() -> (
list[Any] | None
):
return _thread_local_state.call_tf_concrete_function_list
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun_jax: Callable,
*,
polymorphic_shapes: str | PolyShape | None | Sequence[str | PolyShape | None] = None,
polymorphic_constraints: Sequence[str] = (),
with_gradient: bool = True,
enable_xla: bool = DEFAULT_NATIVE_SERIALIZATION, # type: ignore
native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION, # type: ignore
native_serialization_platforms: Sequence[str] | None = None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
See
[README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during lowering.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
polymorphic_constraints: a sequence of constraints on symbolic dimension
expressions, of the form `e1 >= e2` or `e1 <= e2`.
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
native_serialization_platforms: Specifies the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (`None``), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_disabled_checks: Disables the specified safety checks.
See docstring of `DisabledSafetyCheck`.
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if native_serialization is not DEFAULT_NATIVE_SERIALIZATION:
warnings.warn(
"The `native_serialization` parameter is deprecated and "
"will be removed in a future version of JAX.",
DeprecationWarning, stacklevel=2)
del native_serialization
if enable_xla is not DEFAULT_NATIVE_SERIALIZATION:
warnings.warn(
"The `enable_xla` parameter is deprecated and "
"will be removed in a future version of JAX.",
DeprecationWarning, stacklevel=2)
del enable_xla
if native_serialization_platforms:
if (not isinstance(native_serialization_platforms, (list, tuple)) or
not all(p in ["cpu", "cuda", "rocm", "tpu"]
for p in native_serialization_platforms)):
raise ValueError(
"native_serialization_platforms must be a sequence "
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
api.check_callable(fun_jax)
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +
f"Trace state: {core.trace_ctx}")
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct:
# The shape and JAX dtype for a TF argument
tf_arg_shape = np.shape(a)
# Fix the shape for TF1
tf_arg_shape = tuple(d.value
if isinstance(d, tf.compat.v1.Dimension) else d
for d in tf_arg_shape)
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
# We count on the fact that jax.ShapeDtypeStruct allows shapes that
# contain None.
return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype)
args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf)
args_specs = export.symbolic_args_specs(
args_jax_specs, polymorphic_shapes,
constraints=polymorphic_constraints)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf)
kwargs_specs = export.symbolic_args_specs(
kwargs_jax_specs, None)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
args_flat_tf = tuple(
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf))
impl = NativeSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
native_serialization_platforms=native_serialization_platforms,
native_serialization_disabled_checks=native_serialization_disabled_checks)
try:
impl.before_conversion()
outs_tree: tree_util.PyTreeDef = None # type: ignore
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
nonlocal outs_tree
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
return (tuple(outs_tf),
_make_custom_gradient_fn_tf(
fun_jax,
impl=impl,
with_gradient=with_gradient,
args_specs=args_specs, kwargs_specs=kwargs_specs,
args_tf=args_flat_tf,
outs_avals=outs_avals,
outs_tf=outs_tf))
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
outs_flat_tf = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o in outs_tf
]
finally:
impl.after_conversion()
outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf]
out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf)
return out_tf
return converted_fun_tf
class NativeSerializationImpl:
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str] | None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
self.convert_kwargs = dict(native_serialization_platforms=native_serialization_platforms,
native_serialization_disabled_checks=native_serialization_disabled_checks)
if hasattr(fun_jax, "trace"):
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
fun_jit = fun_jax
else:
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
fun_jit = jax.jit(fun_jax)
self.fun_jax = fun_jit
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_disabled_checks = native_serialization_disabled_checks
self.native_serialization_platforms = native_serialization_platforms
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
_thread_local_state.call_tf_concrete_function_list = []
def _restore_context():
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
self._restore_context = _restore_context
_exported_device_assignment = [None]
self.exported = _export._export_internal(
self.fun_jax,
platforms=self.native_serialization_platforms,
disabled_checks=self.native_serialization_disabled_checks,
_device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment,
)(*self.args_specs, **self.kwargs_specs)
assert(_exported_device_assignment[0] is not None)
self.device_assignment = _exported_device_assignment[0]
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
results = _run_exported_as_tf(args_flat_tf, self.exported)
return results, tuple(self.exported.out_avals), self.exported.out_tree
def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]:
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings_hlo=self.exported.in_shardings_hlo,
out_avals=self.exported.out_avals,
out_shardings_hlo=self.exported.out_shardings_hlo,
device_assignment=self.device_assignment,
apply_jit=True)
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
@partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes")
def eval_polymorphic_shape(fun_jax: Callable,
*,
polymorphic_shapes=None) -> Callable:
"""Evaluates the output shape in presence of shape polymorphism.
This is done without lowering or executing the function, same as for
`jax.eval_shape`.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during shape evaluation. See discussion for `jax2tf.convert`.
.. warning:: The shape-polymorphic lowering is an experimental feature.
Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values
with `.shape` and `.dtype` attributes) corresponding to the inputs for
`fun_jax`, and returns a tuple with:
* the jax.ShapeDtypeStruct corresponding to the result, as for
`jax.eval_shape`. The shape may contain symbolic dimension expressions.
* the value that can be passed to `polymorphic_shapes` for a subsequent
call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`.
For example:
>>> import jax
>>> from jax.experimental import jax2tf
>>> from jax import numpy as jnp
>>>
>>> f = lambda A, x: jnp.sin(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x)
>>> print(out_spec.shape)
("a", "c")
>>> print(out_poly_shape)
(a, c)
>>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec)
>>> print(res_poly_shape)
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = export.symbolic_args_specs(
args_specs, polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec)
return res_poly_spec, res_polymorphic_shape
return do_eval_polymorphic_shape
# Internals
def flatten_fun_jax(fun_jax: Callable,
in_tree,
) -> tuple[Callable, Callable]:
"""Wraps the function to take a (flat) list of positional args.
jax2tf works better and is simpler when the JAX function takes and returns
just a tuple of values (no pytrees, no kwargs). This is in part because
jax.vjp does not support kwargs and we can only set
tf.custom_gradient on functions with flat arguments and results
Returns:
* the wrapped JAX function taking and returning a flat list of arguments
* a thunk that can be called after the wrapped function has been called
to return the output pytree.
"""
out_tree_ref = None
def fun_flat_jax(*args_flat_jax):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
tree_res = fun_jax(*tree_args, **tree_kwargs)
res_flat_jax, out_tree = tree_util.tree_flatten(tree_res)
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
out_tree_ref = out_tree
return res_flat_jax
return fun_flat_jax, lambda: out_tree_ref
def preprocess_arg_tf(arg_idx: int,
arg_tf: TfVal) -> TfVal:
"""Pre-processes the TF args.
Returns: a tuple with the pre-processed TF arg, the TF shape, and the
JAX dtype.
"""
if not _is_tfval(arg_tf):
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
# May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf)
# Name input tensors; do this after we have cast the arguments
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
return arg_tf
def _make_custom_gradient_fn_tf(fun_jax,
*,
impl: NativeSerializationImpl,
with_gradient: bool,
args_specs, kwargs_specs,
args_tf: Sequence[TfVal],
outs_avals: Sequence[core.ShapedArray],
outs_tf: Sequence[TfVal]):
"""Prepares the TF function to be used with tf.custom_gradient.
Args:
impl: the serialization implementation details
with_gradient: whether to include a tf.custom_gradient
args_specs, kwargs_specs: the jax.ShapeDtypeArrays for the args and kwargs
args_tf: the flattened TF arguments of the primal function
outs_avals: the flattened output JAX abstract values of the primal function
outs_tf: the flattened TF outputs of the primal function
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"{variables=}")
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
# If the primal function has outputs of integer or bool types, and if we are
# under a tf.function context, then TF will pass None in _out_cts_flat
# in place of these values. We should change these to float0 or
# else JAX gets unhappy. See issue #6975.
if out_ct_tf is not None:
return out_ct_tf
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}"
# Note that out_ct_aval.shape contains dimension variable from the
# primal function scope. We use tf.zeros_like to make a 0 of the right shape.
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf))
vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf
fun_vjp_jax, vjp_in_avals = impl.get_vjp_fun()
vjp_polymorphic_shapes = tuple(
str(a.shape) # Note: may be _DimExpr, not just DimVar
for a in vjp_in_avals)
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=with_gradient,
polymorphic_shapes=vjp_polymorphic_shapes,
**impl.convert_kwargs)(*vjp_args_flat_tf)
# We do not need to fix the in_cts because the TF gradient machinery
# will adjust the unconnected gradients and those for integer types.
return in_cts_flat
return grad_fn_tf
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
exported: export.Exported,
) -> Sequence[TfVal]:
"""Runs the `exported` as an XlaCallModule TF op.
Returns: the flattened tuple of results.
"""
args_avals = exported.in_avals
# TF values may be integer types for float0
def _convert_value(val, aval):
# Check the shape
assert all(d_aval == d_val
for d_aval, d_val in zip(aval.shape, val.shape)
if core.is_constant_dim(d_aval)), (aval, val)
conversion_dtype = _to_tf_dtype(aval.dtype)
if conversion_dtype != aval.dtype:
return tf.cast(val, conversion_dtype)
else:
return val
args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals))
out_shapes_tf = tuple(
tuple(d if core.is_constant_dim(d) else None
for d in out_aval.shape)
for out_aval in exported.out_avals)
out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals)
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
version = exported.calling_convention_version
try:
get_max_supported_version = tfxla.call_module_maximum_supported_version
except AttributeError:
get_max_supported_version = None
if get_max_supported_version:
max_supported_version = get_max_supported_version()
else:
max_supported_version = 6
if version > max_supported_version:
raise NotImplementedError(
"XlaCallModule from your TensorFlow installation supports up to "
f"serialization version {max_supported_version} but the serialized "
f"module needs version {version}. "
"You should upgrade TensorFlow, e.g., to tf_nightly."
)
call_module_attrs = dict(
version=version,
Tout=out_types,
Sout=out_shapes_tf,
function_list=[
concrete_fn.function_def.signature.name
for concrete_fn in _thread_local_state.call_tf_concrete_function_list
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
# We always set has_token_input_output because it requires real tokens
# for versions less than 9 and is not used starting with version 9.
has_token_input_output=False
)
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms)
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)
for dc in exported.disabled_safety_checks)
else:
if version >= 3:
if DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
call_module_attrs["platforms"] = () # No platform checking
if version >= 10:
call_module_attrs["use_shardy_partitioner"] = (
config.use_shardy_partitioner.value
)
if logging.vlog_is_on(3):
# We already logged the MLIR module when we exported it.
logging.vlog(3, "XlaCallModule %s", str(call_module_attrs))
call_module_attrs["module"] = exported.mlir_module_serialized
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mlir_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
# Do not apply XlaSharding for REPLICATED, on inputs and outputs.
# This is an agreed convention, and also improves usability under TF eager.
# See b/255511660.
kept_in_shardings = []
for i in exported.module_kept_var_idx:
kept_in_shardings.append(exported.in_shardings_hlo[i])
args_flat_tf = tuple(
map(partial(_shard_value,
skip_replicated_sharding=tf.executing_eagerly()),
kept_args_flat_tf, kept_in_shardings))
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
# TODO(b/278940799): Replace the TF v1 API with public TF2 API.
# Add the custom call tf.function into the default graph, so those functions
# will be available during tf.SavedModel.save.
if _thread_local_state.call_tf_concrete_function_list is not None:
for concrete_fn in _thread_local_state.call_tf_concrete_function_list:
tf.compat.v1.get_default_graph()._add_function_recursive(
concrete_fn._inference_function
)
res = list(map(partial(_shard_value,
skip_replicated_sharding=tf.executing_eagerly()),
res, exported.out_shardings_hlo))
res = tuple(map(_convert_value, res, exported.out_avals))
return res
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
physical_aval = core.physical_aval(aval)
assert (len(physical_aval.shape) >= len(aval.shape) and
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
return physical_aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
aval = _jax_physical_aval(aval)
return tuple(map(lambda d: None if export.is_symbolic_dim(d) else d,
aval.shape))
# In the TF world, we represent float0 as zeros of this type.
# We pick bool because this is what JAX uses when it lowers float0 to HLO.
_tf_np_dtype_for_float0 = np.bool_
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
try:
jax_dtype = _jax_physical_dtype(jax_dtype)
except TypeError:
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
# tf.float32), so there is no physical dtype anyway
pass
if jax_dtype == dtypes.float0:
jax_dtype = _tf_np_dtype_for_float0
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
if dt not in dtypes._jax_dtype_set:
raise TypeError(f"dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
return dt
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: DType | None = None,
memoize_constants=False) -> tuple[TfVal, DType]:
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
If `jax_dtype` is missing, uses JAX typing rules.
See README.md for details regarding 64-bit values.
Args:
val: a scalar, ndarray, tf.Tensor, or tf.Variable
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
rules for constants.
memoize_constants: whether to memoize TF constants. We can't do this
everywhere, we may be outside of a conversion scope.
Returns:
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
conversion_dtype = _to_tf_dtype(jax_dtype)
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
return tf.cast(val, conversion_dtype), jax_dtype
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or core.abstractify(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.
const_key = (id(val), jax_dtype)
# Since we use id(val) as a cache key, we have to make sure that we keep
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
# Only memoize non-scalars. JAX will lift all non-scalar constants as
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
# early, when entering the Jaxpr, so we create the tf.const early and its
# scope is the entire Jaxpr.
do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
else:
tf_val = None
if tf_val is None:
conversion_dtype = _to_tf_dtype(jax_dtype)
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
if hasattr(val, 'dtype') and dtypes.issubdtype(val.dtype, dtypes.extended):
val = val.dtype._rules.physical_const(val)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype
PartitionsOrReplicated = Union[tuple[int, ...], None]
def split_to_logical_devices(tensor: TfVal,
partition_dimensions: PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object
through the generated computation. It seems that the original function needs
the strategy object only for error checking, which we assume is done upstream
by JAX.
Args:
tensor: Input tensor to annotate.
partition_dimensions: A list of integers, with one integer per tensor
dimension, specifying in how many parts the dimension should be split. The
product of integers must equal the number of devices per replica.
use_sharding_op: whether to use a sharding op, or not.
Returns:
an annotated tensor.
"""
# TODO: this is only for sharded_jit. Either remove, or implement in terms
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = math.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(val: TfVal,
sd: xla_client.HloSharding | None, *,
skip_replicated_sharding: bool) -> TfVal:
"""Apply sharding to a TfVal."""
if sd is None:
return val
sharding_proto = sd.to_proto()
if (skip_replicated_sharding and
op_shardings.is_hlo_sharding_replicated(sd)):
return val
# Tensorflow heavily relies on tile_assignment_devices proto fields specific
# to V1 sharding format, falling back to this format.
if (
not sharding_proto.tile_assignment_devices
and sharding_proto.iota_reshape_dims
):
tad = list(
np.arange(math.prod(sharding_proto.tile_assignment_dimensions))
.reshape(sharding_proto.iota_reshape_dims)
.transpose(sharding_proto.iota_transpose_perm)
.flat
)
else:
tad = sharding_proto.tile_assignment_devices # type: ignore
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_v1_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=tad,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
last_tile_dims=sharding_proto.last_tile_dims,
)
# Shardy requires V2 sharding format.
if config.use_shardy_partitioner.value:
xla_sharding_v2_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
iota_reshape_dims=sharding_proto.iota_reshape_dims,
iota_transpose_perm=sharding_proto.iota_transpose_perm,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
last_tile_dims=sharding_proto.last_tile_dims,
)
else:
xla_sharding_v2_proto = None
if tf_context.executing_eagerly():
raise ValueError(
"A jit function with sharded arguments or results must be used under a `tf.function` context. "
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
# apply_to_tensor comes from a tensorflow package, check the tensorflow
# version to make sure that it has the sharding_v2_proto parameter.
if tf_version < (2, 20):
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
val, use_sharding_op=True
)
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
val, use_sharding_op=True, sharding_v2_proto=xla_sharding_v2_proto
)
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()
# The types here are automagically changed by TensorFlow's checkpointing
# infrastructure.
m.a = (tf.Module(), tf.Module())
m.b = [tf.Module(), tf.Module()]
m.c = {"a": tf.Module()}
tuple_wrapper = type(m.a)
list_wrapper = type(m.b)
dict_wrapper = type(m.c)
# TF AutoTrackable swaps container types out for wrappers.
assert tuple_wrapper is not tuple
assert list_wrapper is not list
assert dict_wrapper is not dict
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
(tuple(xs), None), lambda _, xs: tuple(xs))
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
lambda _, xs: list(xs))
jax.tree_util.register_pytree_node(
dict_wrapper,
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict_wrapper(zip(k, xs)))
_register_checkpoint_pytrees()