DriverTrac/venv/lib/python3.12/site-packages/polars/ml/utilities.py
2025-11-28 09:08:33 +05:30

31 lines
970 B
Python

from typing import Any
from polars import DataFrame
from polars._dependencies import numpy as np
from polars._typing import IndexOrder
from polars.datatypes import Array, List
def frame_to_numpy(
df: DataFrame,
*,
writable: bool,
target: str,
order: IndexOrder = "fortran",
) -> np.ndarray[Any, Any]:
"""Convert a DataFrame to a NumPy array for use with Jax or PyTorch."""
for nm, tp in df.schema.items():
if tp == List:
msg = f"cannot convert List column {nm!r} to {target} (use Array dtype instead)"
raise TypeError(msg) from None
if df.width == 1 and df.schema.dtypes()[0] == Array:
arr = df[df.columns[0]].to_numpy(writable=writable)
else:
arr = df.to_numpy(writable=writable, order=order)
if arr.dtype == object:
msg = f"cannot convert DataFrame to {target} (mixed type columns result in `object` dtype)\n{df.schema!r}"
raise TypeError(msg)
return arr