DriverTrac/venv/lib/python3.12/site-packages/polars/_utils/getitem.py
2025-11-28 09:08:33 +05:30

458 lines
15 KiB
Python

from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, NoReturn, overload
import polars._reexport as pl
import polars.functions as F
from polars._dependencies import _check_for_numpy
from polars._dependencies import numpy as np
from polars._utils.constants import U32_MAX
from polars._utils.slice import PolarsSlice
from polars._utils.various import qualified_type_name, range_to_slice
from polars.datatypes.classes import (
Boolean,
Int8,
Int16,
Int32,
Int64,
String,
UInt32,
UInt64,
)
from polars.meta.index_type import get_index_type
if TYPE_CHECKING:
from collections.abc import Iterable
from polars import DataFrame, Series
from polars._typing import (
MultiColSelector,
MultiIndexSelector,
SingleColSelector,
SingleIndexSelector,
)
__all__ = [
"get_df_item_by_key",
"get_series_item_by_key",
]
@overload
def get_series_item_by_key(s: Series, key: SingleIndexSelector) -> Any: ...
@overload
def get_series_item_by_key(s: Series, key: MultiIndexSelector) -> Series: ...
def get_series_item_by_key(
s: Series, key: SingleIndexSelector | MultiIndexSelector
) -> Any | Series:
"""Select one or more elements from the Series."""
if isinstance(key, int):
return s._s.get_index_signed(key)
elif isinstance(key, slice):
return _select_elements_by_slice(s, key)
elif isinstance(key, range):
key = range_to_slice(key)
return _select_elements_by_slice(s, key)
elif isinstance(key, Sequence):
if not key:
return s.clear()
first = key[0]
if isinstance(first, bool):
_raise_on_boolean_mask()
try:
indices = pl.Series("", key, dtype=Int64)
except TypeError:
msg = f"cannot select elements using Sequence with elements of type {qualified_type_name(first)!r}"
raise TypeError(msg) from None
indices = _convert_series_to_indices(indices, s.len())
return _select_elements_by_index(s, indices)
elif isinstance(key, pl.Series):
indices = _convert_series_to_indices(key, s.len())
return _select_elements_by_index(s, indices)
elif _check_for_numpy(key) and isinstance(key, np.ndarray):
indices = _convert_np_ndarray_to_indices(key, s.len())
return _select_elements_by_index(s, indices)
msg = f"cannot select elements using key of type {qualified_type_name(key)!r}: {key!r}"
raise TypeError(msg)
def _select_elements_by_slice(s: Series, key: slice) -> Series:
return PolarsSlice(s).apply(key) # type: ignore[return-value]
def _select_elements_by_index(s: Series, key: Series) -> Series:
return s._from_pyseries(s._s.gather_with_series(key._s))
# `str` overlaps with `Sequence[str]`
# We can ignore this but we must keep this overload ordering
@overload
def get_df_item_by_key(
df: DataFrame, key: tuple[SingleIndexSelector, SingleColSelector]
) -> Any: ...
@overload
def get_df_item_by_key( # type: ignore[overload-overlap]
df: DataFrame, key: str | tuple[MultiIndexSelector, SingleColSelector]
) -> Series: ...
@overload
def get_df_item_by_key(
df: DataFrame,
key: (
SingleIndexSelector
| MultiIndexSelector
| MultiColSelector
| tuple[SingleIndexSelector, MultiColSelector]
| tuple[MultiIndexSelector, MultiColSelector]
),
) -> DataFrame: ...
def get_df_item_by_key(
df: DataFrame,
key: (
SingleIndexSelector
| SingleColSelector
| MultiColSelector
| MultiIndexSelector
| tuple[SingleIndexSelector, SingleColSelector]
| tuple[SingleIndexSelector, MultiColSelector]
| tuple[MultiIndexSelector, SingleColSelector]
| tuple[MultiIndexSelector, MultiColSelector]
),
) -> DataFrame | Series | Any:
"""Get part of the DataFrame as a new DataFrame, Series, or scalar."""
# Two inputs, e.g. df[1, 2:5]
if isinstance(key, tuple) and len(key) == 2:
row_key, col_key = key
# Support df[True, False] and df["a", "b"] as these are not ambiguous
if isinstance(row_key, (bool, str)):
return _select_columns(df, key) # type: ignore[arg-type]
selection = _select_columns(df, col_key)
if selection.is_empty():
return selection
elif isinstance(selection, pl.Series):
return get_series_item_by_key(selection, row_key)
else:
return _select_rows(selection, row_key)
# Single string input, e.g. df["a"]
if isinstance(key, str):
# This case is required because empty strings are otherwise treated
# as an empty Sequence in `_select_rows`
return df.get_column(key)
# Single input - df[1] - or multiple inputs - df["a", "b", "c"]
try:
return _select_rows(df, key) # type: ignore[arg-type]
except TypeError:
return _select_columns(df, key)
# `str` overlaps with `Sequence[str]`
# We can ignore this but we must keep this overload ordering
@overload
def _select_columns(df: DataFrame, key: SingleColSelector) -> Series: ... # type: ignore[overload-overlap]
@overload
def _select_columns(df: DataFrame, key: MultiColSelector) -> DataFrame: ...
def _select_columns(
df: DataFrame, key: SingleColSelector | MultiColSelector
) -> DataFrame | Series:
"""Select one or more columns from the DataFrame."""
if isinstance(key, int):
return df.to_series(key)
elif isinstance(key, str):
return df.get_column(key)
elif isinstance(key, slice):
start, stop, step = key.start, key.stop, key.step
# Fast path for common case: df[x, :]
if start is None and stop is None and step is None:
return df
if isinstance(start, str):
start = df.get_column_index(start)
if isinstance(stop, str):
stop = df.get_column_index(stop) + 1
int_slice = slice(start, stop, step)
rng = range(df.width)[int_slice]
return _select_columns_by_index(df, rng)
elif isinstance(key, range):
return _select_columns_by_index(df, key)
elif isinstance(key, Sequence):
if not key:
return df.__class__()
first = key[0]
if isinstance(first, bool):
return _select_columns_by_mask(df, key) # type: ignore[arg-type]
elif isinstance(first, int):
return _select_columns_by_index(df, key) # type: ignore[arg-type]
elif isinstance(first, str):
return _select_columns_by_name(df, key) # type: ignore[arg-type]
else:
msg = f"cannot select columns using Sequence with elements of type {qualified_type_name(first)!r}"
raise TypeError(msg)
elif isinstance(key, pl.Series):
if key.is_empty():
return df.__class__()
dtype = key.dtype
if dtype == String:
return _select_columns_by_name(df, key)
elif dtype.is_integer():
return _select_columns_by_index(df, key)
elif dtype == Boolean:
return _select_columns_by_mask(df, key)
else:
msg = f"cannot select columns using Series of type {dtype}"
raise TypeError(msg)
elif _check_for_numpy(key) and isinstance(key, np.ndarray):
if key.ndim == 0:
key = np.atleast_1d(key)
elif key.ndim != 1:
msg = "multi-dimensional NumPy arrays not supported as index"
raise TypeError(msg)
if len(key) == 0:
return df.__class__()
dtype_kind = key.dtype.kind
if dtype_kind in ("i", "u"):
return _select_columns_by_index(df, key)
elif dtype_kind == "b":
return _select_columns_by_mask(df, key)
elif isinstance(key[0], str):
return _select_columns_by_name(df, key)
else:
msg = f"cannot select columns using NumPy array of type {key.dtype}"
raise TypeError(msg)
msg = (
f"cannot select columns using key of type {qualified_type_name(key)!r}: {key!r}"
)
raise TypeError(msg)
def _select_columns_by_index(df: DataFrame, key: Iterable[int]) -> DataFrame:
series = [df.to_series(i) for i in key]
return df.__class__(series)
def _select_columns_by_name(df: DataFrame, key: Iterable[str]) -> DataFrame:
return df._from_pydf(df._df.select(list(key)))
def _select_columns_by_mask(
df: DataFrame, key: Sequence[bool] | Series | np.ndarray[Any, Any]
) -> DataFrame:
if len(key) != df.width:
msg = f"expected {df.width} values when selecting columns by boolean mask, got {len(key)}"
raise ValueError(msg)
indices = (i for i, val in enumerate(key) if val)
return _select_columns_by_index(df, indices)
@overload
def _select_rows(df: DataFrame, key: SingleIndexSelector) -> Series: ...
@overload
def _select_rows(df: DataFrame, key: MultiIndexSelector) -> DataFrame: ...
def _select_rows(
df: DataFrame, key: SingleIndexSelector | MultiIndexSelector
) -> DataFrame | Series:
"""Select one or more rows from the DataFrame."""
if isinstance(key, int):
num_rows = df.height
if (key >= num_rows) or (key < -num_rows):
msg = f"index {key} is out of bounds for DataFrame of height {num_rows}"
raise IndexError(msg)
return df.slice(key, 1)
if isinstance(key, slice):
return _select_rows_by_slice(df, key)
elif isinstance(key, range):
key = range_to_slice(key)
return _select_rows_by_slice(df, key)
elif isinstance(key, Sequence):
if not key:
return df.clear()
if isinstance(key[0], bool):
_raise_on_boolean_mask()
s = pl.Series("", key, dtype=Int64)
indices = _convert_series_to_indices(s, df.height)
return _select_rows_by_index(df, indices)
elif isinstance(key, pl.Series):
indices = _convert_series_to_indices(key, df.height)
return _select_rows_by_index(df, indices)
elif _check_for_numpy(key) and isinstance(key, np.ndarray):
indices = _convert_np_ndarray_to_indices(key, df.height)
return _select_rows_by_index(df, indices)
else:
msg = f"cannot select rows using key of type {qualified_type_name(key)!r}: {key!r}"
raise TypeError(msg)
def _select_rows_by_slice(df: DataFrame, key: slice) -> DataFrame:
return PolarsSlice(df).apply(key) # type: ignore[return-value]
def _select_rows_by_index(df: DataFrame, key: Series) -> DataFrame:
return df._from_pydf(df._df.gather_with_series(key._s))
# UTILS
def _convert_series_to_indices(s: Series, size: int) -> Series:
"""Convert a Series to indices, taking into account negative values."""
# Unsigned or signed Series (ordered from fastest to slowest).
# - pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx) Series indexes.
# - Other unsigned Series indexes are converted to pl.UInt32 (polars)
# or pl.UInt64 (polars_u64_idx).
# - Signed Series indexes are converted pl.UInt32 (polars) or
# pl.UInt64 (polars_u64_idx) after negative indexes are converted
# to absolute indexes.
# pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx).
idx_type = get_index_type()
if s.dtype == idx_type:
return s
if not s.dtype.is_integer():
if s.dtype == Boolean:
_raise_on_boolean_mask()
else:
msg = f"cannot treat Series of type {s.dtype} as indices"
raise TypeError(msg)
if s.len() == 0:
return pl.Series(s.name, [], dtype=idx_type)
if idx_type == UInt32:
if s.dtype in {Int64, UInt64} and s.max() >= U32_MAX: # type: ignore[operator]
msg = "index positions should be smaller than 2^32"
raise ValueError(msg)
if s.dtype == Int64 and s.min() < -U32_MAX: # type: ignore[operator]
msg = "index positions should be greater than or equal to -2^32"
raise ValueError(msg)
if s.dtype.is_signed_integer():
if s.min() < 0: # type: ignore[operator]
if idx_type == UInt32:
idxs = s.cast(Int32) if s.dtype in {Int8, Int16} else s
else:
idxs = s.cast(Int64) if s.dtype in {Int8, Int16, Int32} else s
# Update negative indexes to absolute indexes.
return (
idxs.to_frame()
.select(
F.when(F.col(idxs.name) < 0)
.then(size + F.col(idxs.name))
.otherwise(F.col(idxs.name))
.cast(idx_type)
)
.to_series(0)
)
return s.cast(idx_type)
def _convert_np_ndarray_to_indices(arr: np.ndarray[Any, Any], size: int) -> Series:
"""Convert a NumPy ndarray to indices, taking into account negative values."""
# Unsigned or signed Numpy array (ordered from fastest to slowest).
# - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array
# indexes.
# - Other unsigned numpy array indexes are converted to pl.UInt32
# (polars) or pl.UInt64 (polars_u64_idx).
# - Signed numpy array indexes are converted pl.UInt32 (polars) or
# pl.UInt64 (polars_u64_idx) after negative indexes are converted
# to absolute indexes.
if arr.ndim == 0:
arr = np.atleast_1d(arr)
if arr.ndim != 1:
msg = "only 1D NumPy arrays can be treated as indices"
raise TypeError(msg)
idx_type = get_index_type()
if len(arr) == 0:
return pl.Series("", [], dtype=idx_type)
# Numpy array with signed or unsigned integers.
if arr.dtype.kind not in ("i", "u"):
if arr.dtype.kind == "b":
_raise_on_boolean_mask()
else:
msg = f"cannot treat NumPy array of type {arr.dtype} as indices"
raise TypeError(msg)
if idx_type == UInt32:
if arr.dtype in {np.int64, np.uint64} and arr.max() >= U32_MAX:
msg = "index positions should be smaller than 2^32"
raise ValueError(msg)
if arr.dtype == np.int64 and arr.min() < -U32_MAX:
msg = "index positions should be greater than or equal to -2^32"
raise ValueError(msg)
if arr.dtype.kind == "i" and arr.min() < 0:
if idx_type == UInt32:
if arr.dtype in (np.int8, np.int16):
arr = arr.astype(np.int32)
else:
if arr.dtype in (np.int8, np.int16, np.int32):
arr = arr.astype(np.int64)
# Update negative indexes to absolute indexes.
arr = np.where(arr < 0, size + arr, arr)
# numpy conversion is much faster
arr = arr.astype(np.uint32) if idx_type == UInt32 else arr.astype(np.uint64)
return pl.Series("", arr, dtype=idx_type)
def _raise_on_boolean_mask() -> NoReturn:
msg = (
"selecting rows by passing a boolean mask to `__getitem__` is not supported"
"\n\nHint: Use the `filter` method instead."
)
raise TypeError(msg)