DriverTrac/venv/lib/python3.12/site-packages/jax/experimental/source_mapper/common.py
2025-11-28 09:08:33 +05:30

93 lines
2.4 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for generating source maps."""
import contextlib
import dataclasses
import re
from typing import Any, Protocol
from collections.abc import Sequence
from absl import flags
import jax
from jax._src import sourcemap
@dataclasses.dataclass(frozen=True)
class SourceMapDump:
"""A container for a source map and the paired generated code."""
source_map: sourcemap.SourceMap
generated_code: str
pass_name: str
class CompileFn(Protocol):
def __call__(self, work_dir, fn, f_args, f_kwargs, **kwargs) -> Any:
...
class GenerateDumpFn(Protocol):
def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump:
...
@dataclasses.dataclass(frozen=True)
class Pass:
name: str
compile_fn: CompileFn
generate_dump: GenerateDumpFn
_pass_registry = {}
def register_pass(pass_: Pass):
if pass_.name in _pass_registry:
raise ValueError(f"Pass {pass_.name} already registered")
_pass_registry[pass_.name] = pass_
def all_passes() -> Sequence[Pass]:
return list(_pass_registry.values())
def filter_passes(regex: str) -> Sequence[Pass]:
"""Gets all registered passes whose display name matches the given regex."""
return [
pass_
for pass_ in _pass_registry.values()
if re.match(regex, pass_.name)
]
@contextlib.contextmanager
def flag_env(**kwargs):
"""A context manager for setting and restoring flags."""
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
for kwarg, new_value in kwargs.items():
setattr(flags.FLAGS, kwarg, new_value)
try:
yield
finally:
for kwarg, old_value in old_flags.items():
setattr(flags.FLAGS, kwarg, old_value)
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
with flag_env(**env_flags):
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda
*f_args, **f_kwargs
).compile(compiler_flags)