31 lines
970 B
Python
31 lines
970 B
Python
from typing import Any
|
|
|
|
from polars import DataFrame
|
|
from polars._dependencies import numpy as np
|
|
from polars._typing import IndexOrder
|
|
from polars.datatypes import Array, List
|
|
|
|
|
|
def frame_to_numpy(
|
|
df: DataFrame,
|
|
*,
|
|
writable: bool,
|
|
target: str,
|
|
order: IndexOrder = "fortran",
|
|
) -> np.ndarray[Any, Any]:
|
|
"""Convert a DataFrame to a NumPy array for use with Jax or PyTorch."""
|
|
for nm, tp in df.schema.items():
|
|
if tp == List:
|
|
msg = f"cannot convert List column {nm!r} to {target} (use Array dtype instead)"
|
|
raise TypeError(msg) from None
|
|
|
|
if df.width == 1 and df.schema.dtypes()[0] == Array:
|
|
arr = df[df.columns[0]].to_numpy(writable=writable)
|
|
else:
|
|
arr = df.to_numpy(writable=writable, order=order)
|
|
|
|
if arr.dtype == object:
|
|
msg = f"cannot convert DataFrame to {target} (mixed type columns result in `object` dtype)\n{df.schema!r}"
|
|
raise TypeError(msg)
|
|
return arr
|