216 lines
7.8 KiB
Python
216 lines
7.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Union
|
|
|
|
import polars._reexport as pl
|
|
|
|
if TYPE_CHECKING:
|
|
from polars import DataFrame, LazyFrame, Series
|
|
|
|
FrameOrSeries = Union["DataFrame", "Series"]
|
|
|
|
|
|
class PolarsSlice:
|
|
"""
|
|
Apply Python slice object to Polars DataFrame or Series.
|
|
|
|
Has full support for negative indexing and/or stride.
|
|
"""
|
|
|
|
stop: int
|
|
start: int
|
|
stride: int
|
|
slice_length: int
|
|
is_unbounded: bool
|
|
obj: FrameOrSeries
|
|
|
|
def __init__(self, obj: FrameOrSeries) -> None:
|
|
self.obj = obj
|
|
|
|
@staticmethod
|
|
def _as_original(lazy: LazyFrame, original: FrameOrSeries) -> FrameOrSeries:
|
|
"""Return lazy variant back to its original type."""
|
|
frame = lazy.collect()
|
|
return frame if isinstance(original, pl.DataFrame) else frame.to_series()
|
|
|
|
@staticmethod
|
|
def _lazify(obj: FrameOrSeries) -> LazyFrame:
|
|
"""Make lazy to ensure efficient/consistent handling."""
|
|
return obj.to_frame().lazy() if isinstance(obj, pl.Series) else obj.lazy()
|
|
|
|
def _slice_positive(self, obj: LazyFrame) -> LazyFrame:
|
|
"""Logic for slices with positive stride."""
|
|
# note: at this point stride is guaranteed to be > 1
|
|
return obj.slice(self.start, self.slice_length).gather_every(self.stride)
|
|
|
|
def _slice_negative(self, obj: LazyFrame) -> LazyFrame:
|
|
"""Logic for slices with negative stride."""
|
|
stride = abs(self.stride)
|
|
lazyslice = obj.slice(self.stop + 1, self.slice_length).reverse()
|
|
return lazyslice.gather_every(stride) if (stride > 1) else lazyslice
|
|
|
|
def _slice_setup(self, s: slice) -> None:
|
|
"""Normalise slice bounds, identify unbounded and/or zero-length slices."""
|
|
# can normalise slice indices as we know object size
|
|
obj_len = len(self.obj)
|
|
start, stop, stride = slice(s.start, s.stop, s.step).indices(obj_len)
|
|
|
|
# check if slice is actually unbounded
|
|
if stride >= 1:
|
|
self.is_unbounded = (start <= 0) and (stop >= obj_len)
|
|
else:
|
|
self.is_unbounded = (stop == -1) and (start >= obj_len - 1)
|
|
|
|
# determine slice length
|
|
if self.obj.is_empty():
|
|
self.slice_length = 0
|
|
elif self.is_unbounded:
|
|
self.slice_length = obj_len
|
|
else:
|
|
self.slice_length = (
|
|
0
|
|
if (
|
|
(start == stop)
|
|
or (stride > 0 and start > stop)
|
|
or (stride < 0 and start < stop)
|
|
)
|
|
else abs(stop - start)
|
|
)
|
|
self.start, self.stop, self.stride = start, stop, stride
|
|
|
|
def apply(self, s: slice) -> FrameOrSeries:
|
|
"""Apply a slice operation, taking advantage of any potential fast paths."""
|
|
# normalise slice
|
|
self._slice_setup(s)
|
|
|
|
# check for fast-paths / single-operation calls
|
|
if self.slice_length == 0:
|
|
return self.obj.clear()
|
|
|
|
elif self.is_unbounded and self.stride in (-1, 1):
|
|
return self.obj.reverse() if (self.stride < 0) else self.obj.clone()
|
|
|
|
elif self.start >= 0 and self.stop >= 0 and self.stride == 1:
|
|
return self.obj.slice(self.start, self.slice_length)
|
|
|
|
elif self.stride < 0 and self.slice_length == 1:
|
|
return self.obj.slice(self.stop + 1, 1)
|
|
else:
|
|
# multi-operation calls; make lazy
|
|
lazyobj = self._lazify(self.obj)
|
|
sliced = (
|
|
self._slice_positive(lazyobj)
|
|
if self.stride > 0
|
|
else self._slice_negative(lazyobj)
|
|
)
|
|
return self._as_original(sliced, self.obj)
|
|
|
|
|
|
class LazyPolarsSlice:
|
|
"""
|
|
Apply python slice object to Polars LazyFrame.
|
|
|
|
Only slices with efficient computation paths that map directly
|
|
to existing lazy methods are supported.
|
|
"""
|
|
|
|
obj: LazyFrame
|
|
|
|
def __init__(self, obj: LazyFrame) -> None:
|
|
self.obj = obj
|
|
|
|
def apply(self, s: slice) -> LazyFrame:
|
|
"""
|
|
Apply a slice operation.
|
|
|
|
Note that LazyFrame is designed primarily for efficient computation and does not
|
|
know its own length so, unlike DataFrame, certain slice patterns (such as those
|
|
requiring negative stop/step) may not be supported.
|
|
"""
|
|
start = s.start or 0
|
|
step = s.step or 1
|
|
|
|
# fail on operations that require length to do efficiently
|
|
if s.stop and s.stop < 0:
|
|
msg = "negative stop is not supported for lazy slices"
|
|
raise ValueError(msg)
|
|
if step < 0 and (start > 0 or s.stop is not None) and (start != s.stop):
|
|
if not (start > 0 > step and s.stop is None):
|
|
msg = "negative stride is not supported in conjunction with start+stop"
|
|
raise ValueError(msg)
|
|
|
|
# ---------------------------------------
|
|
# empty slice patterns
|
|
# ---------------------------------------
|
|
# [:0]
|
|
# [i:<=i]
|
|
# [i:>=i:-k]
|
|
if (step > 0 and (s.stop is not None and start >= s.stop)) or (
|
|
step < 0
|
|
and (s.start is not None and s.stop is not None and s.stop >= s.start >= 0)
|
|
):
|
|
return self.obj.clear()
|
|
|
|
# ---------------------------------------
|
|
# straight-through mappings for "reverse"
|
|
# and/or "gather_every"
|
|
# ---------------------------------------
|
|
# [:] => clone()
|
|
# [::k] => gather_every(k),
|
|
# [::-1] => reverse(),
|
|
# [::-k] => reverse().gather_every(abs(k))
|
|
elif s.start is None and s.stop is None:
|
|
if step == 1:
|
|
return self.obj.clone()
|
|
elif step > 1:
|
|
return self.obj.gather_every(step)
|
|
elif step == -1:
|
|
return self.obj.reverse()
|
|
elif step < -1:
|
|
return self.obj.reverse().gather_every(abs(step))
|
|
|
|
# ---------------------------------------
|
|
# straight-through mappings for "head",
|
|
# "reverse" and "gather_every"
|
|
# ---------------------------------------
|
|
# [i::-1] => head(i+1).reverse()
|
|
# [i::k], k<-1 => head(i+1).reverse().gather_every(abs(k))
|
|
elif start >= 0 > step and s.stop is None:
|
|
obj = self.obj.head(s.start + 1).reverse()
|
|
return obj if (abs(step) == 1) else obj.gather_every(abs(step))
|
|
|
|
# ---------------------------------------
|
|
# straight-through mappings for "head"
|
|
# ---------------------------------------
|
|
# [:j] => head(j)
|
|
# [:j:k] => head(j).gather_every(k)
|
|
elif start == 0 and (s.stop or 0) >= 1:
|
|
obj = self.obj.head(s.stop)
|
|
return obj if (step == 1) else obj.gather_every(step)
|
|
|
|
# ---------------------------------------
|
|
# straight-through mappings for "tail"
|
|
# ---------------------------------------
|
|
# [-i:] => tail(abs(i))
|
|
# [-i::k] => tail(abs(i)).gather_every(k)
|
|
elif start < 0 and s.stop is None and step > 0:
|
|
obj = self.obj.tail(abs(start))
|
|
return obj if (step == 1) else obj.gather_every(step)
|
|
|
|
# ---------------------------------------
|
|
# straight-through mappings for "slice"
|
|
# ---------------------------------------
|
|
# [i:] => slice(i)
|
|
# [i:j] => slice(i,j-i)
|
|
# [i:j:k] => slice(i,j-i).gather_every(k)
|
|
elif start > 0 and (s.stop is None or s.stop >= 0):
|
|
slice_length = None if (s.stop is None) else (s.stop - start)
|
|
obj = self.obj.slice(start, slice_length)
|
|
return obj if (step == 1) else obj.gather_every(step)
|
|
|
|
msg = (
|
|
f"the given slice {s!r} is not supported by lazy computation"
|
|
"\n\nConsider a more efficient approach, or construct explicitly with other methods."
|
|
)
|
|
raise ValueError(msg)
|