DriverTrac/venv/lib/python3.12/site-packages/polars/_utils/slice.py

216 lines
7.8 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Union
import polars._reexport as pl
if TYPE_CHECKING:
from polars import DataFrame, LazyFrame, Series
FrameOrSeries = Union["DataFrame", "Series"]
class PolarsSlice:
"""
Apply Python slice object to Polars DataFrame or Series.
Has full support for negative indexing and/or stride.
"""
stop: int
start: int
stride: int
slice_length: int
is_unbounded: bool
obj: FrameOrSeries
def __init__(self, obj: FrameOrSeries) -> None:
self.obj = obj
@staticmethod
def _as_original(lazy: LazyFrame, original: FrameOrSeries) -> FrameOrSeries:
"""Return lazy variant back to its original type."""
frame = lazy.collect()
return frame if isinstance(original, pl.DataFrame) else frame.to_series()
@staticmethod
def _lazify(obj: FrameOrSeries) -> LazyFrame:
"""Make lazy to ensure efficient/consistent handling."""
return obj.to_frame().lazy() if isinstance(obj, pl.Series) else obj.lazy()
def _slice_positive(self, obj: LazyFrame) -> LazyFrame:
"""Logic for slices with positive stride."""
# note: at this point stride is guaranteed to be > 1
return obj.slice(self.start, self.slice_length).gather_every(self.stride)
def _slice_negative(self, obj: LazyFrame) -> LazyFrame:
"""Logic for slices with negative stride."""
stride = abs(self.stride)
lazyslice = obj.slice(self.stop + 1, self.slice_length).reverse()
return lazyslice.gather_every(stride) if (stride > 1) else lazyslice
def _slice_setup(self, s: slice) -> None:
"""Normalise slice bounds, identify unbounded and/or zero-length slices."""
# can normalise slice indices as we know object size
obj_len = len(self.obj)
start, stop, stride = slice(s.start, s.stop, s.step).indices(obj_len)
# check if slice is actually unbounded
if stride >= 1:
self.is_unbounded = (start <= 0) and (stop >= obj_len)
else:
self.is_unbounded = (stop == -1) and (start >= obj_len - 1)
# determine slice length
if self.obj.is_empty():
self.slice_length = 0
elif self.is_unbounded:
self.slice_length = obj_len
else:
self.slice_length = (
0
if (
(start == stop)
or (stride > 0 and start > stop)
or (stride < 0 and start < stop)
)
else abs(stop - start)
)
self.start, self.stop, self.stride = start, stop, stride
def apply(self, s: slice) -> FrameOrSeries:
"""Apply a slice operation, taking advantage of any potential fast paths."""
# normalise slice
self._slice_setup(s)
# check for fast-paths / single-operation calls
if self.slice_length == 0:
return self.obj.clear()
elif self.is_unbounded and self.stride in (-1, 1):
return self.obj.reverse() if (self.stride < 0) else self.obj.clone()
elif self.start >= 0 and self.stop >= 0 and self.stride == 1:
return self.obj.slice(self.start, self.slice_length)
elif self.stride < 0 and self.slice_length == 1:
return self.obj.slice(self.stop + 1, 1)
else:
# multi-operation calls; make lazy
lazyobj = self._lazify(self.obj)
sliced = (
self._slice_positive(lazyobj)
if self.stride > 0
else self._slice_negative(lazyobj)
)
return self._as_original(sliced, self.obj)
class LazyPolarsSlice:
"""
Apply python slice object to Polars LazyFrame.
Only slices with efficient computation paths that map directly
to existing lazy methods are supported.
"""
obj: LazyFrame
def __init__(self, obj: LazyFrame) -> None:
self.obj = obj
def apply(self, s: slice) -> LazyFrame:
"""
Apply a slice operation.
Note that LazyFrame is designed primarily for efficient computation and does not
know its own length so, unlike DataFrame, certain slice patterns (such as those
requiring negative stop/step) may not be supported.
"""
start = s.start or 0
step = s.step or 1
# fail on operations that require length to do efficiently
if s.stop and s.stop < 0:
msg = "negative stop is not supported for lazy slices"
raise ValueError(msg)
if step < 0 and (start > 0 or s.stop is not None) and (start != s.stop):
if not (start > 0 > step and s.stop is None):
msg = "negative stride is not supported in conjunction with start+stop"
raise ValueError(msg)
# ---------------------------------------
# empty slice patterns
# ---------------------------------------
# [:0]
# [i:<=i]
# [i:>=i:-k]
if (step > 0 and (s.stop is not None and start >= s.stop)) or (
step < 0
and (s.start is not None and s.stop is not None and s.stop >= s.start >= 0)
):
return self.obj.clear()
# ---------------------------------------
# straight-through mappings for "reverse"
# and/or "gather_every"
# ---------------------------------------
# [:] => clone()
# [::k] => gather_every(k),
# [::-1] => reverse(),
# [::-k] => reverse().gather_every(abs(k))
elif s.start is None and s.stop is None:
if step == 1:
return self.obj.clone()
elif step > 1:
return self.obj.gather_every(step)
elif step == -1:
return self.obj.reverse()
elif step < -1:
return self.obj.reverse().gather_every(abs(step))
# ---------------------------------------
# straight-through mappings for "head",
# "reverse" and "gather_every"
# ---------------------------------------
# [i::-1] => head(i+1).reverse()
# [i::k], k<-1 => head(i+1).reverse().gather_every(abs(k))
elif start >= 0 > step and s.stop is None:
obj = self.obj.head(s.start + 1).reverse()
return obj if (abs(step) == 1) else obj.gather_every(abs(step))
# ---------------------------------------
# straight-through mappings for "head"
# ---------------------------------------
# [:j] => head(j)
# [:j:k] => head(j).gather_every(k)
elif start == 0 and (s.stop or 0) >= 1:
obj = self.obj.head(s.stop)
return obj if (step == 1) else obj.gather_every(step)
# ---------------------------------------
# straight-through mappings for "tail"
# ---------------------------------------
# [-i:] => tail(abs(i))
# [-i::k] => tail(abs(i)).gather_every(k)
elif start < 0 and s.stop is None and step > 0:
obj = self.obj.tail(abs(start))
return obj if (step == 1) else obj.gather_every(step)
# ---------------------------------------
# straight-through mappings for "slice"
# ---------------------------------------
# [i:] => slice(i)
# [i:j] => slice(i,j-i)
# [i:j:k] => slice(i,j-i).gather_every(k)
elif start > 0 and (s.stop is None or s.stop >= 0):
slice_length = None if (s.stop is None) else (s.stop - start)
obj = self.obj.slice(start, slice_length)
return obj if (step == 1) else obj.gather_every(step)
msg = (
f"the given slice {s!r} is not supported by lazy computation"
"\n\nConsider a more efficient approach, or construct explicitly with other methods."
)
raise ValueError(msg)