182 lines
6.6 KiB
Python
182 lines
6.6 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Colocated Python top-level API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
from typing import Any, overload
|
|
from collections.abc import Callable, Sequence
|
|
|
|
import jax
|
|
from jax._src import api_util
|
|
from jax._src import util
|
|
from jax.experimental.colocated_python.func import make_callable
|
|
from jax.experimental.colocated_python.obj import wrap_class
|
|
import numpy as np
|
|
|
|
|
|
@overload
|
|
def colocated_cpu_devices(
|
|
devices_or_mesh: Sequence[jax.Device],
|
|
) -> Sequence[jax.Device]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def colocated_cpu_devices(
|
|
devices_or_mesh: jax.sharding.Mesh,
|
|
) -> jax.sharding.Mesh:
|
|
...
|
|
|
|
|
|
def colocated_cpu_devices(devices_or_mesh):
|
|
"""Finds devices or a mesh that has CPU devices colocated with the given devices or mesh.
|
|
|
|
An accelerator device often accompanies a CPU device that is on the same host.
|
|
Furthermore, when a single host has multiple accelerator devices, there can be
|
|
multiple CPU devices, each of which is associated with one of the accelerator
|
|
devices with a 1:1 correspondence.
|
|
|
|
This function finds the colocated CPU devices for the given devices or mesh.
|
|
When the input is a mesh, the returned value is another mesh that has the same
|
|
shape as the input mesh but has colocated CPU devices. If an input device is
|
|
already a CPU device, it is returned as-is.
|
|
|
|
It preserves ordering. The output CPU device at index i is associated with the
|
|
input accelerator device at index i.
|
|
|
|
Args:
|
|
devices_or_mesh: A tuple of devices or a mesh.
|
|
|
|
Returns:
|
|
A tuple of devices or a mesh that has the colocated CPU devices.
|
|
"""
|
|
if isinstance(devices_or_mesh, jax.sharding.Mesh):
|
|
return _colocated_cpu_mesh_cached(devices_or_mesh)
|
|
|
|
if not isinstance(devices_or_mesh, tuple):
|
|
devices_or_mesh = tuple(devices_or_mesh)
|
|
try:
|
|
return _colocated_cpu_devices_cached(devices_or_mesh)
|
|
except (ValueError, AttributeError):
|
|
return _colocated_cpu_devices_cached_fallback_to_cpu_backend(
|
|
devices_or_mesh
|
|
)
|
|
|
|
|
|
@util.cache(max_size=1024, trace_context_in_key=False)
|
|
def _colocated_cpu_devices_cached(
|
|
devices: tuple[jax.Device, ...],
|
|
) -> Sequence[jax.Device]:
|
|
cpu_devices_by_colocation_id = collections.defaultdict(list)
|
|
for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access
|
|
if device.device_kind == "cpu":
|
|
cpu_devices_by_colocation_id[device.colocation_id].append(device)
|
|
if not cpu_devices_by_colocation_id:
|
|
raise ValueError("No CPU devices found")
|
|
|
|
colocated_cpu_devices = []
|
|
for device in devices:
|
|
matches = cpu_devices_by_colocation_id[device.colocation_id]
|
|
if not matches:
|
|
raise ValueError(f"Device {device} has no colocated devices")
|
|
elif len(matches) > 1:
|
|
raise ValueError(
|
|
f"Ambiguous colocated devices; device {device} has"
|
|
f" {len(matches)} colocated devices: f{matches}"
|
|
)
|
|
colocated_cpu_devices.append(matches[0])
|
|
return colocated_cpu_devices
|
|
|
|
|
|
@util.cache(max_size=1024, trace_context_in_key=False)
|
|
def _colocated_cpu_devices_cached_fallback_to_cpu_backend(
|
|
devices: tuple[jax.Device, ...],
|
|
) -> Sequence[jax.Device]:
|
|
# TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines
|
|
# CPU devices by its own instead of using a separate CPU backend.
|
|
if devices[0].device_kind == "cpu":
|
|
# Use the devices from the backend of an original device if it defines CPU
|
|
# devices.
|
|
cpu_backend_devices = [d for d in devices[0].client._get_all_devices()
|
|
if d.device_kind == "cpu"]
|
|
else:
|
|
# PjRt-IFRT on a non-CPU platform currently defines CPU devices on a separae
|
|
# CPU backend.
|
|
cpu_backend_devices = jax.local_devices(backend="cpu")
|
|
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
|
|
|
|
available_devices = devices[: min(len(cpu_backend_devices), len(devices))]
|
|
return [
|
|
cpu_backend_devices[device_index_map[d.id]] for d in available_devices
|
|
]
|
|
|
|
|
|
@util.cache(max_size=1024, trace_context_in_key=False)
|
|
def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh:
|
|
"""Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices."""
|
|
# Finding colocated CPU devices reuses the cache of `colocated_cpu_devices`
|
|
# called with devices. `_colocated_cpu_mesh` itself is also cached to avoid
|
|
# creating a new `Mesh` object repeatedly.
|
|
flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat))
|
|
return jax.sharding.Mesh(
|
|
np.array(flat_cpu_devices).reshape(mesh.axis_sizes),
|
|
mesh.axis_names,
|
|
axis_types=mesh.axis_types,
|
|
)
|
|
|
|
|
|
def colocated_python(fun: Callable[..., Any]):
|
|
"""Executes the given Python function on the same devices as the arguments.
|
|
|
|
The returned colocated Python callable lets the user run a serializable Python
|
|
function on the same devices as the arguments, potentially on remote hosts.
|
|
|
|
Python callable implements `specialize` and `__call__` methods. See their
|
|
docstrings for details and https://docs.jax.dev/en/latest/notebooks/colocated-python.html
|
|
for examples.
|
|
|
|
Args:
|
|
fun: An original function to wrap as an I/O callable.
|
|
|
|
Returns:
|
|
Colocated Python callable with no initial specialization.
|
|
"""
|
|
return make_callable(
|
|
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
|
|
)
|
|
|
|
|
|
def colocated_python_class(cls: type[object]) -> type[object]:
|
|
"""Creates a wrapper class that executes the given Python class methods on the same devices as the arguments.
|
|
|
|
The wrapper class exposes the returned type's methods, and can be instantiated
|
|
on JAX. An actual object will be instantiated on the host of the devices of
|
|
the arguments' when a method of the wrapper instance is called for the first
|
|
time.
|
|
|
|
The actual object will persist while the wrapper object is alive, and will be
|
|
destroyed asynchronously when the wrapper object is destroyed. Note that if
|
|
the wrapper object is destroyed immediately without any method call, actual
|
|
objects will not be created.
|
|
|
|
Args:
|
|
cls: The class to wrap as a colocated Python object.
|
|
|
|
Returns:
|
|
Wrapper class.
|
|
"""
|
|
return wrap_class(cls, api_util.fun_sourceinfo(cls))
|