783 lines
25 KiB
Python
783 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
import inspect
|
|
import os
|
|
import re
|
|
import sys
|
|
import warnings
|
|
from collections import Counter
|
|
from collections.abc import (
|
|
Collection,
|
|
Generator,
|
|
Iterable,
|
|
MappingView,
|
|
Sequence,
|
|
Sized,
|
|
)
|
|
from enum import Enum
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Literal,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
import polars as pl
|
|
from polars import functions as F
|
|
from polars._dependencies import _check_for_numpy, import_optional, subprocess
|
|
from polars._dependencies import numpy as np
|
|
from polars.datatypes import (
|
|
Boolean,
|
|
Date,
|
|
Datetime,
|
|
Decimal,
|
|
Duration,
|
|
Int64,
|
|
String,
|
|
Time,
|
|
)
|
|
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator, MutableMapping, Reversible
|
|
|
|
from polars import DataFrame, Expr
|
|
from polars._typing import PolarsDataType, SizeUnit
|
|
|
|
if sys.version_info >= (3, 13):
|
|
from typing import TypeIs
|
|
else:
|
|
from typing_extensions import TypeIs
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from typing import ParamSpec, TypeGuard
|
|
else:
|
|
from typing_extensions import ParamSpec, TypeGuard
|
|
|
|
P = ParamSpec("P")
|
|
T = TypeVar("T")
|
|
|
|
# note: reversed views don't match as instances of MappingView
|
|
if sys.version_info >= (3, 11):
|
|
_views: list[Reversible[Any]] = [{}.keys(), {}.values(), {}.items()]
|
|
_reverse_mapping_views = tuple(type(reversed(view)) for view in _views)
|
|
|
|
|
|
def _process_null_values(
|
|
null_values: None | str | Sequence[str] | dict[str, str] = None,
|
|
) -> None | str | Sequence[str] | list[tuple[str, str]]:
|
|
if isinstance(null_values, dict):
|
|
return list(null_values.items())
|
|
else:
|
|
return null_values
|
|
|
|
|
|
def _is_generator(val: object | Iterator[T]) -> TypeIs[Iterator[T]]:
|
|
return (
|
|
(isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
|
|
or isinstance(val, MappingView)
|
|
or (sys.version_info >= (3, 11) and isinstance(val, _reverse_mapping_views))
|
|
)
|
|
|
|
|
|
def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> bool:
|
|
"""Check whether the given iterable is of the given type(s)."""
|
|
return all(isinstance(x, eltype) for x in val)
|
|
|
|
|
|
def is_path_or_str_sequence(
|
|
val: object, *, allow_str: bool = False, include_series: bool = False
|
|
) -> TypeGuard[Sequence[str | Path]]:
|
|
"""
|
|
Check that `val` is a sequence of strings or paths.
|
|
|
|
Note that a single string is a sequence of strings by definition, use
|
|
`allow_str=False` to return False on a single string.
|
|
"""
|
|
if allow_str is False and isinstance(val, str):
|
|
return False
|
|
elif _check_for_numpy(val) and isinstance(val, np.ndarray):
|
|
return np.issubdtype(val.dtype, np.str_)
|
|
elif include_series and isinstance(val, pl.Series):
|
|
return val.dtype == pl.String
|
|
return (
|
|
not isinstance(val, bytes)
|
|
and isinstance(val, Sequence)
|
|
and _is_iterable_of(val, (Path, str))
|
|
)
|
|
|
|
|
|
def is_bool_sequence(
|
|
val: object, *, include_series: bool = False
|
|
) -> TypeGuard[Sequence[bool]]:
|
|
"""Check whether the given sequence is a sequence of booleans."""
|
|
if _check_for_numpy(val) and isinstance(val, np.ndarray):
|
|
return val.dtype == np.bool_
|
|
elif include_series and isinstance(val, pl.Series):
|
|
return val.dtype == pl.Boolean
|
|
return isinstance(val, Sequence) and _is_iterable_of(val, bool)
|
|
|
|
|
|
def is_int_sequence(
|
|
val: object, *, include_series: bool = False
|
|
) -> TypeGuard[Sequence[int]]:
|
|
"""Check whether the given sequence is a sequence of integers."""
|
|
if _check_for_numpy(val) and isinstance(val, np.ndarray):
|
|
return np.issubdtype(val.dtype, np.integer)
|
|
elif include_series and isinstance(val, pl.Series):
|
|
return val.dtype.is_integer()
|
|
return isinstance(val, Sequence) and _is_iterable_of(val, int)
|
|
|
|
|
|
def is_sequence(
|
|
val: object, *, include_series: bool = False
|
|
) -> TypeGuard[Sequence[Any]]:
|
|
"""Check whether the given input is a numpy array or python sequence."""
|
|
return (_check_for_numpy(val) and isinstance(val, np.ndarray)) or (
|
|
isinstance(val, (pl.Series, Sequence) if include_series else Sequence)
|
|
and not isinstance(val, str)
|
|
)
|
|
|
|
|
|
def is_str_sequence(
|
|
val: object, *, allow_str: bool = False, include_series: bool = False
|
|
) -> TypeGuard[Sequence[str]]:
|
|
"""
|
|
Check that `val` is a sequence of strings.
|
|
|
|
Note that a single string is a sequence of strings by definition, use
|
|
`allow_str=False` to return False on a single string.
|
|
"""
|
|
if allow_str is False and isinstance(val, str):
|
|
return False
|
|
elif _check_for_numpy(val) and isinstance(val, np.ndarray):
|
|
return np.issubdtype(val.dtype, np.str_)
|
|
elif include_series and isinstance(val, pl.Series):
|
|
return val.dtype == pl.String
|
|
return isinstance(val, Sequence) and _is_iterable_of(val, str)
|
|
|
|
|
|
def is_column(obj: Any) -> bool:
|
|
"""Indicate if the given object is a basic/unaliased column."""
|
|
from polars.expr import Expr
|
|
|
|
return isinstance(obj, Expr) and obj.meta.is_column()
|
|
|
|
|
|
def warn_null_comparison(obj: Any) -> None:
|
|
"""Warn for possibly unintentional comparisons with None."""
|
|
if obj is None:
|
|
warnings.warn(
|
|
"Comparisons with None always result in null. Consider using `.is_null()` or `.is_not_null()`.",
|
|
UserWarning,
|
|
stacklevel=find_stacklevel(),
|
|
)
|
|
|
|
|
|
def range_to_series(
|
|
name: str, rng: range, dtype: PolarsDataType | None = None
|
|
) -> pl.Series:
|
|
"""Fast conversion of the given range to a Series."""
|
|
dtype = dtype or Int64
|
|
if dtype.is_integer():
|
|
range = F.int_range( # type: ignore[call-overload]
|
|
start=rng.start, end=rng.stop, step=rng.step, dtype=dtype, eager=True
|
|
)
|
|
else:
|
|
range = F.int_range(
|
|
start=rng.start, end=rng.stop, step=rng.step, eager=True
|
|
).cast(dtype)
|
|
return range.alias(name)
|
|
|
|
|
|
def range_to_slice(rng: range) -> slice:
|
|
"""Return the given range as an equivalent slice."""
|
|
return slice(rng.start, rng.stop, rng.step)
|
|
|
|
|
|
def _in_notebook() -> bool:
|
|
try:
|
|
from IPython import get_ipython
|
|
|
|
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
|
return False
|
|
except ImportError:
|
|
return False
|
|
except AttributeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _in_marimo_notebook() -> bool:
|
|
try:
|
|
import marimo as mo
|
|
|
|
return mo.running_in_notebook() # pragma: no cover
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def arrlen(obj: Any) -> int | None:
|
|
"""Return length of (non-string/dict) sequence; returns None for non-sequences."""
|
|
try:
|
|
return None if isinstance(obj, (str, dict)) else len(obj)
|
|
except TypeError:
|
|
return None
|
|
|
|
|
|
def normalize_filepath(path: str | Path, *, check_not_directory: bool = True) -> str:
|
|
"""Create a string path, expanding the home directory if present."""
|
|
# don't use pathlib here as it modifies slashes (s3:// -> s3:/)
|
|
path = os.path.expanduser(path) # noqa: PTH111
|
|
if (
|
|
check_not_directory
|
|
and os.path.exists(path) # noqa: PTH110
|
|
and os.path.isdir(path) # noqa: PTH112
|
|
):
|
|
msg = f"expected a file path; {path!r} is a directory"
|
|
raise IsADirectoryError(msg)
|
|
return path
|
|
|
|
|
|
def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
|
|
"""Simple version parser; split into a tuple of ints for comparison."""
|
|
if isinstance(version, str):
|
|
version = version.split(".")
|
|
return tuple(int(re.sub(r"\D", "", str(v))) for v in version)
|
|
|
|
|
|
def ordered_unique(values: Sequence[Any]) -> list[Any]:
|
|
"""Return unique list of sequence values, maintaining their order of appearance."""
|
|
seen: set[Any] = set()
|
|
add_ = seen.add
|
|
return [v for v in values if not (v in seen or add_(v))]
|
|
|
|
|
|
def deduplicate_names(names: Iterable[str]) -> list[str]:
|
|
"""Ensure name uniqueness by appending a counter to subsequent duplicates."""
|
|
seen: MutableMapping[str, int] = Counter()
|
|
deduped = []
|
|
for nm in names:
|
|
deduped.append(f"{nm}{seen[nm] - 1}" if nm in seen else nm)
|
|
seen[nm] += 1
|
|
return deduped
|
|
|
|
|
|
@overload
|
|
def scale_bytes(sz: int, unit: SizeUnit) -> int | float: ...
|
|
|
|
|
|
@overload
|
|
def scale_bytes(sz: Expr, unit: SizeUnit) -> Expr: ...
|
|
|
|
|
|
def scale_bytes(sz: int | Expr, unit: SizeUnit) -> int | float | Expr:
|
|
"""Scale size in bytes to other size units (eg: "kb", "mb", "gb", "tb")."""
|
|
if unit in {"b", "bytes"}:
|
|
return sz
|
|
elif unit in {"kb", "kilobytes"}:
|
|
return sz / 1024
|
|
elif unit in {"mb", "megabytes"}:
|
|
return sz / 1024**2
|
|
elif unit in {"gb", "gigabytes"}:
|
|
return sz / 1024**3
|
|
elif unit in {"tb", "terabytes"}:
|
|
return sz / 1024**4
|
|
else:
|
|
msg = f"`unit` must be one of {{'b', 'kb', 'mb', 'gb', 'tb'}}, got {unit!r}"
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _cast_repr_strings_with_schema(
|
|
df: DataFrame, schema: dict[str, PolarsDataType | None]
|
|
) -> DataFrame:
|
|
"""
|
|
Utility function to cast table repr/string values into frame-native types.
|
|
|
|
Parameters
|
|
----------
|
|
df
|
|
Dataframe containing string-repr column data.
|
|
schema
|
|
DataFrame schema containing the desired end-state types.
|
|
|
|
Notes
|
|
-----
|
|
Table repr strings are less strict (or different) than equivalent CSV data, so need
|
|
special handling; as this function is only used for reprs, parsing is flexible.
|
|
"""
|
|
tp: PolarsDataType | None
|
|
if not df.is_empty():
|
|
for tp in df.schema.values():
|
|
if tp != String:
|
|
msg = f"DataFrame should contain only String repr data; found {tp!r}"
|
|
raise TypeError(msg)
|
|
|
|
special_floats = {"-inf", "+inf", "inf", "nan"}
|
|
|
|
# duration string scaling
|
|
ns_sec = 1_000_000_000
|
|
duration_scaling = {
|
|
"ns": 1,
|
|
"us": 1_000,
|
|
"µs": 1_000,
|
|
"ms": 1_000_000,
|
|
"s": ns_sec,
|
|
"m": ns_sec * 60,
|
|
"h": ns_sec * 60 * 60,
|
|
"d": ns_sec * 3_600 * 24,
|
|
"w": ns_sec * 3_600 * 24 * 7,
|
|
}
|
|
|
|
# identify duration units and convert to nanoseconds
|
|
def str_duration_(td: str | None) -> int | None:
|
|
return (
|
|
None
|
|
if td is None
|
|
else sum(
|
|
int(value) * duration_scaling[unit.strip()]
|
|
for value, unit in re.findall(r"([+-]?\d+)(\D+)", td)
|
|
)
|
|
)
|
|
|
|
cast_cols = {}
|
|
for c, tp in schema.items():
|
|
if tp is not None:
|
|
if tp.base_type() == Datetime:
|
|
tp_base = Datetime(tp.time_unit) # type: ignore[union-attr]
|
|
d = F.col(c).str.replace(r"[A-Z ]+$", "")
|
|
cast_cols[c] = (
|
|
F.when(d.str.len_bytes() == 19)
|
|
.then(d + ".000000000")
|
|
.otherwise(d + "000000000")
|
|
.str.slice(0, 29)
|
|
.str.strptime(tp_base, "%Y-%m-%d %H:%M:%S.%9f")
|
|
)
|
|
if getattr(tp, "time_zone", None) is not None:
|
|
cast_cols[c] = cast_cols[c].dt.replace_time_zone(tp.time_zone) # type: ignore[union-attr]
|
|
elif tp == Date:
|
|
cast_cols[c] = F.col(c).str.strptime(tp, "%Y-%m-%d") # type: ignore[arg-type]
|
|
elif tp == Time:
|
|
cast_cols[c] = (
|
|
F.when(F.col(c).str.len_bytes() == 8)
|
|
.then(F.col(c) + ".000000000")
|
|
.otherwise(F.col(c) + "000000000")
|
|
.str.slice(0, 18)
|
|
.str.strptime(tp, "%H:%M:%S.%9f") # type: ignore[arg-type]
|
|
)
|
|
elif tp == Duration:
|
|
cast_cols[c] = (
|
|
F.col(c)
|
|
.map_elements(str_duration_, return_dtype=Int64)
|
|
.cast(Duration("ns"))
|
|
.cast(tp)
|
|
)
|
|
elif tp == Boolean:
|
|
cast_cols[c] = F.col(c).replace_strict({"true": True, "false": False})
|
|
elif tp in INTEGER_DTYPES:
|
|
int_string = F.col(c).str.replace_all(r"[^\d+-]", "")
|
|
cast_cols[c] = (
|
|
pl.when(int_string.str.len_bytes() > 0).then(int_string).cast(tp)
|
|
)
|
|
elif tp in FLOAT_DTYPES or tp.base_type() == Decimal:
|
|
# identify integer/fractional parts
|
|
integer_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$1")
|
|
fractional_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$2")
|
|
cast_cols[c] = (
|
|
# check for empty string, special floats, or integer format
|
|
pl.when(
|
|
F.col(c).str.contains(r"^[+-]?\d*$")
|
|
| F.col(c).str.to_lowercase().is_in(special_floats)
|
|
)
|
|
.then(pl.when(F.col(c).str.len_bytes() > 0).then(F.col(c)))
|
|
# check for scientific notation
|
|
.when(F.col(c).str.contains("[eE]"))
|
|
.then(F.col(c).str.replace(r"[^eE\d]", "."))
|
|
.otherwise(
|
|
# recombine sanitised integer/fractional components
|
|
pl.concat_str(
|
|
integer_part.str.replace_all(r"[^\d+-]", ""),
|
|
fractional_part,
|
|
separator=".",
|
|
)
|
|
)
|
|
.cast(String)
|
|
.cast(tp)
|
|
)
|
|
elif tp != df.schema[c]:
|
|
cast_cols[c] = F.col(c).cast(tp)
|
|
|
|
return df.with_columns(**cast_cols) if cast_cols else df
|
|
|
|
|
|
# when building docs (with Sphinx) we need access to the functions
|
|
# associated with the namespaces from the class, as we don't have
|
|
# an instance; @sphinx_accessor is a @property that allows this.
|
|
NS = TypeVar("NS")
|
|
|
|
|
|
class sphinx_accessor(property):
|
|
def __get__( # type: ignore[override]
|
|
self,
|
|
instance: Any,
|
|
cls: type[NS],
|
|
) -> NS:
|
|
try:
|
|
return self.fget( # type: ignore[misc]
|
|
instance if isinstance(instance, cls) else cls
|
|
)
|
|
except (AttributeError, ImportError):
|
|
return self # type: ignore[return-value]
|
|
|
|
|
|
BUILDING_SPHINX_DOCS = os.getenv("BUILDING_SPHINX_DOCS")
|
|
|
|
|
|
class _NoDefault(Enum):
|
|
# "borrowed" from
|
|
# https://github.com/pandas-dev/pandas/blob/e7859983a814b1823cf26e3b491ae2fa3be47c53/pandas/_libs/lib.pyx#L2736-L2748
|
|
no_default = "NO_DEFAULT"
|
|
|
|
def __repr__(self) -> str:
|
|
return "<no_default>"
|
|
|
|
|
|
# the "no_default" sentinel should typically be used when one of the valid parameter
|
|
# values is None, as otherwise we cannot determine if the caller has set that value.
|
|
no_default = _NoDefault.no_default
|
|
NoDefault = Literal[_NoDefault.no_default]
|
|
|
|
|
|
def find_stacklevel() -> int:
|
|
"""
|
|
Find the first place in the stack that is not inside Polars.
|
|
|
|
Taken from:
|
|
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
|
|
"""
|
|
pkg_dir = str(Path(pl.__file__).parent)
|
|
|
|
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
|
|
frame = inspect.currentframe()
|
|
n = 0
|
|
try:
|
|
while frame:
|
|
fname = inspect.getfile(frame)
|
|
if fname.startswith(pkg_dir) or (
|
|
(qualname := getattr(frame.f_code, "co_qualname", None))
|
|
# ignore @singledispatch wrappers
|
|
and qualname.startswith("singledispatch.")
|
|
):
|
|
frame = frame.f_back
|
|
n += 1
|
|
else:
|
|
break
|
|
finally:
|
|
# https://docs.python.org/3/library/inspect.html
|
|
# > Though the cycle detector will catch these, destruction of the frames
|
|
# > (and local variables) can be made deterministic by removing the cycle
|
|
# > in a 'finally' clause.
|
|
del frame
|
|
return n
|
|
|
|
|
|
def issue_warning(message: str, category: type[Warning], **kwargs: Any) -> None:
|
|
"""
|
|
Issue a warning.
|
|
|
|
Parameters
|
|
----------
|
|
message
|
|
The message associated with the warning.
|
|
category
|
|
The warning category.
|
|
**kwargs
|
|
Additional arguments for `warnings.warn`. Note that the `stacklevel` is
|
|
determined automatically.
|
|
"""
|
|
warnings.warn(
|
|
message=message, category=category, stacklevel=find_stacklevel(), **kwargs
|
|
)
|
|
|
|
|
|
def _get_stack_locals(
|
|
of_type: type | Collection[type] | Callable[[Any], bool] | None = None,
|
|
*,
|
|
named: str | Collection[str] | None = None,
|
|
n_objects: int | None = None,
|
|
n_frames: int | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Retrieve f_locals from all (or the last 'n') stack frames from the calling location.
|
|
|
|
Parameters
|
|
----------
|
|
of_type
|
|
Only return objects of this type; can be a single class, tuple of
|
|
classes, or a callable that returns True/False if the object being
|
|
tested is considered a match.
|
|
n_objects
|
|
If specified, return only the most recent `n` matching objects.
|
|
n_frames
|
|
If specified, look at objects in the last `n` stack frames only.
|
|
named
|
|
If specified, only return objects matching the given name(s).
|
|
"""
|
|
objects = {}
|
|
examined_frames = 0
|
|
|
|
if isinstance(named, str):
|
|
named = (named,)
|
|
if n_frames is None:
|
|
n_frames = sys.maxsize
|
|
|
|
if inspect.isfunction(of_type):
|
|
matches_type = of_type
|
|
else:
|
|
if isinstance(of_type, Collection):
|
|
of_type = tuple(of_type)
|
|
|
|
def matches_type(obj: Any) -> bool: # type: ignore[misc]
|
|
return isinstance(obj, of_type) # type: ignore[arg-type]
|
|
|
|
if named is not None:
|
|
if isinstance(named, str):
|
|
named = (named,)
|
|
elif not isinstance(named, set):
|
|
named = set(named)
|
|
|
|
stack_frame = inspect.currentframe()
|
|
stack_frame = getattr(stack_frame, "f_back", None)
|
|
try:
|
|
while stack_frame and examined_frames < n_frames:
|
|
local_items = list(stack_frame.f_locals.items())
|
|
for nm, obj in reversed(local_items):
|
|
if (
|
|
nm not in objects
|
|
and (named is None or nm in named)
|
|
and (of_type is None or matches_type(obj))
|
|
):
|
|
objects[nm] = obj
|
|
if n_objects is not None and len(objects) >= n_objects:
|
|
return objects
|
|
|
|
stack_frame = stack_frame.f_back
|
|
examined_frames += 1
|
|
finally:
|
|
# https://docs.python.org/3/library/inspect.html
|
|
# > Though the cycle detector will catch these, destruction of the frames
|
|
# > (and local variables) can be made deterministic by removing the cycle
|
|
# > in a finally clause.
|
|
del stack_frame
|
|
|
|
return objects
|
|
|
|
|
|
# this is called from rust
|
|
def _polars_warn(msg: str, category: type[Warning] = UserWarning) -> None:
|
|
warnings.warn(
|
|
msg,
|
|
category=category,
|
|
stacklevel=find_stacklevel(),
|
|
)
|
|
|
|
|
|
def extend_bool(
|
|
value: bool | Sequence[bool],
|
|
n_match: int,
|
|
value_name: str,
|
|
match_name: str,
|
|
) -> Sequence[bool]:
|
|
"""Ensure the given bool or sequence of bools is the correct length."""
|
|
values = [value] * n_match if isinstance(value, bool) else value
|
|
if n_match != len(values):
|
|
msg = (
|
|
f"the length of `{value_name}` ({len(values)}) "
|
|
f"does not match the length of `{match_name}` ({n_match})"
|
|
)
|
|
raise ValueError(msg)
|
|
return values
|
|
|
|
|
|
def in_terminal_that_supports_colour() -> bool:
|
|
"""
|
|
Determine (within reason) if we are in an interactive terminal that supports color.
|
|
|
|
Note: this is not exhaustive, but it covers a lot (most?) of the common cases.
|
|
"""
|
|
if hasattr(sys.stdout, "isatty"):
|
|
# can enhance as necessary, but this is a reasonable start
|
|
return (
|
|
sys.stdout.isatty()
|
|
and (
|
|
sys.platform != "win32"
|
|
or "ANSICON" in os.environ
|
|
or "WT_SESSION" in os.environ
|
|
or os.environ.get("TERM_PROGRAM") == "vscode"
|
|
or os.environ.get("TERM") == "xterm-256color"
|
|
)
|
|
) or os.environ.get("PYCHARM_HOSTED") == "1"
|
|
return False
|
|
|
|
|
|
def parse_percentiles(
|
|
percentiles: Sequence[float] | float | None, *, inject_median: bool = False
|
|
) -> Sequence[float]:
|
|
"""
|
|
Transforms raw percentiles into our preferred format, adding the 50th percentile.
|
|
|
|
Raises a ValueError if the percentile sequence is invalid
|
|
(e.g. outside the range [0, 1])
|
|
"""
|
|
if isinstance(percentiles, float):
|
|
percentiles = [percentiles]
|
|
elif percentiles is None:
|
|
percentiles = []
|
|
if not all((0 <= p <= 1) for p in percentiles):
|
|
msg = "`percentiles` must all be in the range [0, 1]"
|
|
raise ValueError(msg)
|
|
|
|
sub_50_percentiles = sorted(p for p in percentiles if p < 0.5)
|
|
at_or_above_50_percentiles = sorted(p for p in percentiles if p >= 0.5)
|
|
|
|
if inject_median and (
|
|
not at_or_above_50_percentiles or at_or_above_50_percentiles[0] != 0.5
|
|
):
|
|
at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles]
|
|
|
|
return [*sub_50_percentiles, *at_or_above_50_percentiles]
|
|
|
|
|
|
def re_escape(s: str) -> str:
|
|
"""Escape a string for use in a Polars (Rust) regex."""
|
|
# note: almost the same as the standard python 're.escape' function, but
|
|
# escapes _only_ those metachars with meaning to the rust regex crate
|
|
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
|
|
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)
|
|
|
|
|
|
# Don't rename or move. This is used by polars cloud
|
|
def display_dot_graph(
|
|
*,
|
|
dot: str,
|
|
show: bool = True,
|
|
output_path: str | Path | None = None,
|
|
raw_output: bool = False,
|
|
figsize: tuple[float, float] = (16.0, 12.0),
|
|
) -> str | None:
|
|
if raw_output:
|
|
# we do not show a graph, nor save a graph to disk
|
|
return dot
|
|
|
|
output_type = (
|
|
"svg"
|
|
if _in_notebook()
|
|
or _in_marimo_notebook()
|
|
or "POLARS_DOT_SVG_VIEWER" in os.environ
|
|
else "png"
|
|
)
|
|
|
|
try:
|
|
graph = subprocess.check_output(
|
|
["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
|
|
)
|
|
except (ImportError, FileNotFoundError):
|
|
msg = (
|
|
"the graphviz `dot` binary should be on your PATH."
|
|
"(If not installed you can download here: https://graphviz.org/download/)"
|
|
)
|
|
raise ImportError(msg) from None
|
|
|
|
if output_path:
|
|
Path(output_path).write_bytes(graph)
|
|
|
|
if not show:
|
|
return None
|
|
|
|
if _in_notebook():
|
|
from IPython.display import SVG, display
|
|
|
|
return display(SVG(graph))
|
|
elif _in_marimo_notebook():
|
|
import marimo as mo
|
|
|
|
return mo.Html(f"{graph.decode()}")
|
|
else:
|
|
if (cmd := os.environ.get("POLARS_DOT_SVG_VIEWER", None)) is not None:
|
|
import tempfile
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".svg") as file:
|
|
file.write(graph)
|
|
file.flush()
|
|
cmd = cmd.replace("%file%", file.name)
|
|
subprocess.run(cmd, shell=True)
|
|
return None
|
|
|
|
import_optional(
|
|
"matplotlib",
|
|
err_prefix="",
|
|
err_suffix="should be installed to show graphs",
|
|
)
|
|
import matplotlib.image as mpimg
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.figure(figsize=figsize)
|
|
img = mpimg.imread(BytesIO(graph))
|
|
plt.axis("off")
|
|
plt.imshow(img)
|
|
plt.show()
|
|
return None
|
|
|
|
|
|
def qualified_type_name(obj: Any, *, qualify_polars: bool = False) -> str:
|
|
"""
|
|
Return the module-qualified name of the given object as a string.
|
|
|
|
Parameters
|
|
----------
|
|
obj
|
|
The object to get the qualified name for.
|
|
qualify_polars
|
|
If False (default), omit the module path for our own (Polars) objects.
|
|
"""
|
|
if isinstance(obj, type):
|
|
module = obj.__module__
|
|
name = obj.__name__
|
|
else:
|
|
module = obj.__class__.__module__
|
|
name = obj.__class__.__name__
|
|
|
|
if (
|
|
not module
|
|
or module == "builtins"
|
|
or (not qualify_polars and module.startswith("polars."))
|
|
):
|
|
return name
|
|
|
|
return f"{module}.{name}"
|
|
|
|
|
|
def require_same_type(current: Any, other: Any) -> None:
|
|
"""
|
|
Raise an error if the two arguments are not of the same type.
|
|
|
|
The check will not raise an error if one object is of a subclass of the other.
|
|
|
|
Parameters
|
|
----------
|
|
current
|
|
The object the type of which is being checked against.
|
|
other
|
|
An object that has to be of the same type.
|
|
"""
|
|
if not isinstance(other, type(current)) and not isinstance(current, type(other)):
|
|
msg = (
|
|
f"expected `other` to be a {qualified_type_name(current)!r}, "
|
|
f"not {qualified_type_name(other)!r}"
|
|
)
|
|
raise TypeError(msg)
|