240 lines
8.5 KiB
Python
240 lines
8.5 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, Any, Protocol, overload
|
|
|
|
from narwhals._compliant.typing import (
|
|
CompliantExprT,
|
|
CompliantFrameT,
|
|
CompliantLazyFrameT,
|
|
DepthTrackingExprT,
|
|
EagerDataFrameT,
|
|
EagerExprT,
|
|
EagerSeriesT_co,
|
|
LazyExprT,
|
|
NativeFrameT,
|
|
NativeSeriesT,
|
|
)
|
|
from narwhals._utils import (
|
|
exclude_column_names,
|
|
get_column_names,
|
|
passthrough_column_names,
|
|
)
|
|
from narwhals.dependencies import is_numpy_array_2d
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable, Sequence
|
|
|
|
from typing_extensions import TypeAlias, TypeIs
|
|
|
|
from narwhals._compliant.selectors import CompliantSelectorNamespace
|
|
from narwhals._utils import Implementation, Version
|
|
from narwhals.typing import (
|
|
ConcatMethod,
|
|
Into1DArray,
|
|
IntoDType,
|
|
IntoSchema,
|
|
NonNestedLiteral,
|
|
_2DArray,
|
|
)
|
|
|
|
Incomplete: TypeAlias = Any
|
|
|
|
__all__ = [
|
|
"CompliantNamespace",
|
|
"DepthTrackingNamespace",
|
|
"EagerNamespace",
|
|
"LazyNamespace",
|
|
]
|
|
|
|
|
|
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
|
|
# NOTE: `narwhals`
|
|
_implementation: Implementation
|
|
_version: Version
|
|
|
|
@property
|
|
def _expr(self) -> type[CompliantExprT]: ...
|
|
# NOTE: `polars`
|
|
def all(self) -> CompliantExprT:
|
|
return self._expr.from_column_names(get_column_names, context=self)
|
|
|
|
def col(self, *names: str) -> CompliantExprT:
|
|
return self._expr.from_column_names(passthrough_column_names(names), context=self)
|
|
|
|
def exclude(self, *names: str) -> CompliantExprT:
|
|
return self._expr.from_column_names(
|
|
partial(exclude_column_names, names=names), context=self
|
|
)
|
|
|
|
def nth(self, indices: Sequence[int]) -> CompliantExprT:
|
|
return self._expr.from_column_indices(*indices, context=self)
|
|
|
|
def len(self) -> CompliantExprT: ...
|
|
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
|
|
def all_horizontal(
|
|
self, *exprs: CompliantExprT, ignore_nulls: bool
|
|
) -> CompliantExprT: ...
|
|
def any_horizontal(
|
|
self, *exprs: CompliantExprT, ignore_nulls: bool
|
|
) -> CompliantExprT: ...
|
|
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
|
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
|
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
|
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
|
def concat(
|
|
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
|
|
) -> CompliantFrameT: ...
|
|
def concat_str(
|
|
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
|
|
) -> CompliantExprT: ...
|
|
@property
|
|
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
|
|
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
|
# NOTE: typing this accurately requires 2x more `TypeVar`s
|
|
def from_native(self, data: Any, /) -> Any: ...
|
|
def is_native(self, obj: Any, /) -> TypeIs[Any]:
|
|
"""Return `True` if `obj` can be passed to `from_native`."""
|
|
...
|
|
|
|
|
|
class DepthTrackingNamespace(
|
|
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
|
|
Protocol[CompliantFrameT, DepthTrackingExprT],
|
|
):
|
|
def all(self) -> DepthTrackingExprT:
|
|
return self._expr.from_column_names(get_column_names, context=self)
|
|
|
|
def col(self, *names: str) -> DepthTrackingExprT:
|
|
return self._expr.from_column_names(passthrough_column_names(names), context=self)
|
|
|
|
def exclude(self, *names: str) -> DepthTrackingExprT:
|
|
return self._expr.from_column_names(
|
|
partial(exclude_column_names, names=names), context=self
|
|
)
|
|
|
|
|
|
class LazyNamespace(
|
|
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
|
|
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT],
|
|
):
|
|
@property
|
|
def _backend_version(self) -> tuple[int, ...]:
|
|
return self._implementation._backend_version()
|
|
|
|
@property
|
|
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
|
|
def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT]:
|
|
return self._lazyframe._is_native(obj)
|
|
|
|
def from_native(self, data: NativeFrameT | Any, /) -> CompliantLazyFrameT:
|
|
if self._lazyframe._is_native(data):
|
|
return self._lazyframe.from_native(data, context=self)
|
|
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
|
|
raise TypeError(msg)
|
|
|
|
|
|
class EagerNamespace(
|
|
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
|
|
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT, NativeFrameT, NativeSeriesT],
|
|
):
|
|
@property
|
|
def _backend_version(self) -> tuple[int, ...]:
|
|
return self._implementation._backend_version()
|
|
|
|
@property
|
|
def _dataframe(self) -> type[EagerDataFrameT]: ...
|
|
@property
|
|
def _series(self) -> type[EagerSeriesT_co]: ...
|
|
def _if_then_else(
|
|
self,
|
|
when: NativeSeriesT,
|
|
then: NativeSeriesT,
|
|
otherwise: NativeSeriesT | None = None,
|
|
) -> NativeSeriesT: ...
|
|
def when_then(
|
|
self, predicate: EagerExprT, then: EagerExprT, otherwise: EagerExprT | None = None
|
|
) -> EagerExprT:
|
|
def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]:
|
|
predicate_s = df._evaluate_single_output_expr(predicate)
|
|
align = predicate_s._align_full_broadcast
|
|
|
|
then_s = df._evaluate_single_output_expr(then)
|
|
if otherwise is None:
|
|
predicate_s, then_s = align(predicate_s, then_s)
|
|
result = self._if_then_else(predicate_s.native, then_s.native)
|
|
|
|
if otherwise is None:
|
|
predicate_s, then_s = align(predicate_s, then_s)
|
|
result = self._if_then_else(predicate_s.native, then_s.native)
|
|
else:
|
|
otherwise_s = df._evaluate_single_output_expr(otherwise)
|
|
predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s)
|
|
result = self._if_then_else(
|
|
predicate_s.native, then_s.native, otherwise_s.native
|
|
)
|
|
return [then_s._with_native(result)]
|
|
|
|
return self._expr._from_callable(
|
|
func=func,
|
|
evaluate_output_names=getattr(
|
|
then, "_evaluate_output_names", lambda _df: ["literal"]
|
|
),
|
|
alias_output_names=getattr(then, "_alias_output_names", None),
|
|
context=predicate,
|
|
)
|
|
|
|
def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT | NativeSeriesT]:
|
|
return self._dataframe._is_native(obj) or self._series._is_native(obj)
|
|
|
|
@overload
|
|
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
|
|
@overload
|
|
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT_co: ...
|
|
def from_native(
|
|
self, data: NativeFrameT | NativeSeriesT | Any, /
|
|
) -> EagerDataFrameT | EagerSeriesT_co:
|
|
if self._dataframe._is_native(data):
|
|
return self._dataframe.from_native(data, context=self)
|
|
if self._series._is_native(data):
|
|
return self._series.from_native(data, context=self)
|
|
msg = f"Unsupported type: {type(data).__name__!r}"
|
|
raise TypeError(msg)
|
|
|
|
@overload
|
|
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT_co: ...
|
|
|
|
@overload
|
|
def from_numpy(
|
|
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
|
|
) -> EagerDataFrameT: ...
|
|
|
|
def from_numpy(
|
|
self,
|
|
data: Into1DArray | _2DArray,
|
|
/,
|
|
schema: IntoSchema | Sequence[str] | None = None,
|
|
) -> EagerDataFrameT | EagerSeriesT_co:
|
|
if is_numpy_array_2d(data):
|
|
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
|
return self._series.from_numpy(data, context=self)
|
|
|
|
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
|
def _concat_horizontal(
|
|
self, dfs: Sequence[NativeFrameT | Any], /
|
|
) -> NativeFrameT: ...
|
|
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
|
def concat(
|
|
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
|
|
) -> EagerDataFrameT:
|
|
dfs = [item.native for item in items]
|
|
if how == "horizontal":
|
|
native = self._concat_horizontal(dfs)
|
|
elif how == "vertical":
|
|
native = self._concat_vertical(dfs)
|
|
elif how == "diagonal":
|
|
native = self._concat_diagonal(dfs)
|
|
else: # pragma: no cover
|
|
raise NotImplementedError
|
|
return self._dataframe.from_native(native, context=self)
|