704 lines
26 KiB
Python
704 lines
26 KiB
Python
from __future__ import annotations
|
|
|
|
import warnings
|
|
from typing import TYPE_CHECKING, Any, Callable, cast
|
|
|
|
import pandas as pd
|
|
|
|
from narwhals._compliant import DepthTrackingExpr, LazyExpr
|
|
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
|
|
from narwhals._dask.expr_str import DaskExprStringNamespace
|
|
from narwhals._dask.utils import (
|
|
add_row_index,
|
|
align_series_full_broadcast,
|
|
make_group_by_kwargs,
|
|
narwhals_to_native_dtype,
|
|
)
|
|
from narwhals._expression_parsing import evaluate_nodes, evaluate_output_names_and_aliases
|
|
from narwhals._pandas_like.expr import window_kwargs_to_pandas_equivalent
|
|
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
|
|
from narwhals._utils import (
|
|
Implementation,
|
|
generate_temporary_column_name,
|
|
no_default,
|
|
not_implemented,
|
|
)
|
|
from narwhals.exceptions import InvalidOperationError
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
import dask.dataframe.dask_expr as dx
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._compliant.typing import (
|
|
AliasNames,
|
|
EvalNames,
|
|
EvalSeries,
|
|
NarwhalsAggregation,
|
|
)
|
|
from narwhals._dask.dataframe import DaskLazyFrame
|
|
from narwhals._dask.namespace import DaskNamespace
|
|
from narwhals._typing import NoDefault
|
|
from narwhals._utils import Version, _LimitedContext
|
|
from narwhals.typing import (
|
|
FillNullStrategy,
|
|
IntoDType,
|
|
ModeKeepStrategy,
|
|
RollingInterpolationMethod,
|
|
)
|
|
|
|
|
|
def simple_aggregation(attr: str) -> Any:
|
|
# Translate an aggregation which differs from Dask in name only.
|
|
return lambda self, **kwargs: self._with_callable(
|
|
lambda expr: getattr(expr, attr)(**kwargs).to_series()
|
|
)
|
|
|
|
|
|
def simple_method(attr: str) -> Any:
|
|
# Translate a method which differs from Dask in name only.
|
|
return lambda self, **kwargs: self._with_callable(
|
|
lambda expr: getattr(expr, attr)(**kwargs)
|
|
)
|
|
|
|
|
|
def simple_binary(attr: str) -> Any:
|
|
# Translate a binary method which differs from Dask in name only.
|
|
return lambda self, other: self._binary_op(attr, other)
|
|
|
|
|
|
def trivial_binary_right(op: Callable[..., dx.Series]) -> Any:
|
|
# Translate a (right) binary method which differs from Dask in name only.
|
|
return lambda self, other: self._reverse_binary_op(op, other)
|
|
|
|
|
|
class DaskExpr(
|
|
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
|
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
|
):
|
|
_implementation: Implementation = Implementation.DASK
|
|
|
|
# Methods which are simple Narwhals->Dask translations.
|
|
# Keep this simple, resist the temptation to do anything complex or clever.
|
|
__add__ = simple_binary("__add__")
|
|
__sub__ = simple_binary("__sub__")
|
|
__mul__ = simple_binary("__mul__")
|
|
__truediv__ = simple_binary("__truediv__")
|
|
__pow__ = simple_binary("__pow__")
|
|
__mod__ = simple_binary("__mod__")
|
|
__eq__ = simple_binary("__eq__")
|
|
__ne__ = simple_binary("__ne__")
|
|
__gt__ = simple_binary("__gt__")
|
|
__ge__ = simple_binary("__ge__")
|
|
__lt__ = simple_binary("__lt__")
|
|
__le__ = simple_binary("__le__")
|
|
__and__ = simple_binary("__and__")
|
|
__or__ = simple_binary("__or__")
|
|
__rsub__ = trivial_binary_right(lambda x, y: x - y)
|
|
__rtruediv__ = trivial_binary_right(lambda x, y: x / y)
|
|
__rpow__ = trivial_binary_right(lambda x, y: x**y)
|
|
__rmod__ = trivial_binary_right(lambda x, y: x % y)
|
|
all = simple_aggregation("all")
|
|
any = simple_aggregation("any")
|
|
count = simple_aggregation("count")
|
|
kurtosis = simple_aggregation("kurtosis")
|
|
max = simple_aggregation("max")
|
|
mean = simple_aggregation("mean")
|
|
min = simple_aggregation("min")
|
|
skew = simple_aggregation("skew")
|
|
std = simple_aggregation("std")
|
|
sum = simple_aggregation("sum")
|
|
var = simple_aggregation("var")
|
|
__invert__ = simple_method("__invert__")
|
|
abs = simple_method("abs")
|
|
diff = simple_method("diff")
|
|
drop_nulls = simple_method("dropna")
|
|
is_null = simple_method("isna")
|
|
round = simple_method("round")
|
|
unique = simple_method("unique")
|
|
|
|
def __init__(
|
|
self,
|
|
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
|
|
*,
|
|
evaluate_output_names: EvalNames[DaskLazyFrame],
|
|
alias_output_names: AliasNames | None,
|
|
version: Version,
|
|
) -> None:
|
|
self._call = call
|
|
self._evaluate_output_names = evaluate_output_names
|
|
self._alias_output_names = alias_output_names
|
|
self._version = version
|
|
|
|
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
|
return self._call(df)
|
|
|
|
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
|
|
from narwhals._dask.namespace import DaskNamespace
|
|
|
|
return DaskNamespace(version=self._version)
|
|
|
|
def broadcast(self) -> Self:
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
|
|
# that raised a KeyError for result[0] during collection.
|
|
return [result.loc[0][0] for result in self(df)]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_names(
|
|
cls: type[Self],
|
|
evaluate_column_names: EvalNames[DaskLazyFrame],
|
|
/,
|
|
*,
|
|
context: _LimitedContext,
|
|
) -> Self:
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
try:
|
|
return [
|
|
df._native_frame[column_name]
|
|
for column_name in evaluate_column_names(df)
|
|
]
|
|
except KeyError as e:
|
|
if error := df._check_columns_exist(evaluate_column_names(df)):
|
|
raise error from e
|
|
raise
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=evaluate_column_names,
|
|
alias_output_names=None,
|
|
version=context._version,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
return [df.native.iloc[:, i] for i in column_indices]
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=cls._eval_names_indices(column_indices),
|
|
alias_output_names=None,
|
|
version=context._version,
|
|
)
|
|
|
|
def _with_callable(
|
|
self,
|
|
# First argument to `call` should be `dx.Series`
|
|
call: Callable[..., dx.Series],
|
|
/,
|
|
**expressifiable_args: Self,
|
|
) -> Self:
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
native_results: list[dx.Series] = []
|
|
native_series_list = self._call(df)
|
|
other_native_series = {
|
|
key: df._evaluate_single_output_expr(value)
|
|
for key, value in expressifiable_args.items()
|
|
}
|
|
for native_series in native_series_list:
|
|
result_native = call(native_series, **other_native_series)
|
|
native_results.append(result_native)
|
|
return native_results
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
|
current_alias_output_names = self._alias_output_names
|
|
alias_output_names = (
|
|
None
|
|
if func is None
|
|
else func
|
|
if current_alias_output_names is None
|
|
else lambda output_names: func(current_alias_output_names(output_names))
|
|
)
|
|
return type(self)(
|
|
call=self._call,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
def _with_binary(
|
|
self, call: Callable[[dx.Series, Any], dx.Series], other: Any
|
|
) -> Self:
|
|
return self._with_callable(lambda expr, other: call(expr, other), other=other)
|
|
|
|
def _binary_op(self, op_name: str, other: Any) -> Self:
|
|
return self._with_binary(lambda expr, other: getattr(expr, op_name)(other), other)
|
|
|
|
def _reverse_binary_op(
|
|
self, operator_func: Callable[..., dx.Series], other: Any
|
|
) -> Self:
|
|
return self._with_binary(
|
|
lambda expr, other: operator_func(other, expr), other
|
|
).alias("literal")
|
|
|
|
def __floordiv__(self, other: Any) -> Self:
|
|
def _floordiv(
|
|
df: DaskLazyFrame, series: dx.Series, other: dx.Series
|
|
) -> dx.Series:
|
|
series, other = align_series_full_broadcast(df, series, other)
|
|
return (series.__floordiv__(other)).where(other != 0, None)
|
|
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
other_series = df._evaluate_single_output_expr(other)
|
|
return [_floordiv(df, series, other_series) for series in self(df)]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
def __rfloordiv__(self, other: Any) -> Self:
|
|
def _rfloordiv(
|
|
df: DaskLazyFrame, series: dx.Series, other: dx.Series
|
|
) -> dx.Series:
|
|
series, other = align_series_full_broadcast(df, series, other)
|
|
return (other.__floordiv__(series)).where(series != 0, None)
|
|
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
other_native = df._evaluate_single_output_expr(other)
|
|
return [_rfloordiv(df, series, other_native) for series in self(df)]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
).alias("literal")
|
|
|
|
def median(self) -> Self:
|
|
from narwhals.exceptions import InvalidOperationError
|
|
|
|
def func(s: dx.Series) -> dx.Series:
|
|
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
|
|
if not dtype.is_numeric():
|
|
msg = "`median` operation not supported for non-numeric input type."
|
|
raise InvalidOperationError(msg)
|
|
return s.median_approximate().to_series()
|
|
|
|
return self._with_callable(func)
|
|
|
|
def shift(self, n: int) -> Self:
|
|
return self._with_callable(lambda expr: expr.shift(n))
|
|
|
|
def cum_sum(self, *, reverse: bool) -> Self:
|
|
if reverse: # pragma: no cover
|
|
# https://github.com/dask/dask/issues/11802
|
|
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(lambda expr: expr.cumsum())
|
|
|
|
def cum_count(self, *, reverse: bool) -> Self:
|
|
if reverse: # pragma: no cover
|
|
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(lambda expr: (~expr.isna()).astype(int).cumsum())
|
|
|
|
def cum_min(self, *, reverse: bool) -> Self:
|
|
if reverse: # pragma: no cover
|
|
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(lambda expr: expr.cummin())
|
|
|
|
def cum_max(self, *, reverse: bool) -> Self:
|
|
if reverse: # pragma: no cover
|
|
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(lambda expr: expr.cummax())
|
|
|
|
def cum_prod(self, *, reverse: bool) -> Self:
|
|
if reverse: # pragma: no cover
|
|
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(lambda expr: expr.cumprod())
|
|
|
|
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
|
return self._with_callable(
|
|
lambda expr: expr.rolling(
|
|
window=window_size, min_periods=min_samples, center=center
|
|
).sum()
|
|
)
|
|
|
|
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
|
return self._with_callable(
|
|
lambda expr: expr.rolling(
|
|
window=window_size, min_periods=min_samples, center=center
|
|
).mean()
|
|
)
|
|
|
|
def rolling_var(
|
|
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
|
) -> Self:
|
|
if ddof == 1:
|
|
return self._with_callable(
|
|
lambda expr: expr.rolling(
|
|
window=window_size, min_periods=min_samples, center=center
|
|
).var()
|
|
)
|
|
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
|
|
raise NotImplementedError(msg)
|
|
|
|
def rolling_std(
|
|
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
|
) -> Self:
|
|
if ddof == 1:
|
|
return self._with_callable(
|
|
lambda expr: expr.rolling(
|
|
window=window_size, min_periods=min_samples, center=center
|
|
).std()
|
|
)
|
|
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
|
|
raise NotImplementedError(msg)
|
|
|
|
def floor(self) -> Self:
|
|
import dask.array as da
|
|
|
|
return self._with_callable(da.floor)
|
|
|
|
def ceil(self) -> Self:
|
|
import dask.array as da
|
|
|
|
return self._with_callable(da.ceil)
|
|
|
|
def fill_nan(self, value: float | None) -> Self:
|
|
value_nullable = pd.NA if value is None else value
|
|
value_numpy = float("nan") if value is None else value
|
|
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
# If/when pandas exposes an API which distinguishes NaN vs null, use that.
|
|
mask = cast("dx.Series", expr != expr) # noqa: PLR0124
|
|
mask = mask.fillna(False)
|
|
fill = (
|
|
value_nullable
|
|
if get_dtype_backend(expr.dtype, self._implementation)
|
|
else value_numpy
|
|
)
|
|
return expr.mask(mask, fill) # pyright: ignore[reportArgumentType]
|
|
|
|
return self._with_callable(func)
|
|
|
|
def fill_null(
|
|
self, value: Self | None, strategy: FillNullStrategy | None, limit: int | None
|
|
) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
if value is not None:
|
|
res_ser = expr.fillna(value)
|
|
else:
|
|
res_ser = (
|
|
expr.ffill(limit=limit)
|
|
if strategy == "forward"
|
|
else expr.bfill(limit=limit)
|
|
)
|
|
return res_ser
|
|
|
|
return self._with_callable(func)
|
|
|
|
def clip(self, lower_bound: Self, upper_bound: Self) -> Self:
|
|
return self._with_callable(
|
|
lambda expr, lower_bound, upper_bound: expr.clip(
|
|
lower=lower_bound, upper=upper_bound
|
|
),
|
|
lower_bound=lower_bound,
|
|
upper_bound=upper_bound,
|
|
)
|
|
|
|
def clip_lower(self, lower_bound: Self) -> Self:
|
|
return self._with_callable(
|
|
lambda expr, lower_bound: expr.clip(lower=lower_bound),
|
|
lower_bound=lower_bound,
|
|
)
|
|
|
|
def clip_upper(self, upper_bound: Self) -> Self:
|
|
return self._with_callable(
|
|
lambda expr, upper_bound: expr.clip(upper=upper_bound),
|
|
upper_bound=upper_bound,
|
|
)
|
|
|
|
def n_unique(self) -> Self:
|
|
return self._with_callable(lambda expr: expr.nunique(dropna=False).to_series())
|
|
|
|
def is_nan(self) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
dtype = native_to_narwhals_dtype(
|
|
expr.dtype, self._version, self._implementation
|
|
)
|
|
if dtype.is_numeric():
|
|
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
|
|
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
|
|
raise InvalidOperationError(msg)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def len(self) -> Self:
|
|
return self._with_callable(lambda expr: expr.size.to_series())
|
|
|
|
def quantile(
|
|
self, quantile: float, interpolation: RollingInterpolationMethod
|
|
) -> Self:
|
|
if interpolation == "linear":
|
|
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
if expr.npartitions > 1:
|
|
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
|
|
raise NotImplementedError(msg)
|
|
return expr.quantile(
|
|
q=quantile, method="dask"
|
|
).to_series() # pragma: no cover
|
|
|
|
return self._with_callable(func)
|
|
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
|
|
raise NotImplementedError(msg)
|
|
|
|
def is_first_distinct(self) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
_name = expr.name
|
|
col_token = generate_temporary_column_name(
|
|
n_bytes=8, columns=[_name], prefix="row_index_"
|
|
)
|
|
frame = add_row_index(expr.to_frame(), col_token)
|
|
group_by_kwargs = make_group_by_kwargs(drop_null_keys=False)
|
|
first_distinct_index = frame.groupby(_name, **group_by_kwargs).agg(
|
|
{col_token: "min"}
|
|
)[col_token]
|
|
return frame[col_token].isin(first_distinct_index)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_last_distinct(self) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
_name = expr.name
|
|
col_token = generate_temporary_column_name(
|
|
n_bytes=8, columns=[_name], prefix="row_index_"
|
|
)
|
|
frame = add_row_index(expr.to_frame(), col_token)
|
|
group_by_kwargs = make_group_by_kwargs(drop_null_keys=False)
|
|
last_distinct_index = frame.groupby(_name, **group_by_kwargs).agg(
|
|
{col_token: "max"}
|
|
)[col_token]
|
|
return frame[col_token].isin(last_distinct_index)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_unique(self) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
_name = expr.name
|
|
group_by_kwargs = make_group_by_kwargs(drop_null_keys=False)
|
|
return (
|
|
expr.to_frame()
|
|
.groupby(_name, **group_by_kwargs)
|
|
.transform("size", meta=(_name, int))
|
|
== 1
|
|
)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_in(self, other: Any) -> Self:
|
|
return self._with_callable(lambda expr: expr.isin(other))
|
|
|
|
def null_count(self) -> Self:
|
|
return self._with_callable(lambda expr: expr.isna().sum().to_series())
|
|
|
|
def _over_without_partition_by(self, order_by: Sequence[str]) -> Self:
|
|
# This is something like `nw.col('a').cum_sum().order_by(key)`
|
|
# which we can always easily support, as it doesn't require grouping.
|
|
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
|
return self(df.sort(*order_by, descending=False, nulls_last=False))
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
|
if not partition_by:
|
|
assert order_by # noqa: S101
|
|
return self._over_without_partition_by(order_by)
|
|
# pandas is a required dependency of dask so it's safe to import this
|
|
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
|
|
|
# We have something like prev.leaf().over(...) (e.g. `nw.col('a').sum().over('b')`), where:
|
|
# - `prev` must be elementwise (in the example: `nw.col('a')`)
|
|
# - `leaf` must be a "simple" function, i.e. one that pandas supports in `transform`
|
|
# (in the example: `sum`)
|
|
#
|
|
# We first evaluate `prev` as-is, and then evaluate `leaf().over(...)`` by using `transform`
|
|
# or other DataFrameGroupBy methods.
|
|
meta = self._metadata
|
|
if partition_by and (
|
|
meta.prev is not None and not meta.prev.is_elementwise
|
|
): # pragma: no cover
|
|
msg = (
|
|
"Only elementary expressions are supported for `.over` in dask backend "
|
|
"when `partition_by` is specified.\n\n"
|
|
"Please see: "
|
|
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
|
)
|
|
raise NotImplementedError(msg)
|
|
|
|
if order_by:
|
|
# Wrong results https://github.com/dask/dask/issues/11806.
|
|
msg = "`over` with `order_by` is not yet supported in Dask."
|
|
raise NotImplementedError(msg)
|
|
|
|
nodes = list(reversed(list(self._metadata.iter_nodes_reversed())))
|
|
leaf_node = nodes[-1]
|
|
function_name = cast("NarwhalsAggregation", leaf_node.name)
|
|
try:
|
|
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
|
|
except KeyError:
|
|
# window functions are unsupported: https://github.com/dask/dask/issues/11806
|
|
msg = (
|
|
f"Unsupported function: {function_name} in `over` context.\n\n"
|
|
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
|
|
)
|
|
raise NotImplementedError(msg) from None
|
|
dask_kwargs = window_kwargs_to_pandas_equivalent(function_name, leaf_node.kwargs)
|
|
|
|
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
|
plx = self.__narwhals_namespace__()
|
|
if meta.prev is not None:
|
|
df = df.with_columns(cast("DaskExpr", evaluate_nodes(nodes[:-1], plx)))
|
|
_, aliases = evaluate_output_names_and_aliases(self, df, [])
|
|
|
|
with warnings.catch_warnings():
|
|
# https://github.com/dask/dask/issues/11804
|
|
warnings.filterwarnings(
|
|
"ignore", message=".*`meta` is not specified", category=UserWarning
|
|
)
|
|
group_by_kwargs = make_group_by_kwargs(drop_null_keys=False)
|
|
grouped = df.native.groupby(partition_by, **group_by_kwargs)
|
|
if dask_function_name == "size":
|
|
if len(aliases) != 1: # pragma: no cover
|
|
msg = "Safety check failed, please report a bug."
|
|
raise AssertionError(msg)
|
|
res_native = grouped.transform(
|
|
dask_function_name, **dask_kwargs
|
|
).to_frame(aliases[0])
|
|
else:
|
|
res_native = grouped[list(aliases)].transform(
|
|
dask_function_name, **dask_kwargs
|
|
)
|
|
result_frame = df._with_native(res_native).native
|
|
return [result_frame[name] for name in aliases]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
def cast(self, dtype: IntoDType) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
|
return expr.astype(native_dtype)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_finite(self) -> Self:
|
|
import dask.array as da
|
|
|
|
return self._with_callable(da.isfinite)
|
|
|
|
def log(self, base: float) -> Self:
|
|
import dask.array as da
|
|
|
|
def _log(expr: dx.Series) -> dx.Series:
|
|
return da.log(expr) / da.log(base)
|
|
|
|
return self._with_callable(_log)
|
|
|
|
def exp(self) -> Self:
|
|
import dask.array as da
|
|
|
|
return self._with_callable(da.exp)
|
|
|
|
def sqrt(self) -> Self:
|
|
import dask.array as da
|
|
|
|
return self._with_callable(da.sqrt)
|
|
|
|
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
|
def func(expr: dx.Series) -> dx.Series:
|
|
_name = expr.name
|
|
result = expr.to_frame().mode()[_name]
|
|
return result.head(1) if keep == "any" else result
|
|
|
|
return self._with_callable(func)
|
|
|
|
def replace_strict(
|
|
self,
|
|
default: DaskExpr | NoDefault,
|
|
old: Sequence[Any],
|
|
new: Sequence[Any],
|
|
*,
|
|
return_dtype: IntoDType | None,
|
|
) -> Self:
|
|
if default is no_default:
|
|
msg = "`replace_strict` requires an explicit value for `default` for dask backend."
|
|
raise ValueError(msg)
|
|
|
|
mapping = dict(zip(old, new))
|
|
old_ = list(old)
|
|
|
|
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
|
default_series = df._evaluate_single_output_expr(default)
|
|
results = [
|
|
series.replace(mapping).where(series.isin(old_), default_series)
|
|
for series in self(df)
|
|
]
|
|
|
|
if return_dtype:
|
|
native_dtype = narwhals_to_native_dtype(return_dtype, self._version)
|
|
return [res.astype(native_dtype) for res in results]
|
|
|
|
return results
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
version=self._version,
|
|
)
|
|
|
|
@property
|
|
def str(self) -> DaskExprStringNamespace:
|
|
return DaskExprStringNamespace(self)
|
|
|
|
@property
|
|
def dt(self) -> DaskExprDateTimeNamespace:
|
|
return DaskExprDateTimeNamespace(self)
|
|
|
|
filter = not_implemented()
|
|
first = not_implemented()
|
|
rank = not_implemented()
|
|
last = not_implemented()
|
|
|
|
# namespaces
|
|
list: not_implemented = not_implemented() # type: ignore[assignment]
|
|
struct: not_implemented = not_implemented() # type: ignore[assignment]
|