115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Wadler-Lindig pretty printer.
|
|
#
|
|
# References:
|
|
# Wadler, P., 1998. A prettier printer. Journal of Functional Programming,
|
|
# pp.223-244.
|
|
#
|
|
# Lindig, C. 2000. Strictly Pretty.
|
|
# https://lindig.github.io/papers/strictly-pretty-2000.pdf
|
|
#
|
|
# Hafiz, A. 2021. Strictly Annotated: A Pretty-Printer With Support for
|
|
# Annotations. https://ayazhafiz.com/articles/21/strictly-annotated
|
|
#
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from functools import partial
|
|
import sys
|
|
from typing import Any
|
|
|
|
from jax._src import config
|
|
from jax._src.lib import _pretty_printer as _pretty_printer
|
|
|
|
|
|
_PPRINT_USE_COLOR = config.bool_state(
|
|
'jax_pprint_use_color',
|
|
True,
|
|
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
|
)
|
|
|
|
def _can_use_color() -> bool:
|
|
try:
|
|
# Check if we're in IPython or Colab
|
|
ipython = get_ipython() # type: ignore[name-defined]
|
|
shell = ipython.__class__.__name__
|
|
if shell == "ZMQInteractiveShell":
|
|
# Jupyter Notebook
|
|
return True
|
|
elif "colab" in str(ipython.__class__):
|
|
# Google Colab (external or internal)
|
|
return True
|
|
except NameError:
|
|
pass
|
|
# Otherwise check if we're in a terminal
|
|
return hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
|
|
|
|
CAN_USE_COLOR = _can_use_color()
|
|
|
|
Color = _pretty_printer.Color
|
|
Intensity = _pretty_printer.Intensity
|
|
Doc = _pretty_printer.Doc
|
|
|
|
def _format(
|
|
self, width: int = 80, *, use_color: bool | None = None,
|
|
annotation_prefix: str = " # ",
|
|
source_map: list[list[tuple[int, int, Any]]] | None = None
|
|
) -> str:
|
|
"""
|
|
Formats a pretty-printer document as a string.
|
|
|
|
Args:
|
|
source_map: for each line in the output, contains a list of
|
|
(start column, end column, source) tuples. Each tuple associates a
|
|
region of output text with a source.
|
|
"""
|
|
if use_color is None:
|
|
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
|
return self._format(
|
|
width, use_color=use_color, annotation_prefix=annotation_prefix,
|
|
source_map=source_map)
|
|
Doc.format = _format
|
|
Doc.__str__ = lambda self: self.format() # type: ignore[method-assign]
|
|
|
|
nil = _pretty_printer.nil
|
|
text = _pretty_printer.text
|
|
concat = _pretty_printer.concat
|
|
brk = _pretty_printer.brk
|
|
group = _pretty_printer.group
|
|
nest = _pretty_printer.nest
|
|
color = _pretty_printer.color
|
|
source_map = _pretty_printer.source_map
|
|
|
|
|
|
type_annotation = partial(color, intensity=Intensity.NORMAL,
|
|
foreground=Color.MAGENTA)
|
|
keyword = partial(color, intensity=Intensity.BRIGHT, foreground=Color.BLUE)
|
|
|
|
|
|
def join(sep: Doc, docs: Sequence[Doc]) -> Doc:
|
|
"""Concatenates `docs`, separated by `sep`."""
|
|
docs = list(docs)
|
|
if len(docs) == 0:
|
|
return nil()
|
|
if len(docs) == 1:
|
|
return docs[0]
|
|
xs = [docs[0]]
|
|
for doc in docs[1:]:
|
|
xs.append(sep)
|
|
xs.append(doc)
|
|
return concat(xs)
|