from __future__ import annotations import contextlib import sys from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any from polars._utils.parse import parse_into_list_of_expressions from polars._utils.wrap import wrap_expr with contextlib.suppress(ImportError): # Module not available when building docs import polars._plr as plr if TYPE_CHECKING: from collections.abc import Iterable from polars import Expr from polars._typing import IntoExpr __all__ = ["register_plugin_function"] def register_plugin_function( *, plugin_path: Path | str, function_name: str, args: IntoExpr | Iterable[IntoExpr], kwargs: dict[str, Any] | None = None, is_elementwise: bool = False, changes_length: bool = False, returns_scalar: bool = False, cast_to_supertype: bool = False, input_wildcard_expansion: bool = False, pass_name_to_apply: bool = False, use_abs_path: bool = False, ) -> Expr: """ Register a plugin function. See the `user guide `_ for more information about plugins. Parameters ---------- plugin_path Path to the plugin package. Accepts either the file path to the dynamic library file or the path to the directory containing it. function_name The name of the Rust function to register. args The arguments passed to this function. These get passed to the `input` argument on the Rust side, and have to be expressions (or be convertible to expressions). kwargs Non-expression arguments to the plugin function. These must be JSON serializable. is_elementwise Indicate that the function operates on scalars only. This will potentially trigger fast paths. changes_length Indicate that the function will change the length of the expression. For example, a `unique` or `slice` operation. returns_scalar Automatically explode on unit length if the function ran as final aggregation. This is the case for aggregations like `sum`, `min`, `covariance` etc. cast_to_supertype Cast the input expressions to their supertype. input_wildcard_expansion Expand wildcard expressions before executing the function. pass_name_to_apply If set to `True`, the `Series` passed to the function in a group-by operation will ensure the name is set. This is an extra heap allocation per group. use_abs_path If set to `True`, the path will be resolved to an absolute path. The path to the dynamic library is relative to the virtual environment by default. Returns ------- Expr Warnings -------- This is highly unsafe as this will call the C function loaded by `plugin::function_name`. The parameters you set dictate how Polars will handle the function. Make sure they are correct! """ pyexprs = parse_into_list_of_expressions(args) serialized_kwargs = _serialize_kwargs(kwargs) plugin_path = _resolve_plugin_path(plugin_path, use_abs_path=use_abs_path) return wrap_expr( plr.register_plugin_function( plugin_path=str(plugin_path), function_name=function_name, args=pyexprs, kwargs=serialized_kwargs, is_elementwise=is_elementwise, input_wildcard_expansion=input_wildcard_expansion, returns_scalar=returns_scalar, cast_to_supertype=cast_to_supertype, pass_name_to_apply=pass_name_to_apply, changes_length=changes_length, ) ) def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes: """Serialize the function's keyword arguments.""" if not kwargs: return b"" import pickle # Use the highest pickle protocol supported the serde-pickle crate: # https://docs.rs/serde-pickle/latest/serde_pickle/ return pickle.dumps(kwargs, protocol=5) @lru_cache(maxsize=16) def _resolve_plugin_path(path: Path | str, *, use_abs_path: bool = False) -> Path: """Get the file path of the dynamic library file.""" if not isinstance(path, Path): path = Path(path) if path.is_file(): return _resolve_file_path(path, use_abs_path=use_abs_path) for p in path.iterdir(): if _is_dynamic_lib(p): return _resolve_file_path(p, use_abs_path=use_abs_path) msg = f"no dynamic library found at path: {path}" raise FileNotFoundError(msg) def _is_dynamic_lib(path: Path) -> bool: return path.is_file() and path.suffix in (".so", ".dll", ".pyd") def _resolve_file_path(path: Path, *, use_abs_path: bool = False) -> Path: venv_path = Path(sys.prefix) if use_abs_path: return path.resolve() else: try: file_path = path.relative_to(venv_path) except ValueError: # Fallback file_path = path.resolve() return file_path