DriverTrac/venv/lib/python3.12/site-packages/polars/_utils/udfs.py
2025-11-28 09:08:33 +05:30

1252 lines
46 KiB
Python

"""Utilities related to user defined functions (such as those passed to `apply`)."""
from __future__ import annotations
import datetime
import dis
import inspect
import re
import sys
import warnings
from bisect import bisect_left
from collections import defaultdict
from dis import get_instructions
from inspect import signature
from itertools import count, zip_longest
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Literal,
NamedTuple,
Union,
)
from polars._utils.cache import LRUCache
from polars._utils.various import no_default, re_escape
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping
from collections.abc import Set as AbstractSet
from dis import Instruction
from polars._utils.various import NoDefault
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
class StackValue(NamedTuple):
operator: str
operator_arity: int
left_operand: str
right_operand: str
from_module: str | None = None
MapTarget: TypeAlias = Literal["expr", "frame", "series"]
StackEntry: TypeAlias = Union[str, StackValue]
_MIN_PY311 = sys.version_info >= (3, 11)
_MIN_PY312 = _MIN_PY311 and sys.version_info >= (3, 12)
_MIN_PY314 = _MIN_PY312 and sys.version_info >= (3, 14)
_BYTECODE_PARSER_CACHE_: MutableMapping[
tuple[Callable[[Any], Any], str], BytecodeParser
] = LRUCache(32)
class OpNames:
BINARY: ClassVar[dict[str, str]] = {
"BINARY_ADD": "+",
"BINARY_AND": "&",
"BINARY_FLOOR_DIVIDE": "//",
"BINARY_LSHIFT": "<<",
"BINARY_RSHIFT": ">>",
"BINARY_MODULO": "%",
"BINARY_MULTIPLY": "*",
"BINARY_OR": "|",
"BINARY_POWER": "**",
"BINARY_SUBTRACT": "-",
"BINARY_TRUE_DIVIDE": "/",
"BINARY_XOR": "^",
}
CALL = frozenset({"CALL"} if _MIN_PY311 else {"CALL_FUNCTION", "CALL_METHOD"})
CONTROL_FLOW: ClassVar[dict[str, str]] = (
{
"POP_JUMP_FORWARD_IF_FALSE": "&",
"POP_JUMP_FORWARD_IF_TRUE": "|",
"JUMP_IF_FALSE_OR_POP": "&",
"JUMP_IF_TRUE_OR_POP": "|",
}
# note: 3.12 dropped POP_JUMP_FORWARD_IF_* opcodes
if _MIN_PY311 and not _MIN_PY312
else {
"POP_JUMP_IF_FALSE": "&",
"POP_JUMP_IF_TRUE": "|",
"JUMP_IF_FALSE_OR_POP": "&",
"JUMP_IF_TRUE_OR_POP": "|",
}
)
LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL"))
LOAD_ATTR = frozenset({"LOAD_METHOD", "LOAD_ATTR"})
LOAD = LOAD_VALUES | LOAD_ATTR
SIMPLIFY_SPECIALIZED: ClassVar[dict[str, str]] = {
"LOAD_FAST_BORROW": "LOAD_FAST",
"LOAD_SMALL_INT": "LOAD_CONST",
}
SYNTHETIC: ClassVar[dict[str, int]] = {
"POLARS_EXPRESSION": 1,
}
UNARY: ClassVar[dict[str, str]] = {
"UNARY_NEGATIVE": "-",
"UNARY_POSITIVE": "+",
"UNARY_NOT": "~",
}
PARSEABLE_OPS = frozenset(
{"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
| set(UNARY)
| set(CONTROL_FLOW)
| set(SYNTHETIC)
| LOAD_VALUES
)
MATCHABLE_OPS = (
set(SIMPLIFY_SPECIALIZED) | PARSEABLE_OPS | set(BINARY) | LOAD_ATTR | CALL
)
UNARY_VALUES = frozenset(UNARY.values())
# math module funcs that we can map to native expressions
_MATH_FUNCTIONS = frozenset(
(
"acos",
"acosh",
"asin",
"asinh",
"atan",
"atanh",
"cbrt",
"ceil",
"cos",
"cosh",
"degrees",
"exp",
"floor",
"log",
"log10",
"log1p",
"pow",
"radians",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
)
)
# numpy functions that we can map to native expressions
_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy"))
_NUMPY_FUNCTIONS = frozenset(
(
# "abs", # TODO: this one clashes with Python builtin abs
"arccos",
"arccosh",
"arcsin",
"arcsinh",
"arctan",
"arctanh",
"cbrt",
"ceil",
"cos",
"cosh",
"degrees",
"exp",
"floor",
"log",
"log10",
"log1p",
"radians",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
)
)
# python attrs/funcs that map to native expressions
_PYTHON_ATTRS_MAP = {
"date": "dt.date()",
"day": "dt.day()",
"hour": "dt.hour()",
"microsecond": "dt.microsecond()",
"minute": "dt.minute()",
"month": "dt.month()",
"second": "dt.second()",
"year": "dt.year()",
}
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"}
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
# string
"endswith": "str.ends_with",
"lower": "str.to_lowercase",
"lstrip": "str.strip_chars_start",
"removeprefix": "str.strip_prefix",
"removesuffix": "str.strip_suffix",
"replace": "str.replace",
"rstrip": "str.strip_chars_end",
"startswith": "str.starts_with",
"strip": "str.strip_chars",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
"zfill": "str.zfill",
# temporal
"date": "dt.date",
"day": "dt.day",
"hour": "dt.hour",
"isoweekday": "dt.weekday",
"microsecond": "dt.microsecond",
"month": "dt.month",
"second": "dt.second",
"strftime": "dt.strftime",
"time": "dt.time",
"year": "dt.year",
}
_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: numpy.func(x)
# lambda x: numpy.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [_NUMPY_MODULE_ALIASES],
"attribute_name": [],
"function_name": [_NUMPY_FUNCTIONS],
},
# lambda x: math.func(x)
# lambda x: math.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [{"math"}],
"attribute_name": [],
"function_name": [_MATH_FUNCTIONS],
},
# lambda x: json.loads(x)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [{"json"}],
"attribute_name": [],
"function_name": [{"loads"}],
},
# lambda x: datetime.strptime(x, CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [{"LOAD_CONST"}],
"module_opname": [OpNames.LOAD_ATTR],
"attribute_opname": [],
"module_name": [{"datetime"}],
"attribute_name": [],
"function_name": [{"strptime"}],
"check_load_global": False, # type: ignore[dict-item]
},
# lambda x: module.attribute.func(x, CONSTANT)
{
"argument_1_opname": [{"LOAD_FAST"}],
"argument_2_opname": [{"LOAD_CONST"}],
"module_opname": [{"LOAD_ATTR"}],
"attribute_opname": [OpNames.LOAD_ATTR],
"module_name": [{"datetime", "dt"}],
"attribute_name": [{"datetime"}],
"function_name": [{"strptime"}],
"check_load_global": False, # type: ignore[dict-item]
},
]
# In addition to `lambda x: func(x)`, also support cases when a unary operation
# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`.
_MODULE_FUNCTIONS = [
{**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item]
for kind in _MODULE_FUNCTIONS
for unary in [[set(OpNames.UNARY)], []]
]
# Lookup for module functions that have different names as polars expressions
_MODULE_FUNC_TO_EXPR_NAME = {
"math.acos": "arccos",
"math.acosh": "arccosh",
"math.asin": "arcsin",
"math.asinh": "arcsinh",
"math.atan": "arctan",
"math.atanh": "arctanh",
"json.loads": "str.json_decode",
}
_RE_IMPLICIT_BOOL = re.compile(r'pl\.col\("([^"]*)"\) & pl\.col\("\1"\)\.(.+)')
_RE_SERIES_NAMES = re.compile(r"^(s|srs\d?|series)\.")
_RE_STRIP_BOOL = re.compile(r"^bool\((.+)\)$")
def _get_all_caller_variables() -> dict[str, Any]:
"""Get all local and global variables from caller's frame."""
pkg_dir = Path(__file__).parent.parent
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
try:
while frame:
fname = inspect.getfile(frame)
if fname.startswith(str(pkg_dir)):
frame = frame.f_back
n += 1
else:
break
variables: dict[str, Any]
if frame is None:
variables = {}
else:
variables = {**frame.f_locals, **frame.f_globals}
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a finally clause.
del frame
return variables
def _get_target_name(col: str, expression: str, map_target: str) -> str:
"""The name of the object against which the 'map' is being invoked."""
col_expr = f'pl.col("{col}")'
if map_target == "expr":
return col_expr
elif map_target == "series":
if _RE_SERIES_NAMES.match(expression):
return expression.split(".", 1)[0]
# note: handle overlapping name from global variables; fallback
# through "s", "srs", "series" and (finally) srs0 -> srsN...
search_expr = expression.replace(col_expr, "")
for name in ("s", "srs", "series"):
if not re.search(rf"\b{name}\b", search_expr):
return name
n = count()
while True:
name = f"srs{next(n)}"
if not re.search(rf"\b{name}\b", search_expr):
return name
msg = f"TODO: map_target = {map_target!r}"
raise NotImplementedError(msg)
class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""
_map_target_name: str | None = None
_can_attempt_rewrite: bool | None = None
_caller_variables: dict[str, Any] | None = None
_col_expression: tuple[str, str] | NoDefault | None = no_default
def __init__(self, function: Callable[[Any], Any], map_target: MapTarget) -> None:
"""
Initialize BytecodeParser instance and prepare to introspect UDFs.
Parameters
----------
function : callable
The function/lambda to disassemble and introspect.
map_target : {'expr','series','frame'}
The underlying target object type of the map operation.
"""
try:
original_instructions = get_instructions(function)
except TypeError:
# in case we hit something that can't be disassembled (eg: code object
# unavailable, like a bare numpy ufunc that isn't in a lambda/function)
original_instructions = iter([])
self._function = function
self._map_target = map_target
self._param_name = self._get_param_name(function)
self._rewritten_instructions = RewrittenInstructions(
instructions=original_instructions,
caller_variables=self._caller_variables,
function=function,
)
def _omit_implicit_bool(self, expr: str) -> str:
"""Drop extraneous/implied bool (eg: `pl.col("d") & pl.col("d").dt.date()`)."""
while _RE_IMPLICIT_BOOL.search(expr):
expr = _RE_IMPLICIT_BOOL.sub(repl=r'pl.col("\1").\2', string=expr)
return expr
@staticmethod
def _get_param_name(function: Callable[[Any], Any]) -> str | None:
"""Return single function parameter name."""
try:
# note: we do not parse/handle functions with > 1 params
sig = signature(function)
except ValueError:
return None
return (
next(iter(parameters.keys()))
if len(parameters := sig.parameters) == 1
else None
)
def _inject_nesting(
self,
expression_blocks: dict[int, str],
logical_instructions: list[Instruction],
) -> list[tuple[int, str]]:
"""Inject nesting boundaries into expression blocks (as parentheses)."""
if logical_instructions:
# reconstruct nesting for mixed 'and'/'or' ops by associating control flow
# jump offsets with their target expression blocks and applying parens
if len({inst.opname for inst in logical_instructions}) > 1:
block_offsets: list[int] = list(expression_blocks.keys())
prev_end = -1
for inst in logical_instructions:
start = block_offsets[bisect_left(block_offsets, inst.offset) - 1]
end = block_offsets[bisect_left(block_offsets, inst.argval) - 1]
if not (start == 0 and end == block_offsets[-1]):
if prev_end not in (start, end):
expression_blocks[start] = "(" + expression_blocks[start]
expression_blocks[end] += ")"
prev_end = end
for inst in logical_instructions: # inject connecting "&" and "|" ops
expression_blocks[inst.offset] = OpNames.CONTROL_FLOW[inst.opname]
return sorted(expression_blocks.items())
@property
def map_target(self) -> MapTarget:
"""The map target, eg: one of 'expr', 'frame', or 'series'."""
return self._map_target
def can_attempt_rewrite(self) -> bool:
"""
Determine if we may be able to offer a native polars expression instead.
Note that `lambda x: x` is inefficient, but we ignore it because it is not
guaranteed that using the equivalent bare constant value will return the
same output. (Hopefully nobody is writing lambdas like that anyway...)
"""
if self._can_attempt_rewrite is None:
self._can_attempt_rewrite = (
self._param_name is not None
# check minimum number of ops, ensuring all are parseable
and len(self._rewritten_instructions) >= 2
and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
)
# exclude constructs/functions with multiple RETURN_VALUE ops
and sum(
1
for inst in self.original_instructions
if inst.opname == "RETURN_VALUE"
)
== 1
)
return self._can_attempt_rewrite
def dis(self) -> None:
"""Print disassembled function bytecode."""
dis.dis(self._function)
@property
def function(self) -> Callable[[Any], Any]:
"""The function being parsed."""
return self._function
@property
def original_instructions(self) -> list[Instruction]:
"""The original bytecode instructions from the function we are parsing."""
return list(self._rewritten_instructions._original_instructions)
@property
def param_name(self) -> str | None:
"""The parameter name of the function being parsed."""
return self._param_name
@property
def rewritten_instructions(self) -> list[Instruction]:
"""The rewritten bytecode instructions from the function we are parsing."""
return list(self._rewritten_instructions)
def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression/string."""
if self._col_expression is not no_default and self._col_expression is not None:
col_name, expr = self._col_expression
if col != col_name:
expr = re.sub(
rf'pl\.col\("{re_escape(col_name)}"\)',
f'pl.col("{re_escape(col)}")',
expr,
)
self._col_expression = (col, expr)
return expr
self._map_target_name = None
if self._param_name is None:
self._col_expression = None
return None
# decompose bytecode into logical 'and'/'or' expression blocks (if present)
control_flow_blocks = defaultdict(list)
logical_instructions = []
jump_offset = 0
for idx, inst in enumerate(self._rewritten_instructions):
if inst.opname in OpNames.CONTROL_FLOW:
jump_offset = self._rewritten_instructions[idx + 1].offset
logical_instructions.append(inst)
else:
control_flow_blocks[jump_offset].append(inst)
# convert each block to a polars expression string
try:
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
caller_variables=self._caller_variables,
map_target=self._map_target,
function=self._function,
).to_expression(
col=col,
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in control_flow_blocks.items()
},
logical_instructions,
)
except NotImplementedError:
self._col_expression = None
return None
polars_expr = " ".join(expr for _offset, expr in expression_strings)
# note: if no 'pl.col' in the expression, it likely represents a compound
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
if "pl.col(" not in polars_expr:
self._col_expression = None
return None
else:
polars_expr = self._omit_implicit_bool(polars_expr)
if self._map_target == "series":
if (target_name := self._map_target_name) is None:
target_name = _get_target_name(col, polars_expr, self._map_target)
polars_expr = polars_expr.replace(f'pl.col("{col}")', target_name)
self._col_expression = (col, polars_expr)
return polars_expr
def warn(
self,
col: str,
*,
suggestion_override: str | None = None,
udf_override: str | None = None,
) -> None:
"""Generate warning that suggests an equivalent native polars expression."""
# Import these here so that udfs can be imported without polars installed.
from polars._utils.various import (
find_stacklevel,
in_terminal_that_supports_colour,
)
from polars.exceptions import PolarsInefficientMapWarning
suggested_expression = suggestion_override or self.to_expression(col)
if suggested_expression is not None:
if (target_name := self._map_target_name) is None:
target_name = _get_target_name(
col, suggested_expression, self._map_target
)
func_name = udf_override or self._function.__name__ or "..."
if func_name == "<lambda>":
func_name = f"lambda {self._param_name}: ..."
addendum = (
'Note: in list.eval context, pl.col("") should be written as pl.element()'
if 'pl.col("")' in suggested_expression
else ""
)
apitype, clsname = (
("expressions", "Expr")
if self._map_target == "expr"
else ("series", "Series")
)
before, after = (
(
f" \033[31m- {target_name}.map_elements({func_name})\033[0m\n",
f" \033[32m+ {suggested_expression}\033[0m\n{addendum}",
)
if in_terminal_that_supports_colour()
else (
f" - {target_name}.map_elements({func_name})\n",
f" + {suggested_expression}\n{addendum}",
)
)
warnings.warn(
f"\n{clsname}.map_elements is significantly slower than the native {apitype} API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"Replace this expression...\n"
f"{before}"
"with this one instead:\n"
f"{after}",
PolarsInefficientMapWarning,
stacklevel=find_stacklevel(),
)
class InstructionTranslator:
"""Translates Instruction bytecode to a polars expression string."""
def __init__(
self,
instructions: list[Instruction],
caller_variables: dict[str, Any] | None,
function: Callable[[Any], Any],
map_target: MapTarget,
) -> None:
self._stack = self._to_intermediate_stack(instructions, map_target)
self._caller_variables = caller_variables
self._function = function
def to_expression(self, col: str, param_name: str, depth: int) -> str:
"""Convert intermediate stack to polars expression string."""
return self._expr(self._stack, col, param_name, depth)
@staticmethod
def op(inst: Instruction) -> str:
"""Convert bytecode instruction to suitable intermediate op string."""
if (opname := inst.opname) in OpNames.CONTROL_FLOW:
return OpNames.CONTROL_FLOW[opname]
elif inst.argrepr:
return inst.argrepr
elif opname == "IS_OP":
return "is not" if inst.argval else "is"
elif opname == "CONTAINS_OP":
return "not in" if inst.argval else "in"
elif opname in OpNames.UNARY:
return OpNames.UNARY[opname]
elif opname == "BINARY_SUBSCR":
return "replace_strict"
else:
msg = (
f"unexpected or unrecognised op name ({opname})\n\n"
"Please report a bug to https://github.com/pola-rs/polars/issues "
"with the content of function you were passing to the `map` "
f"expression and the following instruction object:\n{inst!r}"
)
raise AssertionError(msg)
def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str:
"""Take stack entry value and convert to polars expression string."""
if isinstance(value, StackValue):
op = _RE_STRIP_BOOL.sub(r"\1", value.operator)
e1 = self._expr(value.left_operand, col, param_name, depth + 1)
if value.operator_arity == 1:
if op not in OpNames.UNARY_VALUES:
if e1.startswith("pl.col("):
call = "" if op.endswith(")") else "()"
return f"{e1}.{op}{call}"
if e1[0] in OpNames.UNARY_VALUES and e1[1:].startswith("pl.col("):
call = "" if op.endswith(")") else "()"
return f"({e1}).{op}{call}"
# support use of consts as numpy/builtin params, eg:
# "np.sin(3) + np.cos(x)", or "len('const_string') + len(x)"
if (
value.from_module in _NUMPY_MODULE_ALIASES
and op in _NUMPY_FUNCTIONS
):
pfx = "np."
elif (
value.from_module == "math"
and _MODULE_FUNC_TO_EXPR_NAME.get(f"math.{op}", op)
in _MATH_FUNCTIONS
):
pfx = "math."
else:
pfx = ""
return f"{pfx}{op}({e1})"
return f"{op}{e1}"
else:
e2 = self._expr(value.right_operand, col, param_name, depth + 1)
if op in ("is", "is not") and value.left_operand == "None":
not_ = "" if op == "is" else "not_"
return f"{e1}.is_{not_}null()"
elif op in ("in", "not in"):
not_ = "" if op == "in" else "~"
return (
f"{not_}({e1}.is_in({e2}))"
if " " in e1
else f"{not_}{e1}.is_in({e2})"
)
elif op == "replace_strict":
if not self._caller_variables:
self._caller_variables = _get_all_caller_variables()
if not isinstance(self._caller_variables.get(e1, None), dict):
msg = "require dict mapping"
raise NotImplementedError(msg)
return f"{e2}.{op}({e1})"
elif op == "<<":
# 2**e2 may be float if e2 was -ve, but if e1 << e2 was valid then
# e2 must have been +ve. therefore 2**e2 can be safely cast to
# i64, which may be necessary if chaining ops that assume i64.
return f"({e1} * 2**{e2}).cast(pl.Int64)"
elif op == ">>":
# (motivation for the cast is same as the '<<' case above)
return f"({e1} / 2**{e2}).cast(pl.Int64)"
else:
expr = f"{e1} {op} {e2}"
return f"({expr})" if depth else expr
elif value == param_name:
return f'pl.col("{col}")'
return value
def _to_intermediate_stack(
self, instructions: list[Instruction], map_target: MapTarget
) -> StackEntry:
"""Take postfix bytecode and convert to an intermediate natural-order stack."""
if map_target in ("expr", "series"):
stack: list[StackEntry] = []
for inst in instructions:
stack.append(
inst.argrepr
if inst.opname in OpNames.LOAD
else (
StackValue(
operator=self.op(inst),
operator_arity=1,
left_operand=stack.pop(), # type: ignore[arg-type]
right_operand=None, # type: ignore[arg-type]
from_module=getattr(inst, "_from_module", None),
)
if (
inst.opname in OpNames.UNARY
or OpNames.SYNTHETIC.get(inst.opname) == 1
)
else StackValue(
operator=self.op(inst),
operator_arity=2,
left_operand=stack.pop(-2), # type: ignore[arg-type]
right_operand=stack.pop(-1), # type: ignore[arg-type]
from_module=getattr(inst, "_from_module", None),
)
)
)
return stack[0]
# TODO: dataframe.map... ?
msg = f"TODO: {map_target!r} map target not yet supported."
raise NotImplementedError(msg)
class RewrittenInstructions:
"""
Standalone class that applies Instruction rewrite/filtering rules.
This significantly simplifies subsequent parsing by injecting
synthetic POLARS_EXPRESSION ops into the Instruction stream for
easy identification/translation, and separates the parsing logic
from the identification of expression translation opportunities.
"""
_ignored_ops = frozenset(
[
"COPY",
"COPY_FREE_VARS",
"NOT_TAKEN",
"POP_TOP",
"PRECALL",
"PUSH_NULL",
"RESUME",
"RETURN_VALUE",
"TO_BOOL",
]
)
def __init__(
self,
instructions: Iterator[Instruction],
function: Callable[[Any], Any],
caller_variables: dict[str, Any] | None,
) -> None:
self._function = function
self._caller_variables = caller_variables
self._original_instructions = list(instructions)
normalised_instructions = []
for inst in self._unpack_superinstructions(self._original_instructions):
if inst.opname not in self._ignored_ops:
if inst.opname not in OpNames.MATCHABLE_OPS:
self._rewritten_instructions = []
return
upgraded_inst = self._update_instruction(inst)
normalised_instructions.append(upgraded_inst)
self._rewritten_instructions = self._rewrite(normalised_instructions)
def __len__(self) -> int:
return len(self._rewritten_instructions)
def __iter__(self) -> Iterator[Instruction]:
return iter(self._rewritten_instructions)
def __getitem__(self, item: Any) -> Instruction:
return self._rewritten_instructions[item]
def _matches(
self,
idx: int,
*,
opnames: list[AbstractSet[str]],
argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None,
is_attr: bool = False,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Parameters
----------
idx
The index of the first instruction to check.
opnames
The full opname sequence that defines a match.
argvals
Associated argvals that must also match (in same position as opnames).
is_attr
Indicate if the match represents pure attribute access (cannot be called).
"""
n_required_ops, argvals = len(opnames), argvals or []
idx_offset = idx + n_required_ops
if (
is_attr
and (trailing_inst := self._instructions[idx_offset : idx_offset + 1])
and trailing_inst[0].opname in OpNames.CALL # not pure attr if called
):
return []
instructions = self._instructions[idx:idx_offset]
if len(instructions) == n_required_ops and all(
inst.opname in match_opnames
and (match_argval is None or inst.argval in match_argval)
for inst, match_opnames, match_argval in zip_longest(
instructions, opnames, argvals
)
):
return instructions
return []
def _rewrite(self, instructions: list[Instruction]) -> list[Instruction]:
"""
Apply rewrite rules, potentially injecting synthetic operations.
Rules operate on the instruction stream and can examine/modify
it as needed, pushing updates into "updated_instructions" and
returning True/False to indicate if any changes were made.
"""
self._instructions = instructions
updated_instructions: list[Instruction] = []
idx = 0
while idx < len(self._instructions):
inst, increment = self._instructions[idx], 1
if inst.opname not in OpNames.LOAD or not any(
(increment := map_rewrite(idx, updated_instructions))
for map_rewrite in (
# add any other rewrite methods here
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_attrs,
)
):
updated_instructions.append(inst)
idx += increment or 1
return updated_instructions
def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int:
"""Replace python attribute lookup with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}],
argvals=[None, _PYTHON_ATTRS_MAP],
is_attr=True,
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
px = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], px])
return len(matching_instructions)
def _rewrite_builtins(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace builtin function calls with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST", "LOAD_CONST"}, OpNames.CALL],
argvals=[_PYTHON_BUILTINS],
):
inst1, inst2 = matching_instructions[:2]
if (argval := inst1.argval) in _PYTHON_CASTS_MAP:
dtype = _PYTHON_CASTS_MAP[argval]
argval = f"cast(pl.{dtype})"
px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst2._replace(offset=inst1.offset)
updated_instructions.extend((operand, px))
return len(matching_instructions)
def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace function calls with a synthetic POLARS_EXPRESSION op."""
for check_globals in (False, True):
for function_kind in _MODULE_FUNCTIONS:
if check_globals and not function_kind.get("check_load_global", True):
return 0
opnames: list[AbstractSet[str]] = (
[
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["argument_1_opname"],
*function_kind["argument_1_unary_opname"],
*function_kind["argument_2_opname"],
OpNames.CALL,
]
if check_globals
else [
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["module_opname"],
*function_kind["attribute_opname"],
*function_kind["argument_1_opname"],
*function_kind["argument_1_unary_opname"],
*function_kind["argument_2_opname"],
OpNames.CALL,
]
)
module_aliases = function_kind["module_name"]
if matching_instructions := self._matches(
idx,
opnames=opnames,
argvals=[
*function_kind["function_name"],
]
if check_globals
else [
*function_kind["module_name"],
*function_kind["attribute_name"],
*function_kind["function_name"],
],
):
attribute_count = len(function_kind["attribute_name"])
inst1, inst2, inst3 = matching_instructions[
attribute_count : 3 + attribute_count
]
if check_globals:
if not self._caller_variables:
self._caller_variables = _get_all_caller_variables()
if (expr_name := inst1.argval) not in self._caller_variables:
continue
else:
module_name = self._caller_variables[expr_name].__module__
if not any((module_name in m) for m in module_aliases):
continue
expr_name = _MODULE_FUNC_TO_EXPR_NAME.get(
f"{module_name}.{expr_name}", expr_name
)
elif inst1.argval == "json":
expr_name = "str.json_decode"
elif inst1.argval == "datetime":
fmt = matching_instructions[attribute_count + 3].argval
expr_name = f'str.to_datetime(format="{fmt}")'
if not self._is_stdlib_datetime(
inst1.argval,
matching_instructions[0].argval,
attribute_count,
):
# skip these instructions if not stdlib datetime function
return len(matching_instructions)
elif inst1.argval == "math":
expr_name = _MODULE_FUNC_TO_EXPR_NAME.get(
f"math.{inst2.argval}", inst2.argval
)
else:
expr_name = inst2.argval
# note: POLARS_EXPRESSION is mapped as unary op, so switch
# instruction order/offsets (for later RPE-type stack walk)
swap_inst = inst2 if check_globals else inst3
px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=expr_name,
argrepr=expr_name,
offset=swap_inst.offset,
)
px._from_module = None if check_globals else (inst1.argval or None) # type: ignore[attr-defined]
operand = swap_inst._replace(offset=inst1.offset)
updated_instructions.extend(
(
operand,
matching_instructions[3 + attribute_count],
px,
)
if function_kind["argument_1_unary_opname"]
else (operand, px)
)
return len(matching_instructions)
return 0
def _rewrite_methods(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace python method calls with synthetic POLARS_EXPRESSION op."""
LOAD_METHOD = OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"}
if matching_instructions := (
# method call with one arg, eg: "s.endswith('!')"
self._matches(
idx,
opnames=[LOAD_METHOD, {"LOAD_CONST"}, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
or
# method call with no arg, eg: "s.lower()"
self._matches(
idx,
opnames=[LOAD_METHOD, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
):
inst = matching_instructions[0]
expr = _PYTHON_METHODS_MAP[inst.argval]
if matching_instructions[1].opname == "LOAD_CONST":
param_value = matching_instructions[1].argval
if isinstance(param_value, tuple) and expr in (
"str.starts_with",
"str.ends_with",
):
starts, ends = ("^", "") if "starts" in expr else ("", "$")
rx = "|".join(re_escape(v) for v in param_value)
q = '"' if "'" in param_value else "'"
expr = f"str.contains(r{q}{starts}({rx}){ends}{q})"
else:
expr += f"({param_value!r})"
px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr)
updated_instructions.append(px)
elif matching_instructions := (
# method call with three args, eg: "s.replace('!','?',count=2)"
self._matches(
idx,
opnames=[
LOAD_METHOD,
{"LOAD_CONST"},
{"LOAD_CONST"},
{"LOAD_CONST"},
OpNames.CALL,
],
argvals=[_PYTHON_METHODS_MAP],
)
or
# method call with two args, eg: "s.replace('!','?')"
self._matches(
idx,
opnames=[LOAD_METHOD, {"LOAD_CONST"}, {"LOAD_CONST"}, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
):
inst = matching_instructions[0]
expr = _PYTHON_METHODS_MAP[inst.argval]
param_values = [
i.argval
for i in matching_instructions[1 : len(matching_instructions) - 1]
]
if expr == "str.replace":
if len(param_values) == 3:
old, new, count = param_values
expr += f"({old!r},{new!r},n={count},literal=True)"
else:
old, new = param_values
expr = f"str.replace_all({old!r},{new!r},literal=True)"
else:
expr += f"({','.join(repr(v) for v in param_values)})"
px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr)
updated_instructions.append(px)
return len(matching_instructions)
@staticmethod
def _unpack_superinstructions(
instructions: list[Instruction],
) -> Iterator[Instruction]:
"""Expand known 'superinstructions' into their component parts."""
for inst in instructions:
if inst.opname in (
"LOAD_FAST_LOAD_FAST",
"LOAD_FAST_BORROW_LOAD_FAST_BORROW",
):
for idx in (0, 1):
yield inst._replace(
opname="LOAD_FAST",
argval=inst.argval[idx],
argrepr=inst.argval[idx],
)
else:
yield inst
@staticmethod
def _update_instruction(inst: Instruction) -> Instruction:
"""Update/modify specific instructions to simplify multi-version parsing."""
if not _MIN_PY311 and inst.opname in OpNames.BINARY:
# update older binary opcodes using py >= 3.11 'BINARY_OP' instead
inst = inst._replace(
argrepr=OpNames.BINARY[inst.opname],
opname="BINARY_OP",
)
elif _MIN_PY314:
if (opname := inst.opname) in OpNames.SIMPLIFY_SPECIALIZED:
# simplify specialised opcode variants to their more generic form
# (eg: 'LOAD_FAST_BORROW' -> 'LOAD_FAST', etc)
updated_params = {"opname": OpNames.SIMPLIFY_SPECIALIZED[inst.opname]}
if opname == "LOAD_SMALL_INT":
updated_params["argrepr"] = str(inst.argval)
inst = inst._replace(**updated_params) # type: ignore[arg-type]
elif opname == "BINARY_OP" and inst.argrepr == "[]":
# special case for new 'BINARY_OP ([])'; revert to 'BINARY_SUBSCR'
inst = inst._replace(opname="BINARY_SUBSCR", argrepr="")
return inst
def _is_stdlib_datetime(
self, function_name: str, module_name: str, attribute_count: int
) -> bool:
if not self._caller_variables:
self._caller_variables = _get_all_caller_variables()
vars = self._caller_variables
return (
attribute_count == 0 and vars.get(function_name) is datetime.datetime
) or (attribute_count == 1 and vars.get(module_name) is datetime)
def _raw_function_meta(function: Callable[[Any], Any]) -> tuple[str, str]:
"""Identify translatable calls that aren't wrapped inside a lambda/function."""
try:
func_module = function.__class__.__module__
func_name = function.__name__
except AttributeError:
return "", ""
# numpy function calls
if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS:
return "np", f"{func_name}()"
# python function calls
elif func_module == "builtins":
if func_name in _PYTHON_CASTS_MAP:
return "builtins", f"cast(pl.{_PYTHON_CASTS_MAP[func_name]})"
elif func_name in _MATH_FUNCTIONS:
import math
if function is getattr(math, func_name):
expr_name = _MODULE_FUNC_TO_EXPR_NAME.get(
f"math.{func_name}", func_name
)
return "math", f"{expr_name}()"
elif func_name == "loads":
import json # double-check since it is referenced via 'builtins'
if function is json.loads:
return "json", "str.json_decode()"
return "", ""
def warn_on_inefficient_map(
function: Callable[[Any], Any], columns: list[str], map_target: MapTarget
) -> None:
"""
Generate `PolarsInefficientMapWarning` on poor usage of a `map` function.
Parameters
----------
function
The function passed to `map`.
columns
The column name(s) of the original object; in the case of an `Expr` this
will be a list of length 1, containing the expression's root name.
map_target
The target of the `map` call. One of `"expr"`, `"frame"`, or `"series"`.
"""
if map_target == "frame":
msg = "TODO: 'frame' map-function parsing"
raise NotImplementedError(msg)
# note: we only consider simple functions with a single col/param
col: str = columns and columns[0] # type: ignore[assignment]
if not col and col != "":
return None
# the parser introspects function bytecode to determine if we can
# rewrite as a (much) more optimal native polars expression instead
if (parser := _BYTECODE_PARSER_CACHE_.get(key := (function, map_target))) is None:
parser = BytecodeParser(function, map_target)
_BYTECODE_PARSER_CACHE_[key] = parser
if parser.can_attempt_rewrite():
parser.warn(col)
else:
# handle bare numpy/json functions
module, suggestion = _raw_function_meta(function)
if module and suggestion:
target_name = _get_target_name(col, suggestion, map_target)
parser._map_target_name = target_name
fn = function.__name__
parser.warn(
col,
suggestion_override=f"{target_name}.{suggestion}",
udf_override=fn if module == "builtins" else f"{module}.{fn}",
)
__all__ = ["BytecodeParser", "warn_on_inefficient_map"]