DriverTrac/venv/lib/python3.12/site-packages/jax/_src/op_shardings.py

120 lines
4.3 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sharding utilities"""
from __future__ import annotations
from collections.abc import Sequence
import itertools
from typing import Union
import numpy as np
from jax._src.lib import xla_client as xc
def get_num_ways_dim_sharded(
hlo_sharding: xc.HloSharding, allow_partial_manual: bool = False
) -> tuple[list[int], int]:
assert not hlo_sharding.is_manual()
if hlo_sharding.is_replicated():
return [], 1
if hlo_sharding.is_unreduced():
return [], 1
partitions = hlo_sharding.tile_assignment_dimensions()
subgroup_types = hlo_sharding.subgroup_types()
if subgroup_types == [xc.OpSharding.Type.REPLICATED]:
return list(partitions[:-1]), partitions[-1]
elif subgroup_types == [xc.OpSharding.Type.UNREDUCED]:
return list(partitions[:-1]), 1
elif set(subgroup_types) == {xc.OpSharding.Type.REPLICATED,
xc.OpSharding.Type.UNREDUCED}:
replicated_loc = subgroup_types.index(xc.OpSharding.Type.REPLICATED)
return list(partitions[:-2]), partitions[-2:][replicated_loc]
elif allow_partial_manual and xc.OpSharding.Type.MANUAL in subgroup_types:
if subgroup_types == [xc.OpSharding.Type.MANUAL]:
return list(partitions[:-1]), 1
else:
assert (set(subgroup_types) ==
{xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL})
replicated_loc = subgroup_types.index(xc.OpSharding.Type.REPLICATED)
return list(partitions[:-2]), partitions[-2:][replicated_loc]
elif hlo_sharding.replicate_on_last_tile_dim():
return list(partitions[:-1]), partitions[-1]
else:
if subgroup_types:
raise NotImplementedError(f"Unhandled OpSharding type: {hlo_sharding}. "
"Please open a bug report!")
return list(partitions), 1
def is_hlo_sharding_replicated(hc: xc.HloSharding) -> bool:
return True if hc.num_devices() == 1 else hc.is_replicated()
def are_hlo_shardings_equal(hc1: xc.HloSharding, hc2: xc.HloSharding) -> bool:
if hc1 is hc2:
return True
if is_hlo_sharding_replicated(hc1) and is_hlo_sharding_replicated(hc2):
return True
return hc1 == hc2
_Index = Union[int, slice, tuple[Union[int, slice], ...]]
def op_sharding_to_numpy_indices(
hlo_sharding: xc.HloSharding, shape: Sequence[int],
num_devices: int) -> np.ndarray:
indices = np.empty(num_devices, dtype=np.object_)
# num_devices is required as an argument when hlo_sharding is
# REPLICATED. `jax.device_count()` cannot be used because you can create
# an opsharding with less number of devices than `jax.device_count()`.
if is_hlo_sharding_replicated(hlo_sharding):
indices.fill((slice(None),) * len(shape))
return indices
assert num_devices == hlo_sharding.num_devices()
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
assert len(partitions) == len(shape), (len(partitions), len(shape))
axis_indices: list[Sequence[_Index]] = []
for dim, n_shards in zip(shape, partitions):
if n_shards == 1:
axis_indices.append([slice(None)])
elif n_shards > 1:
shard_size, ragged = divmod(dim, n_shards)
assert not ragged, (dim, n_shards)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(n_shards)])
else:
raise AssertionError('Unrecognized number of shards. Please file a bug!')
device_it = iter(hlo_sharding.tile_assignment_devices())
for idxs in itertools.product(*axis_indices):
for _ in range(num_replicas):
indices[next(device_it)] = idxs
return indices
def op_sharding_to_indices(
op_sharding: xc.HloSharding, shape: Sequence[int],
num_devices: int) -> tuple[tuple[slice, ...], ...]:
indices = op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
return tuple(indices.flat)