156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import sys
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from polars._utils.parse import parse_into_list_of_expressions
|
|
from polars._utils.wrap import wrap_expr
|
|
|
|
with contextlib.suppress(ImportError): # Module not available when building docs
|
|
import polars._plr as plr
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable
|
|
|
|
from polars import Expr
|
|
from polars._typing import IntoExpr
|
|
|
|
__all__ = ["register_plugin_function"]
|
|
|
|
|
|
def register_plugin_function(
|
|
*,
|
|
plugin_path: Path | str,
|
|
function_name: str,
|
|
args: IntoExpr | Iterable[IntoExpr],
|
|
kwargs: dict[str, Any] | None = None,
|
|
is_elementwise: bool = False,
|
|
changes_length: bool = False,
|
|
returns_scalar: bool = False,
|
|
cast_to_supertype: bool = False,
|
|
input_wildcard_expansion: bool = False,
|
|
pass_name_to_apply: bool = False,
|
|
use_abs_path: bool = False,
|
|
) -> Expr:
|
|
"""
|
|
Register a plugin function.
|
|
|
|
See the `user guide <https://docs.pola.rs/user-guide/plugins/expr_plugins>`_
|
|
for more information about plugins.
|
|
|
|
Parameters
|
|
----------
|
|
plugin_path
|
|
Path to the plugin package. Accepts either the file path to the dynamic library
|
|
file or the path to the directory containing it.
|
|
function_name
|
|
The name of the Rust function to register.
|
|
args
|
|
The arguments passed to this function. These get passed to the `input`
|
|
argument on the Rust side, and have to be expressions (or be convertible
|
|
to expressions).
|
|
kwargs
|
|
Non-expression arguments to the plugin function. These must be
|
|
JSON serializable.
|
|
is_elementwise
|
|
Indicate that the function operates on scalars only. This will potentially
|
|
trigger fast paths.
|
|
changes_length
|
|
Indicate that the function will change the length of the expression.
|
|
For example, a `unique` or `slice` operation.
|
|
returns_scalar
|
|
Automatically explode on unit length if the function ran as final aggregation.
|
|
This is the case for aggregations like `sum`, `min`, `covariance` etc.
|
|
cast_to_supertype
|
|
Cast the input expressions to their supertype.
|
|
input_wildcard_expansion
|
|
Expand wildcard expressions before executing the function.
|
|
pass_name_to_apply
|
|
If set to `True`, the `Series` passed to the function in a group-by operation
|
|
will ensure the name is set. This is an extra heap allocation per group.
|
|
use_abs_path
|
|
If set to `True`, the path will be resolved to an absolute path.
|
|
The path to the dynamic library is relative to the virtual environment by
|
|
default.
|
|
|
|
Returns
|
|
-------
|
|
Expr
|
|
|
|
Warnings
|
|
--------
|
|
This is highly unsafe as this will call the C function loaded by
|
|
`plugin::function_name`.
|
|
|
|
The parameters you set dictate how Polars will handle the function.
|
|
Make sure they are correct!
|
|
"""
|
|
pyexprs = parse_into_list_of_expressions(args)
|
|
serialized_kwargs = _serialize_kwargs(kwargs)
|
|
plugin_path = _resolve_plugin_path(plugin_path, use_abs_path=use_abs_path)
|
|
|
|
return wrap_expr(
|
|
plr.register_plugin_function(
|
|
plugin_path=str(plugin_path),
|
|
function_name=function_name,
|
|
args=pyexprs,
|
|
kwargs=serialized_kwargs,
|
|
is_elementwise=is_elementwise,
|
|
input_wildcard_expansion=input_wildcard_expansion,
|
|
returns_scalar=returns_scalar,
|
|
cast_to_supertype=cast_to_supertype,
|
|
pass_name_to_apply=pass_name_to_apply,
|
|
changes_length=changes_length,
|
|
)
|
|
)
|
|
|
|
|
|
def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes:
|
|
"""Serialize the function's keyword arguments."""
|
|
if not kwargs:
|
|
return b""
|
|
|
|
import pickle
|
|
|
|
# Use the highest pickle protocol supported the serde-pickle crate:
|
|
# https://docs.rs/serde-pickle/latest/serde_pickle/
|
|
return pickle.dumps(kwargs, protocol=5)
|
|
|
|
|
|
@lru_cache(maxsize=16)
|
|
def _resolve_plugin_path(path: Path | str, *, use_abs_path: bool = False) -> Path:
|
|
"""Get the file path of the dynamic library file."""
|
|
if not isinstance(path, Path):
|
|
path = Path(path)
|
|
|
|
if path.is_file():
|
|
return _resolve_file_path(path, use_abs_path=use_abs_path)
|
|
|
|
for p in path.iterdir():
|
|
if _is_dynamic_lib(p):
|
|
return _resolve_file_path(p, use_abs_path=use_abs_path)
|
|
|
|
msg = f"no dynamic library found at path: {path}"
|
|
raise FileNotFoundError(msg)
|
|
|
|
|
|
def _is_dynamic_lib(path: Path) -> bool:
|
|
return path.is_file() and path.suffix in (".so", ".dll", ".pyd")
|
|
|
|
|
|
def _resolve_file_path(path: Path, *, use_abs_path: bool = False) -> Path:
|
|
venv_path = Path(sys.prefix)
|
|
|
|
if use_abs_path:
|
|
return path.resolve()
|
|
else:
|
|
try:
|
|
file_path = path.relative_to(venv_path)
|
|
except ValueError: # Fallback
|
|
file_path = path.resolve()
|
|
|
|
return file_path
|