DriverTrac/venv/lib/python3.12/site-packages/narwhals/_expression_parsing.py
2025-11-28 09:08:33 +05:30

907 lines
32 KiB
Python

# Utilities for expression parsing
# Useful for backends which don't have any concept of expressions, such
# and pandas or PyArrow.
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
from __future__ import annotations
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
from narwhals._utils import zip_strict
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import (
InvalidIntoExprError,
InvalidOperationError,
MultiOutputExpressionError,
)
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from typing_extensions import Never, TypeIs
from narwhals._compliant import CompliantExpr, CompliantFrameT
from narwhals._compliant.typing import (
AliasNames,
CompliantExprAny,
CompliantFrameAny,
CompliantNamespaceAny,
EvalNames,
)
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray
def is_expr(obj: Any) -> TypeIs[Expr]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.expr import Expr
return isinstance(obj, Expr)
def is_series(obj: Any) -> TypeIs[Series[Any]]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.series import Series
return isinstance(obj, Series)
def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]
return evaluate_output_names
def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
if exprs[0]._alias_output_names is None:
return None
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc]
return alias_output_names
def evaluate_output_names_and_aliases(
expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
) -> tuple[Sequence[str], Sequence[str]]:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if exclude and expr._metadata.expansion_kind.is_multi_unnamed():
output_names, aliases = zip_strict(
*[
(x, alias)
for x, alias in zip_strict(output_names, aliases)
if x not in exclude
]
)
return output_names, aliases
class ExprKind(Enum):
"""Describe which kind of expression we are dealing with."""
LITERAL = auto()
"""e.g. `nw.lit(1)`"""
AGGREGATION = auto()
"""Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""
ORDERABLE_AGGREGATION = auto()
"""Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""
ELEMENTWISE = auto()
"""Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""
ORDERABLE_WINDOW = auto()
"""Depends on the rows around it and on their order, e.g. `diff`."""
WINDOW = auto()
"""Depends on the rows around it and possibly their order, e.g. `rank`."""
FILTRATION = auto()
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
ORDERABLE_FILTRATION = auto()
"""Changes length, affected by row order, e.g. `tail`."""
OVER = auto()
"""Results from calling `.over` on expression."""
COL = auto()
"""Results from calling `nw.col`."""
NTH = auto()
"""Results from calling `nw.nth`."""
EXCLUDE = auto()
"""Results from calling `nw.exclude`."""
ALL = auto()
"""Results from calling `nw.all`."""
SELECTOR = auto()
"""Results from creating an expression with a selector."""
WHEN_THEN = auto()
"""Results from `when/then expression`, possibly followed by `otherwise`."""
SERIES = auto()
"""Results from converting a Series to Expr."""
@property
def is_orderable(self) -> bool:
# Any operation which may be affected by `order_by`, such as `cum_sum`,
# `diff`, `rank`, `arg_max`, ...
return self in {
ExprKind.ORDERABLE_WINDOW,
ExprKind.WINDOW,
ExprKind.ORDERABLE_AGGREGATION,
ExprKind.ORDERABLE_FILTRATION,
}
@property
def is_elementwise(self) -> bool:
# Any operation which can operate on each row independently
# of the rows around it, e.g. `abs(), __add__, sum_horizontal, ...`
return self in {
ExprKind.ALL,
ExprKind.COL,
ExprKind.ELEMENTWISE,
ExprKind.EXCLUDE,
ExprKind.LITERAL,
ExprKind.NTH,
ExprKind.SELECTOR,
ExprKind.SERIES,
ExprKind.WHEN_THEN,
}
@property
def is_scalar_like(self) -> bool:
return self in {
ExprKind.AGGREGATION,
ExprKind.LITERAL,
ExprKind.ORDERABLE_AGGREGATION,
}
def is_scalar_like(obj: CompliantExprAny) -> bool:
return obj._metadata.is_scalar_like
class ExpansionKind(Enum):
"""Describe what kind of expansion the expression performs."""
SINGLE = auto()
"""e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""
MULTI_NAMED = auto()
"""e.g. `nw.col('a', 'b')`"""
MULTI_UNNAMED = auto()
"""e.g. `nw.all()`, nw.nth(0, 1)"""
def is_multi_unnamed(self) -> bool:
return self is ExpansionKind.MULTI_UNNAMED
def is_multi_output(self) -> bool:
return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}
def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
# e.g. nw.selectors.all() - nw.selectors.numeric().
return ExpansionKind.MULTI_UNNAMED
# Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover
raise AssertionError(msg) # pragma: no cover
class ExprNode:
"""An operation to create or modify an expression.
Parameters:
kind: ExprKind of operation.
name: Name of function, as defined in the compliant protocols.
exprs: Expressifiable arguments to function.
str_as_lit: Whether to interpret strings as literals when they
are present in `exprs`.
allow_multi_output: Whether to allow any of `exprs` to be multi-output.
kwargs: Other (non-expressifiable) arguments to function.
"""
def __init__(
self,
kind: ExprKind,
name: str,
/,
*exprs: IntoExpr | NonNestedLiteral,
str_as_lit: bool = False,
allow_multi_output: bool = False,
**kwargs: Any,
) -> None:
self.kind: ExprKind = kind
self.name: str = name
self.exprs: Sequence[IntoExpr | NonNestedLiteral] = exprs
self.kwargs: dict[str, Any] = kwargs
self.str_as_lit: bool = str_as_lit
self.allow_multi_output: bool = allow_multi_output
# Cached methods.
self._is_orderable_cached: bool | None = None
self._is_elementwise_cached: bool | None = None
def __repr__(self) -> str:
if self.name == "col":
names = ", ".join(str(x) for x in self.kwargs["names"])
return f"col({names})"
arg_str = []
expr_repr = ", ".join(str(x) for x in self.exprs)
kwargs_repr = ", ".join(f"{key}={value}" for key, value in self.kwargs.items())
if self.exprs:
arg_str.append(expr_repr)
if self.kwargs:
arg_str.append(kwargs_repr)
return f"{self.name}({', '.join(arg_str)})"
def as_dict(self) -> dict[str, Any]: # pragma: no cover
# Just for debugging.
return {
"kind": self.kind,
"name": self.name,
"exprs": self.exprs,
"kwargs": self.kwargs,
"str_as_lit": self.str_as_lit,
"allow_multi_output": self.allow_multi_output,
}
def _with_kwargs(self, **kwargs: Any) -> ExprNode:
return self.__class__(
self.kind, self.name, *self.exprs, str_as_lit=self.str_as_lit, **kwargs
)
def _push_down_over_node_in_place(
self, over_node: ExprNode, over_node_without_order_by: ExprNode
) -> None:
exprs: list[IntoExpr | NonNestedLiteral] = []
# Note: please keep this as a for-loop (rather than a list-comprehension)
# so that pytest-cov highlights any uncovered branches.
over_node_order_by = over_node.kwargs["order_by"]
over_node_partition_by = over_node.kwargs["partition_by"]
for expr in self.exprs:
if not is_expr(expr):
exprs.append(expr)
elif over_node_order_by and any(
expr_node.is_orderable() for expr_node in expr._nodes
):
exprs.append(expr._with_over_node(over_node))
elif over_node_partition_by and not all(
expr_node.is_elementwise() for expr_node in expr._nodes
):
exprs.append(expr._with_over_node(over_node_without_order_by))
else:
# If there's no `partition_by`, then `over_node_without_order_by` is a no-op.
exprs.append(expr)
self.exprs = exprs
def is_orderable(self) -> bool:
if self._is_orderable_cached is None:
# Note: don't combine these if/then statements so that pytest-cov shows if
# anything is uncovered.
if self.kind.is_orderable: # noqa: SIM114
self._is_orderable_cached = True
elif any(
any(node.is_orderable() for node in expr._nodes)
for expr in self.exprs
if is_expr(expr)
):
self._is_orderable_cached = True
else:
self._is_orderable_cached = False
return self._is_orderable_cached
def is_elementwise(self) -> bool:
if self._is_elementwise_cached is None:
# Note: don't combine these if/then statements so that pytest-cov shows if
# anything is uncovered.
if not self.kind.is_elementwise: # noqa: SIM114
self._is_elementwise_cached = False
elif any(
any(not node.is_elementwise() for node in expr._nodes)
for expr in self.exprs
if is_expr(expr)
):
self._is_elementwise_cached = False
else:
self._is_elementwise_cached = True
return self._is_elementwise_cached
class ExprMetadata:
"""Expression metadata.
Parameters:
expansion_kind: What kind of expansion the expression performs.
has_windows: Whether it already contains window functions.
is_elementwise: Whether it can operate row-by-row without context
of the other rows around it.
is_literal: Whether it is just a literal wrapped in an expression.
is_scalar_like: Whether it is a literal or an aggregation.
n_orderable_ops: The number of order-dependent operations. In the
lazy case, this number must be `0` by the time the expression
is evaluated.
preserves_length: Whether the expression preserves the input length.
current_node: The current ExprNode in the linked list.
prev: Reference to the previous ExprMetadata in the linked list (None for root).
"""
__slots__ = (
"current_node",
"expansion_kind",
"has_windows",
"is_elementwise",
"is_literal",
"is_scalar_like",
"n_orderable_ops",
"preserves_length",
"prev",
)
def __init__(
self,
expansion_kind: ExpansionKind,
*,
has_windows: bool = False,
n_orderable_ops: int = 0,
preserves_length: bool = True,
is_elementwise: bool = True,
is_scalar_like: bool = False,
is_literal: bool = False,
current_node: ExprNode,
prev: ExprMetadata | None = None,
) -> None:
if is_literal:
assert is_scalar_like # noqa: S101 # debug assertion
self.expansion_kind: ExpansionKind = expansion_kind
self.has_windows: bool = has_windows
self.n_orderable_ops: int = n_orderable_ops
self.is_elementwise: bool = is_elementwise
self.preserves_length: bool = preserves_length
self.is_scalar_like: bool = is_scalar_like
self.is_literal: bool = is_literal
self.current_node: ExprNode = current_node
self.prev: ExprMetadata | None = prev
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
msg = f"Cannot subclass {cls.__name__!r}"
raise TypeError(msg)
def __repr__(self) -> str: # pragma: no cover
nodes = tuple(reversed(tuple(self.iter_nodes_reversed())))
return (
f"ExprMetadata(\n"
f" expansion_kind: {self.expansion_kind},\n"
f" has_windows: {self.has_windows},\n"
f" n_orderable_ops: {self.n_orderable_ops},\n"
f" is_elementwise: {self.is_elementwise},\n"
f" preserves_length: {self.preserves_length},\n"
f" is_scalar_like: {self.is_scalar_like},\n"
f" is_literal: {self.is_literal},\n"
f" nodes: {nodes},\n"
")"
)
def iter_nodes_reversed(self) -> Iterator[ExprNode]:
"""Iterate through all nodes from current to root."""
current: ExprMetadata | None = self
while current is not None:
yield current.current_node
current = current.prev
@classmethod
def from_node(
cls, node: ExprNode, *compliant_exprs: CompliantExprAny
) -> ExprMetadata:
return KIND_TO_METADATA_CONSTRUCTOR[node.kind](node, *compliant_exprs)
def with_node(
self,
node: ExprNode,
compliant_expr: CompliantExprAny,
*compliant_expr_args: CompliantExprAny,
) -> ExprMetadata:
return KIND_TO_METADATA_UPDATER[node.kind](
self, node, compliant_expr, *compliant_expr_args
)
@classmethod
def from_aggregation(cls, node: ExprNode) -> ExprMetadata:
return cls(
ExpansionKind.SINGLE,
is_elementwise=False,
preserves_length=False,
is_scalar_like=True,
current_node=node,
prev=None,
)
@classmethod
def from_literal(cls, node: ExprNode) -> ExprMetadata:
return cls(
ExpansionKind.SINGLE,
is_elementwise=True,
preserves_length=False,
is_literal=True,
is_scalar_like=True,
current_node=node,
prev=None,
)
@classmethod
def from_series(cls, node: ExprNode) -> ExprMetadata:
return cls(ExpansionKind.SINGLE, current_node=node, prev=None)
@classmethod
def from_col(cls, node: ExprNode) -> ExprMetadata:
# e.g. `nw.col('a')`, `nw.nth(0)`
return (
cls(ExpansionKind.SINGLE, current_node=node, prev=None)
if len(node.kwargs["names"]) == 1
else cls.from_selector_multi_named(node)
)
@classmethod
def from_nth(cls, node: ExprNode) -> ExprMetadata:
return (
cls(ExpansionKind.SINGLE, current_node=node, prev=None)
if len(node.kwargs["indices"]) == 1
else cls.from_selector_multi_named(node)
)
@classmethod
def from_selector_multi_named(cls, node: ExprNode) -> ExprMetadata:
# e.g. `nw.col('a', 'b')`
return cls(ExpansionKind.MULTI_NAMED, current_node=node, prev=None)
@classmethod
def from_selector_multi_unnamed(cls, node: ExprNode) -> ExprMetadata:
# e.g. `nw.all()`
return cls(ExpansionKind.MULTI_UNNAMED, current_node=node, prev=None)
@classmethod
def from_elementwise(
cls, node: ExprNode, *compliant_exprs: CompliantExprAny
) -> ExprMetadata:
return combine_metadata(
*compliant_exprs, to_single_output=True, current_node=node, prev=None
)
@property
def is_filtration(self) -> bool:
return not self.preserves_length and not self.is_scalar_like
def with_aggregation(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
current_node=node,
prev=self,
)
def with_orderable_aggregation(
self, node: ExprNode, _ce: CompliantExprAny
) -> ExprMetadata:
# Deprecated, used only in stable.v1.
if self.is_scalar_like: # pragma: no cover
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
current_node=node,
prev=self,
)
def with_elementwise(
self,
node: ExprNode,
compliant_expr: CompliantExprAny,
*compliant_expr_args: CompliantExprAny,
) -> ExprMetadata:
return combine_metadata(
compliant_expr,
*compliant_expr_args,
to_single_output=False,
current_node=node,
prev=compliant_expr._metadata,
)
def with_window(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata:
# Window function which may (but doesn't have to) be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply window (e.g. `rank`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=self.has_windows,
# The function isn't order-dependent (but, users can still use `order_by` if they wish!),
# so we don't increment `n_orderable_ops`.
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def with_orderable_window(
self, node: ExprNode, _ce: CompliantExprAny
) -> ExprMetadata:
# Window function which must be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def with_ordered_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
n_orderable_ops = self.n_orderable_ops
if (
not n_orderable_ops
and next(self.op_nodes_reversed()).kind is not ExprKind.WINDOW
):
msg = (
"Cannot use `order_by` in `over` on expression which isn't orderable.\n"
"If your expression is orderable, then make sure that `over(order_by=...)`\n"
"comes immediately after the order-dependent expression.\n\n"
"Hint: instead of\n"
" - `(nw.col('price').diff() + 1).over(order_by='date')`\n"
"write:\n"
" + `nw.col('price').diff().over(order_by='date') + 1`\n"
)
raise InvalidOperationError(msg)
if next(self.op_nodes_reversed()).kind.is_orderable and n_orderable_ops > 0:
n_orderable_ops -= 1
return ExprMetadata(
self.expansion_kind,
has_windows=True,
n_orderable_ops=n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def with_partitioned_over(
self, node: ExprNode, _ce: CompliantExprAny
) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=True,
n_orderable_ops=self.n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def with_over(self, node: ExprNode, _ce: CompliantExprAny) -> ExprMetadata:
if node.kwargs["order_by"]:
return self.with_ordered_over(node, _ce)
if not node.kwargs["partition_by"]: # pragma: no cover
msg = "At least one of `partition_by` or `order_by` must be specified."
raise InvalidOperationError(msg)
return self.with_partitioned_over(node, _ce)
def with_filtration(
self, node: ExprNode, *compliant_exprs: CompliantExprAny
) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
result_has_windows = any(x._metadata.has_windows for x in compliant_exprs)
result_n_orderable_ops = sum(x._metadata.n_orderable_ops for x in compliant_exprs)
return ExprMetadata(
self.expansion_kind,
has_windows=result_has_windows,
n_orderable_ops=result_n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def with_orderable_filtration(
self, node: ExprNode, _ce: CompliantExprAny
) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
current_node=node,
prev=self,
)
def op_nodes_reversed(self) -> Iterator[ExprNode]:
for node in self.iter_nodes_reversed():
if node.name.startswith(("name.", "alias")):
# Skip nodes which only do aliasing.
continue
yield node
KIND_TO_METADATA_CONSTRUCTOR: dict[ExprKind, Callable[[ExprNode], ExprMetadata]] = {
ExprKind.AGGREGATION: ExprMetadata.from_aggregation,
ExprKind.ALL: ExprMetadata.from_selector_multi_unnamed,
ExprKind.ELEMENTWISE: ExprMetadata.from_elementwise,
ExprKind.EXCLUDE: ExprMetadata.from_selector_multi_unnamed,
ExprKind.SERIES: ExprMetadata.from_series,
ExprKind.COL: ExprMetadata.from_col,
ExprKind.LITERAL: ExprMetadata.from_literal,
ExprKind.NTH: ExprMetadata.from_nth,
ExprKind.SELECTOR: ExprMetadata.from_selector_multi_unnamed,
}
KIND_TO_METADATA_UPDATER: dict[ExprKind, Callable[..., ExprMetadata]] = {
ExprKind.AGGREGATION: ExprMetadata.with_aggregation,
ExprKind.ELEMENTWISE: ExprMetadata.with_elementwise,
ExprKind.FILTRATION: ExprMetadata.with_filtration,
ExprKind.ORDERABLE_AGGREGATION: ExprMetadata.with_orderable_aggregation,
ExprKind.ORDERABLE_FILTRATION: ExprMetadata.with_orderable_filtration,
ExprKind.OVER: ExprMetadata.with_over,
ExprKind.ORDERABLE_WINDOW: ExprMetadata.with_orderable_window,
ExprKind.WINDOW: ExprMetadata.with_window,
}
def combine_metadata(
*compliant_exprs: CompliantExprAny,
to_single_output: bool,
current_node: ExprNode,
prev: ExprMetadata | None,
) -> ExprMetadata:
"""Combine metadata from `args`.
Arguments:
compliant_exprs: Expression arguments.
to_single_output: Whether the result is always single-output, regardless
of the inputs (e.g. `nw.sum_horizontal`).
current_node: The current node being added.
prev: ExprMetadata of previous node.
"""
n_filtrations = 0
result_expansion_kind = ExpansionKind.SINGLE
result_has_windows = False
result_n_orderable_ops = 0
# result preserves length if at least one input does
result_preserves_length = False
# result is elementwise if all inputs are elementwise
result_is_elementwise = True
# result is scalar-like if all inputs are scalar-like
result_is_scalar_like = True
# result is literal if all inputs are literal
result_is_literal = True
for i, ce in enumerate(compliant_exprs):
metadata = ce._metadata
assert metadata is not None # noqa: S101
if metadata.expansion_kind.is_multi_output():
expansion_kind = metadata.expansion_kind
if not to_single_output:
result_expansion_kind = (
result_expansion_kind & expansion_kind if i > 0 else expansion_kind
)
result_has_windows |= metadata.has_windows
result_n_orderable_ops += metadata.n_orderable_ops
result_preserves_length |= metadata.preserves_length
result_is_elementwise &= metadata.is_elementwise
result_is_scalar_like &= metadata.is_scalar_like
result_is_literal &= metadata.is_literal
n_filtrations += int(metadata.is_filtration)
if n_filtrations > 1:
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
raise InvalidOperationError(msg)
if result_preserves_length and n_filtrations:
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
raise InvalidOperationError(msg)
return ExprMetadata(
result_expansion_kind,
has_windows=result_has_windows,
n_orderable_ops=result_n_orderable_ops,
preserves_length=result_preserves_length,
is_elementwise=result_is_elementwise,
is_scalar_like=result_is_scalar_like,
is_literal=result_is_literal,
current_node=current_node,
prev=prev,
)
def check_expressions_preserve_length(
*args: CompliantExprAny, function_name: str
) -> None:
# Raise if any argument in `args` isn't length-preserving.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
if not all(x._metadata.preserves_length for x in args):
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
raise InvalidOperationError(msg)
def _parse_into_expr(
arg: IntoExpr | NonNestedLiteral | _1DArray,
*,
str_as_lit: bool = False,
backend: Any = None,
allow_literal: bool = True,
) -> Expr:
from narwhals.functions import col, lit, new_series
if isinstance(arg, str) and not str_as_lit:
return col(arg)
if is_numpy_array_1d(arg):
return new_series("", arg, backend=backend)._to_expr()
if is_series(arg):
return arg._to_expr()
if is_expr(arg):
return arg
if not allow_literal:
raise InvalidIntoExprError.from_invalid_type(type(arg))
return lit(arg)
def evaluate_into_exprs(
*exprs: IntoExpr | NonNestedLiteral | _1DArray,
ns: CompliantNamespaceAny,
str_as_lit: bool,
allow_multi_output: bool,
) -> Iterator[CompliantExprAny]:
for expr in exprs:
ret = _parse_into_expr(
expr, str_as_lit=str_as_lit, backend=ns._implementation
)._to_compliant_expr(ns)
if not allow_multi_output and ret._metadata.expansion_kind.is_multi_output():
msg = "Multi-output expressions are not allowed in this context."
raise MultiOutputExpressionError(msg)
yield ret
def maybe_broadcast_ces(*compliant_exprs: CompliantExprAny) -> list[CompliantExprAny]:
broadcast = any(not is_scalar_like(ce) for ce in compliant_exprs)
results: list[CompliantExprAny] = []
for compliant_expr in compliant_exprs:
if broadcast and is_scalar_like(compliant_expr):
_compliant_expr: CompliantExprAny = compliant_expr.broadcast()
# Make sure to preserve metadata.
_compliant_expr._opt_metadata = compliant_expr._metadata
results.append(_compliant_expr)
else:
results.append(compliant_expr)
return results
def evaluate_root_node(node: ExprNode, ns: CompliantNamespaceAny) -> CompliantExprAny:
if node.name in {"col", "exclude"}:
# There's too much potential for Sequence[str] vs str bugs, so we pass down
# `names` positionally rather than as a sequence of strings.
ce = getattr(ns, node.name)(*node.kwargs["names"])
ces = []
else:
if "." in node.name:
module, method = node.name.split(".")
func = getattr(getattr(ns, module), method)
else:
func = getattr(ns, node.name)
ces = maybe_broadcast_ces(
*evaluate_into_exprs(
*node.exprs,
ns=ns,
str_as_lit=node.str_as_lit,
allow_multi_output=node.allow_multi_output,
)
)
ce = cast("CompliantExprAny", func(*ces, **node.kwargs))
md = ExprMetadata.from_node(node, *ces)
ce._opt_metadata = md
return ce
def evaluate_node(
compliant_expr: CompliantExprAny, node: ExprNode, ns: CompliantNamespaceAny
) -> CompliantExprAny:
md: ExprMetadata = compliant_expr._metadata
compliant_expr, *compliant_expr_args = maybe_broadcast_ces(
compliant_expr,
*evaluate_into_exprs(
*node.exprs,
ns=ns,
str_as_lit=node.str_as_lit,
allow_multi_output=node.allow_multi_output,
),
)
md = md.with_node(node, compliant_expr, *compliant_expr_args)
if "." in node.name:
accessor, method = node.name.split(".")
func = getattr(getattr(compliant_expr, accessor), method)
else:
func = getattr(compliant_expr, node.name)
ret = cast("CompliantExprAny", func(*compliant_expr_args, **node.kwargs))
ret._opt_metadata = md
return ret
def evaluate_nodes(
nodes: Sequence[ExprNode], ns: CompliantNamespaceAny
) -> CompliantExprAny:
ce = evaluate_root_node(nodes[0], ns)
for node in nodes[1:]:
ce = evaluate_node(ce, node, ns)
return ce