DriverTrac/venv/lib/python3.12/site-packages/polars/plugins.py

156 lines
4.9 KiB
Python

from __future__ import annotations
import contextlib
import sys
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any
from polars._utils.parse import parse_into_list_of_expressions
from polars._utils.wrap import wrap_expr
with contextlib.suppress(ImportError): # Module not available when building docs
import polars._plr as plr
if TYPE_CHECKING:
from collections.abc import Iterable
from polars import Expr
from polars._typing import IntoExpr
__all__ = ["register_plugin_function"]
def register_plugin_function(
*,
plugin_path: Path | str,
function_name: str,
args: IntoExpr | Iterable[IntoExpr],
kwargs: dict[str, Any] | None = None,
is_elementwise: bool = False,
changes_length: bool = False,
returns_scalar: bool = False,
cast_to_supertype: bool = False,
input_wildcard_expansion: bool = False,
pass_name_to_apply: bool = False,
use_abs_path: bool = False,
) -> Expr:
"""
Register a plugin function.
See the `user guide <https://docs.pola.rs/user-guide/plugins/expr_plugins>`_
for more information about plugins.
Parameters
----------
plugin_path
Path to the plugin package. Accepts either the file path to the dynamic library
file or the path to the directory containing it.
function_name
The name of the Rust function to register.
args
The arguments passed to this function. These get passed to the `input`
argument on the Rust side, and have to be expressions (or be convertible
to expressions).
kwargs
Non-expression arguments to the plugin function. These must be
JSON serializable.
is_elementwise
Indicate that the function operates on scalars only. This will potentially
trigger fast paths.
changes_length
Indicate that the function will change the length of the expression.
For example, a `unique` or `slice` operation.
returns_scalar
Automatically explode on unit length if the function ran as final aggregation.
This is the case for aggregations like `sum`, `min`, `covariance` etc.
cast_to_supertype
Cast the input expressions to their supertype.
input_wildcard_expansion
Expand wildcard expressions before executing the function.
pass_name_to_apply
If set to `True`, the `Series` passed to the function in a group-by operation
will ensure the name is set. This is an extra heap allocation per group.
use_abs_path
If set to `True`, the path will be resolved to an absolute path.
The path to the dynamic library is relative to the virtual environment by
default.
Returns
-------
Expr
Warnings
--------
This is highly unsafe as this will call the C function loaded by
`plugin::function_name`.
The parameters you set dictate how Polars will handle the function.
Make sure they are correct!
"""
pyexprs = parse_into_list_of_expressions(args)
serialized_kwargs = _serialize_kwargs(kwargs)
plugin_path = _resolve_plugin_path(plugin_path, use_abs_path=use_abs_path)
return wrap_expr(
plr.register_plugin_function(
plugin_path=str(plugin_path),
function_name=function_name,
args=pyexprs,
kwargs=serialized_kwargs,
is_elementwise=is_elementwise,
input_wildcard_expansion=input_wildcard_expansion,
returns_scalar=returns_scalar,
cast_to_supertype=cast_to_supertype,
pass_name_to_apply=pass_name_to_apply,
changes_length=changes_length,
)
)
def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes:
"""Serialize the function's keyword arguments."""
if not kwargs:
return b""
import pickle
# Use the highest pickle protocol supported the serde-pickle crate:
# https://docs.rs/serde-pickle/latest/serde_pickle/
return pickle.dumps(kwargs, protocol=5)
@lru_cache(maxsize=16)
def _resolve_plugin_path(path: Path | str, *, use_abs_path: bool = False) -> Path:
"""Get the file path of the dynamic library file."""
if not isinstance(path, Path):
path = Path(path)
if path.is_file():
return _resolve_file_path(path, use_abs_path=use_abs_path)
for p in path.iterdir():
if _is_dynamic_lib(p):
return _resolve_file_path(p, use_abs_path=use_abs_path)
msg = f"no dynamic library found at path: {path}"
raise FileNotFoundError(msg)
def _is_dynamic_lib(path: Path) -> bool:
return path.is_file() and path.suffix in (".so", ".dll", ".pyd")
def _resolve_file_path(path: Path, *, use_abs_path: bool = False) -> Path:
venv_path = Path(sys.prefix)
if use_abs_path:
return path.resolve()
else:
try:
file_path = path.relative_to(venv_path)
except ValueError: # Fallback
file_path = path.resolve()
return file_path