DriverTrac/venv/lib/python3.12/site-packages/polars/ml/torch.py

214 lines
6.7 KiB
Python

# mypy: disable-error-code="unused-ignore"
from __future__ import annotations
from typing import TYPE_CHECKING
from polars._utils.unstable import issue_unstable_warning
from polars.dataframe import DataFrame
from polars.expr import Expr
from polars.selectors import exclude
if TYPE_CHECKING:
import sys
from collections.abc import Sequence
from torch import Tensor, memory_format
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
try:
import torch
from torch.utils.data import TensorDataset
except ImportError:
msg = (
"Required package 'torch' not installed.\n"
"Please install it using the command `pip install torch`."
)
raise ImportError(msg) from None
__all__ = ["PolarsDataset"]
class PolarsDataset(TensorDataset): # type: ignore[misc]
"""
TensorDataset class specialized for use with Polars DataFrames.
.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
Parameters
----------
frame
Polars DataFrame containing the data that will be retrieved as Tensors.
label
One or more column names or expressions that label the feature data; results
in `(features,label)` tuples, where all non-label columns are considered
to be features. If no label is designated then each returned item is a
simple `(features,)` tuple containing all row elements.
features
One or more column names or expressions that represent the feature data.
If not provided, all columns not designated as labels are considered to be
features.
Notes
-----
* Integer, slice, range, integer list/Tensor Dataset indexing is all supported.
* Designating multi-element labels is also supported.
Examples
--------
>>> from torch.utils.data import DataLoader
>>> df = pl.DataFrame(
... data=[
... (0, 1, 1.5),
... (1, 0, -0.5),
... (2, 0, 0.0),
... (3, 1, -2.25),
... ],
... schema=["lbl", "feat1", "feat2"],
... orient="row",
... )
Create a Dataset from a Polars DataFrame, standardising the dtype and
separating the label/feature columns.
>>> ds = df.to_torch("dataset", label="lbl", dtype=pl.Float32)
>>> ds # doctest: +IGNORE_RESULT
<PolarsDataset [len:4, features:2, labels:1] at 0x156B033B0>
>>> ds.features
tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000],
[ 0.0000, 0.0000],
[ 1.0000, -2.2500]])
>>> ds[0]
(tensor([1.0000, 1.5000]), tensor(0.))
The Dataset can be used standalone, or in conjunction with a DataLoader.
>>> dl = DataLoader(ds, batch_size=2)
>>> list(dl)
[[tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]),
tensor([0., 1.])],
[tensor([[ 0.0000, 0.0000],
[ 1.0000, -2.2500]]),
tensor([2., 3.])]]
Note that the label can be given as an expression as well as a column name,
allowing for independent transform and dtype adjustment from the feature
columns.
>>> ds = df.to_torch(
... "dataset",
... dtype=pl.Float32,
... label=(pl.col("lbl") * 8).cast(pl.Int16),
... )
>>> ds[:2]
(tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]), tensor([0, 8], dtype=torch.int16))
"""
tensors: tuple[Tensor, ...]
labels: Tensor | None
features: Tensor
def __init__(
self,
frame: DataFrame,
*,
label: str | Expr | Sequence[str | Expr] | None = None,
features: str | Expr | Sequence[str | Expr] | None = None,
) -> None:
issue_unstable_warning("`PolarsDataset` is considered unstable.")
if isinstance(label, (str, Expr)):
label = [label]
label_frame: DataFrame | None = None
if not label:
feature_frame = frame.select(features) if features else frame
self.features = feature_frame.to_torch()
self.tensors = (self.features,)
self.labels = None
else:
label_frame = frame.select(*label)
self.labels = ( # type: ignore[attr-defined]
label_frame if len(label) > 1 else label_frame.to_series()
).to_torch()
feature_frame = frame.select(
features
if (isinstance(features, Expr) or features)
else exclude(label_frame.columns)
)
self.features = feature_frame.to_torch()
self.tensors = (self.features, self.labels) # type: ignore[assignment]
self._n_labels = 0 if (label_frame is None) else label_frame.width
self._n_features = feature_frame.width
def __copy__(self) -> Self:
"""Return a shallow copy of this PolarsDataset."""
dummy_frame = DataFrame({"blank": [0]})
dataset_copy = self.__class__(dummy_frame)
for attr in (
"tensors",
"labels",
"features",
"_n_labels",
"_n_features",
):
setattr(dataset_copy, attr, getattr(self, attr))
return dataset_copy
def __repr__(self) -> str:
"""Return a string representation of the PolarsDataset."""
return (
f"<{type(self).__name__} "
f"[len:{len(self)},"
f" features:{self._n_features},"
f" labels:{self._n_labels}"
f"] at 0x{id(self):X}>"
)
def half(
self,
*,
features: bool = True,
labels: bool = True,
memory_format: memory_format = torch.preserve_format,
) -> Self:
"""
Return a copy of this PolarsDataset with the numeric data converted to f16.
Parameters
----------
features
Convert feature data to half precision (f16).
labels
Convert label data to half precision (f16).
memory_format
Desired memory format for the modified tensors.
"""
ds = self.__copy__()
if features:
ds.features = self.features.to(torch.float16, memory_format=memory_format)
if self.labels is not None:
if labels:
ds.labels = self.labels.to(torch.float16, memory_format=memory_format)
ds.tensors = (ds.features, ds.labels) # type: ignore[assignment]
else:
ds.tensors = (ds.features,)
return ds
@property
def schema(self) -> dict[str, torch.dtype | None]:
"""Return the features/labels schema."""
return {
"features": self.features.dtype,
"labels": self.labels.dtype if self.labels is not None else None,
}