DriverTrac/venv/lib/python3.12/site-packages/jax/_src/pmap.py

124 lines
5.0 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import stages
from jax._src import traceback_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.shard_map import _shard_map, _axes_to_pspec
from jax._src.api import _shared_code_pmap, _prepare_pmap, jit
from jax._src.mesh import Mesh
from jax._src.lax import lax
from jax._src.tree_util import tree_map, tree_unflatten
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
traceback_util.register_exclusion(__file__)
# Implementing pmap in terms of shard_map
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
static_broadcasted_argnums=(), devices=None, backend=None,
axis_size=None, donate_argnums=(), global_arg_shapes=None):
del global_arg_shapes
# TODO(vanderplas): move these definitions into jax._src and avoid local import.
import jax.experimental.multihost_utils as mhu # pytype: disable=import-error
devices = tuple(devices) if devices is not None else devices
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
if isinstance(axis_name, core._TempAxisName):
axis_name = repr(axis_name)
def infer_params(*args, __check=True, **kwargs):
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, devices, backend, axis_size, args, kwargs)
if __check:
for arg in p.flat_args:
dispatch.check_arg(arg)
mesh = Mesh(_get_devices(p, backend), (axis_name,))
_pmapped, in_specs, out_specs = _cached_shard_map(
p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name)
jitted_f = jit(
_pmapped,
donate_argnums=[i for i, val in enumerate(p.donated_invars) if val])
if __check and xb.process_count() > 1:
flat_global_args = mhu.host_local_array_to_global_array(
p.flat_args, mesh, list(in_specs))
else:
flat_global_args = p.flat_args
return jitted_f, flat_global_args, p, mesh, out_specs, donate_tuple
@util.wraps(f)
def wrapped(*args, **kwargs):
jitted_f, flat_global_args, p, mesh, out_specs, _ = infer_params(
*args, **kwargs)
outs = jitted_f(*flat_global_args)
if xb.process_count() > 1:
outs = mhu.global_array_to_host_local_array(outs, mesh, out_specs())
return tree_unflatten(p.out_tree(), outs)
def lower(*args, **kwargs):
jitted_f, flat_global_args, p, _, _, donate_tuple = infer_params(
*args, __check=False, **kwargs
)
abstract_args = list(map(core.shaped_abstractify, flat_global_args))
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
lowered = jitted_f.trace(*flat_global_args).lower()
lowered = stages.Lowered(lowered._lowering, args_info, p.out_tree(),
no_kwargs=lowered._no_kwargs)
return lowered
wrapped.lower = lower
return wrapped
@lu.cache
def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
f_transformed = flat_fun.f_transformed
def reset_stores_f_transformed(*args, **kwargs):
for store in flat_fun.stores:
if store is not None:
store.reset()
return f_transformed(*args, **kwargs)
flat_fun.f_transformed = reset_stores_f_transformed
in_specs = tuple(map(partial(_axes_to_pspec, axis_name), in_axes_flat))
out_specs = lambda: map(partial(_axes_to_pspec, axis_name), out_axes_thunk())
fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk)
return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs,
out_specs=out_specs, check_vma=False,
axis_names=set(mesh.axis_names)),
in_specs, out_specs)
@lu.transformation2
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
args = tree_map(lambda x, ax: x if ax is None else lax.squeeze(x, [ax]),
list(args), list(in_axes))
out = f(*args)
return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
list(out), list(out_axes_thunk()))
def _get_devices(p, backend):
if backend is not None and p.devices is None:
devs = xb.devices(backend=backend)
else:
devs = xb.devices() if p.devices is None else p.devices
if xb.process_count() > 1:
return devs[:p.global_axis_size]
return devs[:p.local_axis_size]