# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Colocated Python function API implementation.""" from __future__ import annotations import dataclasses import inspect import random import threading from typing import Any from collections.abc import Callable, Sequence import jax from jax._src import api from jax._src import tree_util from jax._src import util from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @dataclasses.dataclass(frozen=True, slots=True) class FunctionInfo: """User function wrapped by colocated_python.""" fun: Callable[..., Any] fun_sourceinfo: str | None fun_signature: inspect.Signature | None @dataclasses.dataclass(frozen=True, slots=True) class Specialization: """Specialization for a colocated_python function.""" in_specs_treedef: tree_util.PyTreeDef | None = None in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None out_specs_treedef: tree_util.PyTreeDef | None = None out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None devices: xc.DeviceList | None = None def update( self, *, in_specs_treedef: tree_util.PyTreeDef | None = None, in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, out_specs_treedef: tree_util.PyTreeDef | None = None, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, devices: Sequence[jax.Device] | xc.DeviceList | None = None, ): """Creates a new specialization with overrides.""" if in_specs_treedef is None: in_specs_treedef = self.in_specs_treedef elif self.in_specs_treedef is not None: raise ValueError("in_specs already specified") if in_specs_leaves is None: in_specs_leaves = self.in_specs_leaves elif self.in_specs_leaves is not None: raise ValueError("in_specs already specified") if out_specs_fn is None: out_specs_fn = self.out_specs_fn elif self.out_specs_fn is not None: raise ValueError("out_specs_fn already specified") if out_specs_treedef is None: out_specs_treedef = self.out_specs_treedef elif self.out_specs_treedef is not None: raise ValueError("out_specs already specified") if out_specs_leaves is None: out_specs_leaves = self.out_specs_leaves elif self.out_specs_leaves is not None: raise ValueError("out_specs already specified") if devices is None: devices = self.devices elif self.devices is not None: raise ValueError("devices already specified") elif not isinstance(devices, xc.DeviceList): devices = xc.DeviceList(tuple(devices)) return Specialization( in_specs_treedef, in_specs_leaves, out_specs_fn, out_specs_treedef, out_specs_leaves, devices, ) def _get_spec(x: Any) -> api.ShapeDtypeStruct: """Extracts a spec for a value, which must be a JAX Array.""" # TODO(hyeontaek): Allow Python values and automatically apply `shard_arg` # with a suitable sharding and layout. if not isinstance(x, jax.Array): raise ValueError( "colocated_python only supports jax.Array as input and output, but got" f" {type(x)}." ) return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None: """Returns a representative device list from function call arguments.""" device_list_set: set[xc.DeviceList] = set() for x in args: sharding = getattr(x, "sharding", None) if sharding is not None: device_list_set.add(x.sharding._internal_device_list) if not device_list_set: return None if len(device_list_set) != 1: raise ValueError( "All arguments must use the same device list, but got" f" multiple device lists: {device_list_set}." ) return device_list_set.pop() def _compile_to_executable( name: str, fun: Callable[..., Any], in_specs_treedef: tree_util.PyTreeDef, in_specs_leaves: tuple[api.ShapeDtypeStruct, ...], out_specs_treedef: tree_util.PyTreeDef, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: xc.DeviceList, ) -> Callable[..., Any]: """Compiles a Python function into a runtime executable.""" fun_and_specialization = ( fun, in_specs_treedef, in_specs_leaves, out_specs_treedef, out_specs_leaves, devices, ) pickled_function = _serialize(fun_and_specialization) program = ifrt_programs.make_colocated_python_program( name, pickled_function, devices, in_specs_leaves, out_specs_leaves ) ifrt_client = devices[0].client out_sdss = tuple( jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves ) out_shardings = tuple(sds.sharding for sds in out_specs_leaves) try: compile_options = ifrt_programs.make_colocated_python_compile_options() loaded_executable = ifrt_client.compile_ifrt_program( program, compile_options ) out_handlers = pxla.global_avals_to_results_handler( out_sdss, out_shardings, committed=True # type: ignore ).handlers def call(*args, **kwargs): args_leaves = tree_util.tree_leaves((args, kwargs)) execute_result = loaded_executable.execute_sharded( args_leaves, with_tokens=False ) results = execute_result.consume_with_handlers(out_handlers) return tree_util.tree_unflatten(out_specs_treedef, results) return call except jax.errors.JaxRuntimeError as e: # TODO(hyeontaek): Implement colocated Python support in McJAX and remove # this fallback path. if "PjRtCompiler requires an HloProgram" in str(e): return fun raise def _make_output_specs_and_push_result_fun( info: FunctionInfo, specialization: Specialization, uid: int ) -> Callable[..., Any]: """Creates a function that computes output specs and pushes the result to the result store.""" assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.out_specs_treedef is None assert specialization.out_specs_leaves is None assert specialization.devices is not None devices = specialization.devices def lowered_fun(*args, **kwargs) -> jax.Array: result = info.fun(*args, **kwargs) result_leaves, out_treedef = tree_util.tree_flatten(result) out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves) return _serialize_specs(out_treedef, out_spec_leaves, devices) out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( _make_specs_for_serialized_specs(specialization.devices), ) name = getattr(info.fun, "__name__", "unknown") name = f"{name}_output_specs_and_push_result" return _compile_to_executable( name=name, fun=lowered_fun, in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, out_specs_treedef=out_specs_treedef, out_specs_leaves=tuple(out_specs_leaves), devices=specialization.devices, ) def _make_pop_result_fun( info: FunctionInfo, specialization: Specialization, uid: int ) -> Callable[..., Any]: """Makes a function that pops results from the result store.""" assert specialization.out_specs_treedef is not None assert specialization.out_specs_leaves is not None assert specialization.devices is not None out_specs_treedef = specialization.out_specs_treedef def lowered_fun(): result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) return tree_util.tree_unflatten(out_specs_treedef, result_leaves) in_specs_leaves, in_specs_treedef = tree_util.tree_flatten(( # args (), # kwargs {}, )) name = getattr(info.fun, "__name__", "unknown") name = f"{name}_pop_result" return _compile_to_executable( name=name, fun=lowered_fun, in_specs_treedef=in_specs_treedef, in_specs_leaves=tuple(in_specs_leaves), out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) def _make_async_execution_fun( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: """Makes a function that asynchronously executes the function.""" assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.out_specs_treedef is not None assert specialization.out_specs_leaves is not None assert specialization.devices is not None name = getattr(info.fun, "__name__", "unknown") return _compile_to_executable( name=name, fun=info.fun, in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) @jax._src.util.cache(max_size=None) def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: """Returns a specialized function for the given specialization.""" util.test_event("colocated_python_func._get_specialized_func") assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.devices is not None uid = random.getrandbits(63) mutex = threading.Lock() # Asynchronous execution function that has known output_specs. async_execution_func = None def specialized_func(*args, **kwargs): """Specialized function to be executed with given args and kwargs.""" nonlocal specialization, async_execution_func with mutex: if async_execution_func is None: if specialization.out_specs_treedef is None: if specialization.out_specs_fn is None: serialized_out_specs = _make_output_specs_and_push_result_fun( info, specialization, uid )(*args, **kwargs) # Waits for the output_specs. This may block. out_specs_treedef, out_specs_leaves = _deserialize_specs( serialized_out_specs ) # Subsequent calls would use async_execution_func with discovered # output_specs. specialization = specialization.update( out_specs_treedef=out_specs_treedef, out_specs_leaves=out_specs_leaves, ) async_execution_func = _make_async_execution_fun( info, specialization ) return _make_pop_result_fun(info, specialization, uid)() else: # Compute out_specs using out_specs_fn and inputs. args_specs, kwargs_specs = tree_util.tree_map( _get_spec, (args, kwargs) ) out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs) # Type checking is ignored to silence mypy error: Incompatible types # in assignment (expression has type "list[Any]", variable has type # "tuple[ShapeDtypeStruct, ...]") [assignment] out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment] out_specs ) specialization = specialization.update( out_specs_treedef=out_specs_treedef, out_specs_leaves=tuple(out_specs_leaves), ) async_execution_func = _make_async_execution_fun( info, specialization ) # Fall-through. else: async_execution_func = _make_async_execution_fun(info, specialization) # Fall-through. # Asynchronous execution runs outside of the mutex to allow concurrent # execution for inline executors. return async_execution_func(*args, **kwargs) return specialized_func def make_callable( fun: Callable[..., Any], fun_sourceinfo: str | None, fun_signature: inspect.Signature | None, ): """Makes a colocated Python callable.""" return _make_callable( FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() ) def _make_callable(info: FunctionInfo, specialization: Specialization): """Internal implementation of make_callable.""" def specialize( in_specs: ShapeDtypeStructTree | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, devices: Sequence[jax.Device] | None = None, ): """Returns a colocated Python callable with extra specialization. Args: in_specs: Optionally specifies the expected input specs. Input specs are expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a function call. out_specs_fn: Optionally specifies a function that computes the output specs from input specs. If unspecified, colocated Python will compute the output specs during the very first execution, and this execution will be synchronous. devices: Optionally specifies the devices to execute the function on. Must be provided if `in_specs` has no leaves because devices cannot be inferred from input specs or arguments. Returns: A colocated Python callable with extra specialization. """ # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if # `out_specs_fn(in_specs)` returns at least one leaf that we can use for # inferring `devices`. if in_specs is None: in_specs_leaves, in_specs_treedef = None, None else: in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) in_specs_leaves = tuple(in_specs_leaves_list) return _make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, out_specs_fn=out_specs_fn, devices=devices, ), ) @api_boundary def __call__(*args, **kwargs): """Executes the given Python function on the same devices as the arguments or as specialized. If the callable has not been specialized with output shapes and shardings (see `specialize` above), the very first call will run synchronously to discover output shapes and shardings, and will run asynchronously after. If specialized with output shapes and shardings, every execution of the callable will be asynchronous. """ args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) if specialization.in_specs_treedef is None: # Allow input polymorphism by applying input_specs specialization # temporarily for this call. return _make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, ), )(*args, **kwargs) if specialization.devices is None: devices = _infer_devices_from_args(args_leaves) if devices is None: raise ValueError( "No devices found. colocated_python function without input" " arguments must be first specialized with devices." ) # Allow device polymorphism by applying devices specialization temporarily # for this call. return _make_callable(info, specialization.update(devices=devices))( *args, **kwargs ) # Assertion is added to silence mypy error: Unsupported operand types for != # ("PyTreeDef" and "None") [operator] assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) # If input_specs is known, verify that it matches actual inputs. if (specialization.in_specs_treedef != in_specs_treedef or specialization.in_specs_leaves != in_specs_leaves): raise ValueError( "Input specs in specialization and input specs of arguments must have" " the same pytree structure, but they have the following structural" " differences:\n" + ("\n".join( f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" f" a {thing2} in value 2, so {explanation}.\n" for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( specialization.in_specs_treedef, in_specs_treedef )))) return _get_specialized_func(info, specialization)(*args, **kwargs) __call__ = wraps(info.fun)(__call__) __call__.specialize = specialize return __call__