294 lines
11 KiB
Python
294 lines
11 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from jax._src import array
|
|
from jax._src import dtypes
|
|
from jax._src import xla_bridge
|
|
from jax._src.api import device_put
|
|
from jax._src.lax.lax import _array_copy
|
|
from jax._src.lib import _jax
|
|
from jax._src.lib import jaxlib_extension_version
|
|
from jax._src.lib import xla_client
|
|
from jax._src.numpy import lax_numpy as jnp
|
|
from jax._src.numpy import scalar_types as jnp_types
|
|
from jax._src.sharding import Sharding
|
|
from jax._src.typing import Array, DLDeviceType, DTypeLike
|
|
|
|
import numpy as np
|
|
|
|
|
|
DLPACK_VERSION = (0, 8)
|
|
MIN_DLPACK_VERSION = (0, 5)
|
|
|
|
# A set of dtypes that dlpack supports.
|
|
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
|
|
# because their hashes are different.
|
|
# For example,
|
|
# hash(jnp.float32) != hash(jnp.dtype(jnp.float32))
|
|
# hash(jnp.float32) == hash(jnp.dtype(jnp.float32).type)
|
|
|
|
# TODO(vanderplas): remove this set
|
|
SUPPORTED_DTYPES: frozenset[DTypeLike] = frozenset({
|
|
jnp_types.int8, jnp_types.int16, jnp_types.int32, jnp_types.int64,
|
|
jnp_types.uint8, jnp_types.uint16, jnp_types.uint32, jnp_types.uint64,
|
|
jnp_types.float16, jnp_types.bfloat16, jnp_types.float32, jnp_types.float64,
|
|
jnp_types.complex64, jnp_types.complex128, jnp_types.bool_})
|
|
|
|
SUPPORTED_DTYPES_SET: frozenset[np.dtype] = frozenset({np.dtype(dt) for dt in SUPPORTED_DTYPES})
|
|
|
|
|
|
def is_supported_dtype(dtype: DTypeLike) -> bool:
|
|
"""Check if dtype is supported by jax.dlpack."""
|
|
if dtype is None:
|
|
# NumPy will silently cast this to float64, which may be surprising.
|
|
raise TypeError(f"Expected a string or dtype-like object; got {dtype=}")
|
|
return np.dtype(dtype) in SUPPORTED_DTYPES_SET
|
|
|
|
|
|
def _to_dlpack(x: Array, stream: int | Any | None,
|
|
src_device: _jax.Device | None = None,
|
|
device: _jax.Device | None = None,
|
|
copy: bool | None = None):
|
|
|
|
if src_device is None:
|
|
src_device, = x.devices()
|
|
if device and (src_device is None or device != src_device):
|
|
if copy is not None and not copy:
|
|
raise ValueError(
|
|
f"Specified {device=} which requires a copy since the source device "
|
|
f"is {repr(src_device)}, however copy=False. Set copy=True or "
|
|
"copy=None to perform the requested operation."
|
|
)
|
|
else:
|
|
arr = device_put(x, device)
|
|
else:
|
|
arr = _array_copy(x) if copy else x
|
|
return _jax.buffer_to_dlpack_managed_tensor(
|
|
arr.addressable_data(0), stream=stream
|
|
)
|
|
|
|
|
|
_DL_DEVICE_TO_PLATFORM = {
|
|
DLDeviceType.kDLCPU: "cpu",
|
|
DLDeviceType.kDLCUDA: "cuda",
|
|
DLDeviceType.kDLROCM: "rocm",
|
|
}
|
|
|
|
|
|
def to_dlpack(x: Array, stream: int | Any | None = None,
|
|
src_device: _jax.Device | None = None,
|
|
dl_device: tuple[DLDeviceType, int] | None = None,
|
|
max_version: tuple[int, int] | None = None,
|
|
copy : bool | None = None):
|
|
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
|
|
|
Args:
|
|
x: a :class:`~jax.Array`, on either CPU or GPU.
|
|
stream: optional platform-dependent stream to wait on until the buffer is
|
|
ready. This corresponds to the `stream` argument to ``__dlpack__``
|
|
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
|
src_device: either a CPU or GPU :class:`~jax.Device`.
|
|
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
|
|
format e.g. as produced by ``__dlpack_device__``.
|
|
max_version: the maximum DLPack version that the consumer (i.e. caller of
|
|
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
|
|
This function is not guaranteed to return a capsule of version
|
|
``max_version``.
|
|
copy: a boolean indicating whether or not to copy the input. If
|
|
``copy=True`` then the function must always copy. When
|
|
``copy=False`` then the function must never copy, and must raise an error
|
|
when a copy is deemed necessary. If ``copy=None`` then the function must
|
|
avoid a copy if possible but may copy if needed.
|
|
|
|
Returns:
|
|
A DLPack PyCapsule object.
|
|
|
|
Note:
|
|
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
|
|
cannot be marked as immutable, and it is possible for processes external
|
|
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
|
|
is mutated, it may lead to undefined behavior when using the associated JAX
|
|
array. When JAX eventually supports ``DLManagedTensorVersioned``
|
|
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
|
|
"""
|
|
if not isinstance(x, array.ArrayImpl):
|
|
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
|
f"got {type(x)}")
|
|
|
|
device = None
|
|
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
|
|
if dl_device_type:
|
|
try:
|
|
dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type]
|
|
backend = xla_bridge.get_backend(dl_device_platform)
|
|
device = backend.device_from_local_hardware_id(local_hardware_id)
|
|
except KeyError:
|
|
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
|
|
# recommends using BufferError.
|
|
raise BufferError(
|
|
"The device specification passed to to_dlpack contains an"
|
|
f" unsupported device type (DLDeviceType: {dl_device_type})"
|
|
) from None
|
|
|
|
# As new versions are adopted over time, we can maintain some legacy paths
|
|
# for compatibility mediated through the max_version parameter.
|
|
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
|
|
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
|
|
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
|
|
if max_version is None or max_version >= DLPACK_VERSION:
|
|
# Latest
|
|
return _to_dlpack(
|
|
x, stream=stream,
|
|
src_device=src_device,
|
|
device=device,
|
|
copy=copy
|
|
)
|
|
elif max_version >= MIN_DLPACK_VERSION:
|
|
# Oldest supported
|
|
return _to_dlpack(
|
|
x, stream=stream,
|
|
src_device=src_device,
|
|
device=device,
|
|
copy=copy
|
|
)
|
|
else:
|
|
raise BufferError(
|
|
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
|
|
f"version ({max_version}) was requested."
|
|
)
|
|
|
|
def _check_device(device, dlpack_device, copy):
|
|
if device and dlpack_device != device:
|
|
if copy is not None and not copy:
|
|
raise ValueError(
|
|
f"Specified {device=} which requires a copy since the source device "
|
|
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
|
|
"copy=None to perform the requested operation."
|
|
)
|
|
|
|
def _place_array(_arr, device, dlpack_device, copy):
|
|
if device and dlpack_device != device:
|
|
return device_put(_arr, device)
|
|
if copy:
|
|
return jnp.array(_arr, copy=True)
|
|
return _arr
|
|
|
|
def _is_tensorflow_tensor(external_array):
|
|
t = type(external_array)
|
|
return (
|
|
t.__qualname__ == "EagerTensor"
|
|
and t.__module__.endswith("tensorflow.python.framework.ops")
|
|
)
|
|
|
|
def from_dlpack(external_array,
|
|
device: _jax.Device | Sharding | None = None,
|
|
copy: bool | None = None):
|
|
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
|
|
|
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
|
|
device transfer or copy was requested.
|
|
|
|
Args:
|
|
external_array: An array object that has ``__dlpack__`` and
|
|
``__dlpack_device__`` methods.
|
|
device: The (optional) :py:class:`Device`, representing the device on which
|
|
the returned array should be placed. If given, then the result is
|
|
committed to the device. If unspecified, the resulting array will be
|
|
unpacked onto the same device it originated from. Setting ``device`` to a
|
|
device different from the source of ``external_array`` will require a
|
|
copy, meaning ``copy`` must be set to either ``True`` or ``None``.
|
|
copy: An (optional) boolean, controlling whether or not a copy is performed.
|
|
If ``copy=True`` then a copy is always performed, even if unpacked onto
|
|
the same device. If ``copy=False`` then the copy is never performed and
|
|
will raise an error if necessary. When ``copy=None`` then a copy may be
|
|
performed if needed for a device transfer.
|
|
|
|
Returns:
|
|
A jax.Array
|
|
|
|
Note:
|
|
While JAX arrays are always immutable, dlpack buffers cannot be marked as
|
|
immutable, and it is possible for processes external to JAX to mutate them
|
|
in-place. If a jax Array is constructed from a dlpack buffer and the buffer
|
|
is later modified in-place, it may lead to undefined behavior when using
|
|
the associated JAX array.
|
|
"""
|
|
if isinstance(device, Sharding):
|
|
device_set = device.device_set
|
|
if len(device_set) > 1:
|
|
raise ValueError(
|
|
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
|
|
f"a Sharding with {len(device_set)} devices was provided."
|
|
)
|
|
device, = device_set
|
|
if not hasattr(external_array, "__dlpack__") or not hasattr(external_array, "__dlpack_device__"):
|
|
raise TypeError(
|
|
"The array passed to from_dlpack must have __dlpack__ and __dlpack_device__ methods."
|
|
)
|
|
|
|
dl_device_type, device_id = external_array.__dlpack_device__()
|
|
try:
|
|
dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type]
|
|
except KeyError:
|
|
raise TypeError(
|
|
"Array passed to from_dlpack is on unsupported device type "
|
|
f"(DLDeviceType: {dl_device_type}, array: {external_array}"
|
|
) from None
|
|
|
|
backend = xla_bridge.get_backend(dl_device_platform)
|
|
dlpack_device = backend.device_from_local_hardware_id(device_id)
|
|
_check_device(device, dlpack_device, copy)
|
|
if _is_tensorflow_tensor(external_array):
|
|
# TensorFlow does not support stream=.
|
|
stream = None
|
|
else:
|
|
try:
|
|
stream = dlpack_device.get_stream_for_external_ready_events()
|
|
except _jax.JaxRuntimeError as err:
|
|
if "UNIMPLEMENTED" in str(err):
|
|
stream = None
|
|
else:
|
|
raise
|
|
dlpack = external_array.__dlpack__(stream=stream)
|
|
|
|
try:
|
|
if jaxlib_extension_version < 384:
|
|
arr = _jax.dlpack_managed_tensor_to_buffer(
|
|
dlpack, dlpack_device, stream)
|
|
else:
|
|
arr = _jax.dlpack_managed_tensor_to_buffer(
|
|
dlpack, dlpack_device, stream, copy)
|
|
except xla_client.XlaRuntimeError as e:
|
|
se = str(e)
|
|
if "is not aligned to" in se:
|
|
i = se.index("is not aligned to")
|
|
raise ValueError(
|
|
"Specified input which requires a copy since the source data "
|
|
f"buffer {se[i:]} However copy=False. Set copy=True or "
|
|
"copy=None to perform the requested operation."
|
|
)
|
|
else:
|
|
raise
|
|
# TODO(phawkins): when we are ready to support x64 arrays in
|
|
# non-x64 mode, change the semantics to not canonicalize here.
|
|
arr = jnp.asarray(arr, dtype=dtypes.canonicalize_dtype(arr.dtype))
|
|
if copy and jaxlib_extension_version >= 384:
|
|
# copy was already handled by dlpack_managed_tensor_to_buffer.
|
|
copy = None
|
|
return _place_array(arr, device, dlpack_device, copy)
|