DriverTrac/venv/lib/python3.12/site-packages/polars/_utils/various.py

783 lines
25 KiB
Python

from __future__ import annotations
import inspect
import os
import re
import sys
import warnings
from collections import Counter
from collections.abc import (
Collection,
Generator,
Iterable,
MappingView,
Sequence,
Sized,
)
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
overload,
)
import polars as pl
from polars import functions as F
from polars._dependencies import _check_for_numpy, import_optional, subprocess
from polars._dependencies import numpy as np
from polars.datatypes import (
Boolean,
Date,
Datetime,
Decimal,
Duration,
Int64,
String,
Time,
)
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping, Reversible
from polars import DataFrame, Expr
from polars._typing import PolarsDataType, SizeUnit
if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeGuard
else:
from typing_extensions import ParamSpec, TypeGuard
P = ParamSpec("P")
T = TypeVar("T")
# note: reversed views don't match as instances of MappingView
if sys.version_info >= (3, 11):
_views: list[Reversible[Any]] = [{}.keys(), {}.values(), {}.items()]
_reverse_mapping_views = tuple(type(reversed(view)) for view in _views)
def _process_null_values(
null_values: None | str | Sequence[str] | dict[str, str] = None,
) -> None | str | Sequence[str] | list[tuple[str, str]]:
if isinstance(null_values, dict):
return list(null_values.items())
else:
return null_values
def _is_generator(val: object | Iterator[T]) -> TypeIs[Iterator[T]]:
return (
(isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
or isinstance(val, MappingView)
or (sys.version_info >= (3, 11) and isinstance(val, _reverse_mapping_views))
)
def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> bool:
"""Check whether the given iterable is of the given type(s)."""
return all(isinstance(x, eltype) for x in val)
def is_path_or_str_sequence(
val: object, *, allow_str: bool = False, include_series: bool = False
) -> TypeGuard[Sequence[str | Path]]:
"""
Check that `val` is a sequence of strings or paths.
Note that a single string is a sequence of strings by definition, use
`allow_str=False` to return False on a single string.
"""
if allow_str is False and isinstance(val, str):
return False
elif _check_for_numpy(val) and isinstance(val, np.ndarray):
return np.issubdtype(val.dtype, np.str_)
elif include_series and isinstance(val, pl.Series):
return val.dtype == pl.String
return (
not isinstance(val, bytes)
and isinstance(val, Sequence)
and _is_iterable_of(val, (Path, str))
)
def is_bool_sequence(
val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[bool]]:
"""Check whether the given sequence is a sequence of booleans."""
if _check_for_numpy(val) and isinstance(val, np.ndarray):
return val.dtype == np.bool_
elif include_series and isinstance(val, pl.Series):
return val.dtype == pl.Boolean
return isinstance(val, Sequence) and _is_iterable_of(val, bool)
def is_int_sequence(
val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[int]]:
"""Check whether the given sequence is a sequence of integers."""
if _check_for_numpy(val) and isinstance(val, np.ndarray):
return np.issubdtype(val.dtype, np.integer)
elif include_series and isinstance(val, pl.Series):
return val.dtype.is_integer()
return isinstance(val, Sequence) and _is_iterable_of(val, int)
def is_sequence(
val: object, *, include_series: bool = False
) -> TypeGuard[Sequence[Any]]:
"""Check whether the given input is a numpy array or python sequence."""
return (_check_for_numpy(val) and isinstance(val, np.ndarray)) or (
isinstance(val, (pl.Series, Sequence) if include_series else Sequence)
and not isinstance(val, str)
)
def is_str_sequence(
val: object, *, allow_str: bool = False, include_series: bool = False
) -> TypeGuard[Sequence[str]]:
"""
Check that `val` is a sequence of strings.
Note that a single string is a sequence of strings by definition, use
`allow_str=False` to return False on a single string.
"""
if allow_str is False and isinstance(val, str):
return False
elif _check_for_numpy(val) and isinstance(val, np.ndarray):
return np.issubdtype(val.dtype, np.str_)
elif include_series and isinstance(val, pl.Series):
return val.dtype == pl.String
return isinstance(val, Sequence) and _is_iterable_of(val, str)
def is_column(obj: Any) -> bool:
"""Indicate if the given object is a basic/unaliased column."""
from polars.expr import Expr
return isinstance(obj, Expr) and obj.meta.is_column()
def warn_null_comparison(obj: Any) -> None:
"""Warn for possibly unintentional comparisons with None."""
if obj is None:
warnings.warn(
"Comparisons with None always result in null. Consider using `.is_null()` or `.is_not_null()`.",
UserWarning,
stacklevel=find_stacklevel(),
)
def range_to_series(
name: str, rng: range, dtype: PolarsDataType | None = None
) -> pl.Series:
"""Fast conversion of the given range to a Series."""
dtype = dtype or Int64
if dtype.is_integer():
range = F.int_range( # type: ignore[call-overload]
start=rng.start, end=rng.stop, step=rng.step, dtype=dtype, eager=True
)
else:
range = F.int_range(
start=rng.start, end=rng.stop, step=rng.step, eager=True
).cast(dtype)
return range.alias(name)
def range_to_slice(rng: range) -> slice:
"""Return the given range as an equivalent slice."""
return slice(rng.start, rng.stop, rng.step)
def _in_notebook() -> bool:
try:
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def _in_marimo_notebook() -> bool:
try:
import marimo as mo
return mo.running_in_notebook() # pragma: no cover
except ImportError:
return False
def arrlen(obj: Any) -> int | None:
"""Return length of (non-string/dict) sequence; returns None for non-sequences."""
try:
return None if isinstance(obj, (str, dict)) else len(obj)
except TypeError:
return None
def normalize_filepath(path: str | Path, *, check_not_directory: bool = True) -> str:
"""Create a string path, expanding the home directory if present."""
# don't use pathlib here as it modifies slashes (s3:// -> s3:/)
path = os.path.expanduser(path) # noqa: PTH111
if (
check_not_directory
and os.path.exists(path) # noqa: PTH110
and os.path.isdir(path) # noqa: PTH112
):
msg = f"expected a file path; {path!r} is a directory"
raise IsADirectoryError(msg)
return path
def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
"""Simple version parser; split into a tuple of ints for comparison."""
if isinstance(version, str):
version = version.split(".")
return tuple(int(re.sub(r"\D", "", str(v))) for v in version)
def ordered_unique(values: Sequence[Any]) -> list[Any]:
"""Return unique list of sequence values, maintaining their order of appearance."""
seen: set[Any] = set()
add_ = seen.add
return [v for v in values if not (v in seen or add_(v))]
def deduplicate_names(names: Iterable[str]) -> list[str]:
"""Ensure name uniqueness by appending a counter to subsequent duplicates."""
seen: MutableMapping[str, int] = Counter()
deduped = []
for nm in names:
deduped.append(f"{nm}{seen[nm] - 1}" if nm in seen else nm)
seen[nm] += 1
return deduped
@overload
def scale_bytes(sz: int, unit: SizeUnit) -> int | float: ...
@overload
def scale_bytes(sz: Expr, unit: SizeUnit) -> Expr: ...
def scale_bytes(sz: int | Expr, unit: SizeUnit) -> int | float | Expr:
"""Scale size in bytes to other size units (eg: "kb", "mb", "gb", "tb")."""
if unit in {"b", "bytes"}:
return sz
elif unit in {"kb", "kilobytes"}:
return sz / 1024
elif unit in {"mb", "megabytes"}:
return sz / 1024**2
elif unit in {"gb", "gigabytes"}:
return sz / 1024**3
elif unit in {"tb", "terabytes"}:
return sz / 1024**4
else:
msg = f"`unit` must be one of {{'b', 'kb', 'mb', 'gb', 'tb'}}, got {unit!r}"
raise ValueError(msg)
def _cast_repr_strings_with_schema(
df: DataFrame, schema: dict[str, PolarsDataType | None]
) -> DataFrame:
"""
Utility function to cast table repr/string values into frame-native types.
Parameters
----------
df
Dataframe containing string-repr column data.
schema
DataFrame schema containing the desired end-state types.
Notes
-----
Table repr strings are less strict (or different) than equivalent CSV data, so need
special handling; as this function is only used for reprs, parsing is flexible.
"""
tp: PolarsDataType | None
if not df.is_empty():
for tp in df.schema.values():
if tp != String:
msg = f"DataFrame should contain only String repr data; found {tp!r}"
raise TypeError(msg)
special_floats = {"-inf", "+inf", "inf", "nan"}
# duration string scaling
ns_sec = 1_000_000_000
duration_scaling = {
"ns": 1,
"us": 1_000,
"µs": 1_000,
"ms": 1_000_000,
"s": ns_sec,
"m": ns_sec * 60,
"h": ns_sec * 60 * 60,
"d": ns_sec * 3_600 * 24,
"w": ns_sec * 3_600 * 24 * 7,
}
# identify duration units and convert to nanoseconds
def str_duration_(td: str | None) -> int | None:
return (
None
if td is None
else sum(
int(value) * duration_scaling[unit.strip()]
for value, unit in re.findall(r"([+-]?\d+)(\D+)", td)
)
)
cast_cols = {}
for c, tp in schema.items():
if tp is not None:
if tp.base_type() == Datetime:
tp_base = Datetime(tp.time_unit) # type: ignore[union-attr]
d = F.col(c).str.replace(r"[A-Z ]+$", "")
cast_cols[c] = (
F.when(d.str.len_bytes() == 19)
.then(d + ".000000000")
.otherwise(d + "000000000")
.str.slice(0, 29)
.str.strptime(tp_base, "%Y-%m-%d %H:%M:%S.%9f")
)
if getattr(tp, "time_zone", None) is not None:
cast_cols[c] = cast_cols[c].dt.replace_time_zone(tp.time_zone) # type: ignore[union-attr]
elif tp == Date:
cast_cols[c] = F.col(c).str.strptime(tp, "%Y-%m-%d") # type: ignore[arg-type]
elif tp == Time:
cast_cols[c] = (
F.when(F.col(c).str.len_bytes() == 8)
.then(F.col(c) + ".000000000")
.otherwise(F.col(c) + "000000000")
.str.slice(0, 18)
.str.strptime(tp, "%H:%M:%S.%9f") # type: ignore[arg-type]
)
elif tp == Duration:
cast_cols[c] = (
F.col(c)
.map_elements(str_duration_, return_dtype=Int64)
.cast(Duration("ns"))
.cast(tp)
)
elif tp == Boolean:
cast_cols[c] = F.col(c).replace_strict({"true": True, "false": False})
elif tp in INTEGER_DTYPES:
int_string = F.col(c).str.replace_all(r"[^\d+-]", "")
cast_cols[c] = (
pl.when(int_string.str.len_bytes() > 0).then(int_string).cast(tp)
)
elif tp in FLOAT_DTYPES or tp.base_type() == Decimal:
# identify integer/fractional parts
integer_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$1")
fractional_part = F.col(c).str.replace(r"^(.*)\D(\d*)$", "$2")
cast_cols[c] = (
# check for empty string, special floats, or integer format
pl.when(
F.col(c).str.contains(r"^[+-]?\d*$")
| F.col(c).str.to_lowercase().is_in(special_floats)
)
.then(pl.when(F.col(c).str.len_bytes() > 0).then(F.col(c)))
# check for scientific notation
.when(F.col(c).str.contains("[eE]"))
.then(F.col(c).str.replace(r"[^eE\d]", "."))
.otherwise(
# recombine sanitised integer/fractional components
pl.concat_str(
integer_part.str.replace_all(r"[^\d+-]", ""),
fractional_part,
separator=".",
)
)
.cast(String)
.cast(tp)
)
elif tp != df.schema[c]:
cast_cols[c] = F.col(c).cast(tp)
return df.with_columns(**cast_cols) if cast_cols else df
# when building docs (with Sphinx) we need access to the functions
# associated with the namespaces from the class, as we don't have
# an instance; @sphinx_accessor is a @property that allows this.
NS = TypeVar("NS")
class sphinx_accessor(property):
def __get__( # type: ignore[override]
self,
instance: Any,
cls: type[NS],
) -> NS:
try:
return self.fget( # type: ignore[misc]
instance if isinstance(instance, cls) else cls
)
except (AttributeError, ImportError):
return self # type: ignore[return-value]
BUILDING_SPHINX_DOCS = os.getenv("BUILDING_SPHINX_DOCS")
class _NoDefault(Enum):
# "borrowed" from
# https://github.com/pandas-dev/pandas/blob/e7859983a814b1823cf26e3b491ae2fa3be47c53/pandas/_libs/lib.pyx#L2736-L2748
no_default = "NO_DEFAULT"
def __repr__(self) -> str:
return "<no_default>"
# the "no_default" sentinel should typically be used when one of the valid parameter
# values is None, as otherwise we cannot determine if the caller has set that value.
no_default = _NoDefault.no_default
NoDefault = Literal[_NoDefault.no_default]
def find_stacklevel() -> int:
"""
Find the first place in the stack that is not inside Polars.
Taken from:
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
"""
pkg_dir = str(Path(pl.__file__).parent)
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
try:
while frame:
fname = inspect.getfile(frame)
if fname.startswith(pkg_dir) or (
(qualname := getattr(frame.f_code, "co_qualname", None))
# ignore @singledispatch wrappers
and qualname.startswith("singledispatch.")
):
frame = frame.f_back
n += 1
else:
break
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a 'finally' clause.
del frame
return n
def issue_warning(message: str, category: type[Warning], **kwargs: Any) -> None:
"""
Issue a warning.
Parameters
----------
message
The message associated with the warning.
category
The warning category.
**kwargs
Additional arguments for `warnings.warn`. Note that the `stacklevel` is
determined automatically.
"""
warnings.warn(
message=message, category=category, stacklevel=find_stacklevel(), **kwargs
)
def _get_stack_locals(
of_type: type | Collection[type] | Callable[[Any], bool] | None = None,
*,
named: str | Collection[str] | None = None,
n_objects: int | None = None,
n_frames: int | None = None,
) -> dict[str, Any]:
"""
Retrieve f_locals from all (or the last 'n') stack frames from the calling location.
Parameters
----------
of_type
Only return objects of this type; can be a single class, tuple of
classes, or a callable that returns True/False if the object being
tested is considered a match.
n_objects
If specified, return only the most recent `n` matching objects.
n_frames
If specified, look at objects in the last `n` stack frames only.
named
If specified, only return objects matching the given name(s).
"""
objects = {}
examined_frames = 0
if isinstance(named, str):
named = (named,)
if n_frames is None:
n_frames = sys.maxsize
if inspect.isfunction(of_type):
matches_type = of_type
else:
if isinstance(of_type, Collection):
of_type = tuple(of_type)
def matches_type(obj: Any) -> bool: # type: ignore[misc]
return isinstance(obj, of_type) # type: ignore[arg-type]
if named is not None:
if isinstance(named, str):
named = (named,)
elif not isinstance(named, set):
named = set(named)
stack_frame = inspect.currentframe()
stack_frame = getattr(stack_frame, "f_back", None)
try:
while stack_frame and examined_frames < n_frames:
local_items = list(stack_frame.f_locals.items())
for nm, obj in reversed(local_items):
if (
nm not in objects
and (named is None or nm in named)
and (of_type is None or matches_type(obj))
):
objects[nm] = obj
if n_objects is not None and len(objects) >= n_objects:
return objects
stack_frame = stack_frame.f_back
examined_frames += 1
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a finally clause.
del stack_frame
return objects
# this is called from rust
def _polars_warn(msg: str, category: type[Warning] = UserWarning) -> None:
warnings.warn(
msg,
category=category,
stacklevel=find_stacklevel(),
)
def extend_bool(
value: bool | Sequence[bool],
n_match: int,
value_name: str,
match_name: str,
) -> Sequence[bool]:
"""Ensure the given bool or sequence of bools is the correct length."""
values = [value] * n_match if isinstance(value, bool) else value
if n_match != len(values):
msg = (
f"the length of `{value_name}` ({len(values)}) "
f"does not match the length of `{match_name}` ({n_match})"
)
raise ValueError(msg)
return values
def in_terminal_that_supports_colour() -> bool:
"""
Determine (within reason) if we are in an interactive terminal that supports color.
Note: this is not exhaustive, but it covers a lot (most?) of the common cases.
"""
if hasattr(sys.stdout, "isatty"):
# can enhance as necessary, but this is a reasonable start
return (
sys.stdout.isatty()
and (
sys.platform != "win32"
or "ANSICON" in os.environ
or "WT_SESSION" in os.environ
or os.environ.get("TERM_PROGRAM") == "vscode"
or os.environ.get("TERM") == "xterm-256color"
)
) or os.environ.get("PYCHARM_HOSTED") == "1"
return False
def parse_percentiles(
percentiles: Sequence[float] | float | None, *, inject_median: bool = False
) -> Sequence[float]:
"""
Transforms raw percentiles into our preferred format, adding the 50th percentile.
Raises a ValueError if the percentile sequence is invalid
(e.g. outside the range [0, 1])
"""
if isinstance(percentiles, float):
percentiles = [percentiles]
elif percentiles is None:
percentiles = []
if not all((0 <= p <= 1) for p in percentiles):
msg = "`percentiles` must all be in the range [0, 1]"
raise ValueError(msg)
sub_50_percentiles = sorted(p for p in percentiles if p < 0.5)
at_or_above_50_percentiles = sorted(p for p in percentiles if p >= 0.5)
if inject_median and (
not at_or_above_50_percentiles or at_or_above_50_percentiles[0] != 0.5
):
at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles]
return [*sub_50_percentiles, *at_or_above_50_percentiles]
def re_escape(s: str) -> str:
"""Escape a string for use in a Polars (Rust) regex."""
# note: almost the same as the standard python 're.escape' function, but
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)
# Don't rename or move. This is used by polars cloud
def display_dot_graph(
*,
dot: str,
show: bool = True,
output_path: str | Path | None = None,
raw_output: bool = False,
figsize: tuple[float, float] = (16.0, 12.0),
) -> str | None:
if raw_output:
# we do not show a graph, nor save a graph to disk
return dot
output_type = (
"svg"
if _in_notebook()
or _in_marimo_notebook()
or "POLARS_DOT_SVG_VIEWER" in os.environ
else "png"
)
try:
graph = subprocess.check_output(
["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
)
except (ImportError, FileNotFoundError):
msg = (
"the graphviz `dot` binary should be on your PATH."
"(If not installed you can download here: https://graphviz.org/download/)"
)
raise ImportError(msg) from None
if output_path:
Path(output_path).write_bytes(graph)
if not show:
return None
if _in_notebook():
from IPython.display import SVG, display
return display(SVG(graph))
elif _in_marimo_notebook():
import marimo as mo
return mo.Html(f"{graph.decode()}")
else:
if (cmd := os.environ.get("POLARS_DOT_SVG_VIEWER", None)) is not None:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".svg") as file:
file.write(graph)
file.flush()
cmd = cmd.replace("%file%", file.name)
subprocess.run(cmd, shell=True)
return None
import_optional(
"matplotlib",
err_prefix="",
err_suffix="should be installed to show graphs",
)
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
img = mpimg.imread(BytesIO(graph))
plt.axis("off")
plt.imshow(img)
plt.show()
return None
def qualified_type_name(obj: Any, *, qualify_polars: bool = False) -> str:
"""
Return the module-qualified name of the given object as a string.
Parameters
----------
obj
The object to get the qualified name for.
qualify_polars
If False (default), omit the module path for our own (Polars) objects.
"""
if isinstance(obj, type):
module = obj.__module__
name = obj.__name__
else:
module = obj.__class__.__module__
name = obj.__class__.__name__
if (
not module
or module == "builtins"
or (not qualify_polars and module.startswith("polars."))
):
return name
return f"{module}.{name}"
def require_same_type(current: Any, other: Any) -> None:
"""
Raise an error if the two arguments are not of the same type.
The check will not raise an error if one object is of a subclass of the other.
Parameters
----------
current
The object the type of which is being checked against.
other
An object that has to be of the same type.
"""
if not isinstance(other, type(current)) and not isinstance(current, type(other)):
msg = (
f"expected `other` to be a {qualified_type_name(current)!r}, "
f"not {qualified_type_name(other)!r}"
)
raise TypeError(msg)