191 lines
6.2 KiB
Python
191 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from polars.datatypes import Boolean, Categorical, Enum, String
|
|
from polars.interchange.buffer import PolarsBuffer
|
|
from polars.interchange.protocol import (
|
|
Column,
|
|
ColumnNullType,
|
|
CopyNotAllowedError,
|
|
DtypeKind,
|
|
Endianness,
|
|
)
|
|
from polars.interchange.utils import polars_dtype_to_dtype
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterator
|
|
from typing import Any
|
|
|
|
from polars import Series
|
|
from polars.interchange.protocol import CategoricalDescription, ColumnBuffers, Dtype
|
|
|
|
|
|
class PolarsColumn(Column):
|
|
"""
|
|
A column object backed by a Polars Series.
|
|
|
|
Parameters
|
|
----------
|
|
column
|
|
The Polars Series backing the column object.
|
|
allow_copy
|
|
Allow data to be copied during operations on this column. If set to `False`,
|
|
a RuntimeError will be raised if data would be copied.
|
|
"""
|
|
|
|
def __init__(self, column: Series, *, allow_copy: bool = True) -> None:
|
|
self._col = column
|
|
self._allow_copy = allow_copy
|
|
|
|
def size(self) -> int:
|
|
"""Size of the column in elements."""
|
|
return self._col.len()
|
|
|
|
@property
|
|
def offset(self) -> int:
|
|
"""Offset of the first element with respect to the start of the underlying buffer.""" # noqa: W505
|
|
if self._col.dtype == Boolean:
|
|
return self._col._get_buffer_info()[1]
|
|
else:
|
|
return 0
|
|
|
|
@property
|
|
def dtype(self) -> Dtype:
|
|
"""Data type of the column."""
|
|
pl_dtype = self._col.dtype
|
|
return polars_dtype_to_dtype(pl_dtype)
|
|
|
|
@property
|
|
def describe_categorical(self) -> CategoricalDescription:
|
|
"""
|
|
Description of the categorical data type of the column.
|
|
|
|
Raises
|
|
------
|
|
TypeError
|
|
If the data type of the column is not categorical.
|
|
"""
|
|
dtype = self._col.dtype
|
|
if dtype == Categorical:
|
|
categories = self._col.cat.get_categories()
|
|
is_ordered = False
|
|
elif dtype == Enum:
|
|
categories = dtype.categories # type: ignore[attr-defined]
|
|
is_ordered = True
|
|
else:
|
|
msg = "`describe_categorical` only works on categorical columns"
|
|
raise TypeError(msg)
|
|
|
|
return {
|
|
"is_ordered": is_ordered,
|
|
"is_dictionary": True,
|
|
"categories": PolarsColumn(categories, allow_copy=self._allow_copy),
|
|
}
|
|
|
|
@property
|
|
def describe_null(self) -> tuple[ColumnNullType, int | None]:
|
|
"""Description of the null representation the column uses."""
|
|
if self.null_count == 0:
|
|
return ColumnNullType.NON_NULLABLE, None
|
|
else:
|
|
return ColumnNullType.USE_BITMASK, 0
|
|
|
|
@property
|
|
def null_count(self) -> int:
|
|
"""The number of null elements."""
|
|
return self._col.null_count()
|
|
|
|
@property
|
|
def metadata(self) -> dict[str, Any]:
|
|
"""The metadata for the column."""
|
|
return {}
|
|
|
|
def num_chunks(self) -> int:
|
|
"""Return the number of chunks the column consists of."""
|
|
return self._col.n_chunks()
|
|
|
|
def get_chunks(self, n_chunks: int | None = None) -> Iterator[PolarsColumn]:
|
|
"""
|
|
Return an iterator yielding the column chunks.
|
|
|
|
Parameters
|
|
----------
|
|
n_chunks
|
|
The number of chunks to return. Must be a multiple of the number of chunks
|
|
in the column.
|
|
|
|
Notes
|
|
-----
|
|
When `n_chunks` is higher than the number of chunks in the column, a slice
|
|
must be performed that is not on the chunk boundary. This will trigger some
|
|
compute if the column contains null values or if the column is of data type
|
|
boolean.
|
|
"""
|
|
total_n_chunks = self.num_chunks()
|
|
chunks = self._col.get_chunks()
|
|
|
|
if (n_chunks is None) or (n_chunks == total_n_chunks):
|
|
for chunk in chunks:
|
|
yield PolarsColumn(chunk, allow_copy=self._allow_copy)
|
|
|
|
elif (n_chunks <= 0) or (n_chunks % total_n_chunks != 0):
|
|
msg = (
|
|
"`n_chunks` must be a multiple of the number of chunks of this column"
|
|
f" ({total_n_chunks})"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
else:
|
|
subchunks_per_chunk = n_chunks // total_n_chunks
|
|
for chunk in chunks:
|
|
size = len(chunk)
|
|
step = size // subchunks_per_chunk
|
|
if size % subchunks_per_chunk != 0:
|
|
step += 1
|
|
for start in range(0, step * subchunks_per_chunk, step):
|
|
yield PolarsColumn(
|
|
chunk[start : start + step], allow_copy=self._allow_copy
|
|
)
|
|
|
|
def get_buffers(self) -> ColumnBuffers:
|
|
"""Return a dictionary containing the underlying buffers."""
|
|
dtype = self._col.dtype
|
|
|
|
if dtype == String and not self._allow_copy:
|
|
msg = "string buffers must be converted"
|
|
raise CopyNotAllowedError(msg)
|
|
|
|
buffers = self._col._get_buffers()
|
|
|
|
return {
|
|
"data": self._wrap_data_buffer(buffers["values"]),
|
|
"validity": self._wrap_validity_buffer(buffers["validity"]),
|
|
"offsets": self._wrap_offsets_buffer(buffers["offsets"]),
|
|
}
|
|
|
|
def _wrap_data_buffer(self, buffer: Series) -> tuple[PolarsBuffer, Dtype]:
|
|
interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
|
|
dtype = polars_dtype_to_dtype(buffer.dtype)
|
|
return interchange_buffer, dtype
|
|
|
|
def _wrap_validity_buffer(
|
|
self, buffer: Series | None
|
|
) -> tuple[PolarsBuffer, Dtype] | None:
|
|
if buffer is None:
|
|
return None
|
|
|
|
interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
|
|
dtype = (DtypeKind.BOOL, 1, "b", Endianness.NATIVE)
|
|
return interchange_buffer, dtype
|
|
|
|
def _wrap_offsets_buffer(
|
|
self, buffer: Series | None
|
|
) -> tuple[PolarsBuffer, Dtype] | None:
|
|
if buffer is None:
|
|
return None
|
|
|
|
interchange_buffer = PolarsBuffer(buffer, allow_copy=self._allow_copy)
|
|
dtype = (DtypeKind.INT, 64, "l", Endianness.NATIVE)
|
|
return interchange_buffer, dtype
|