# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common utilities for generating source maps.""" import contextlib import dataclasses import re from typing import Any, Protocol from collections.abc import Sequence from absl import flags import jax from jax._src import sourcemap @dataclasses.dataclass(frozen=True) class SourceMapDump: """A container for a source map and the paired generated code.""" source_map: sourcemap.SourceMap generated_code: str pass_name: str class CompileFn(Protocol): def __call__(self, work_dir, fn, f_args, f_kwargs, **kwargs) -> Any: ... class GenerateDumpFn(Protocol): def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump: ... @dataclasses.dataclass(frozen=True) class Pass: name: str compile_fn: CompileFn generate_dump: GenerateDumpFn _pass_registry = {} def register_pass(pass_: Pass): if pass_.name in _pass_registry: raise ValueError(f"Pass {pass_.name} already registered") _pass_registry[pass_.name] = pass_ def all_passes() -> Sequence[Pass]: return list(_pass_registry.values()) def filter_passes(regex: str) -> Sequence[Pass]: """Gets all registered passes whose display name matches the given regex.""" return [ pass_ for pass_ in _pass_registry.values() if re.match(regex, pass_.name) ] @contextlib.contextmanager def flag_env(**kwargs): """A context manager for setting and restoring flags.""" old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs} for kwarg, new_value in kwargs.items(): setattr(flags.FLAGS, kwarg, new_value) try: yield finally: for kwarg, old_value in old_flags.items(): setattr(flags.FLAGS, kwarg, old_value) def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags): with flag_env(**env_flags): jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda *f_args, **f_kwargs ).compile(compiler_flags)