171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
|
|
from polars.datatypes import (
|
|
Boolean,
|
|
Categorical,
|
|
Date,
|
|
Datetime,
|
|
Duration,
|
|
Enum,
|
|
Float32,
|
|
Float64,
|
|
Int8,
|
|
Int16,
|
|
Int32,
|
|
Int64,
|
|
String,
|
|
Time,
|
|
UInt8,
|
|
UInt16,
|
|
UInt32,
|
|
UInt64,
|
|
)
|
|
from polars.interchange.protocol import DtypeKind, Endianness
|
|
|
|
if TYPE_CHECKING:
|
|
from polars._typing import PolarsDataType
|
|
from polars.datatypes import DataTypeClass
|
|
from polars.interchange.protocol import Dtype
|
|
|
|
NE = Endianness.NATIVE
|
|
|
|
polars_dtype_to_dtype_map: dict[DataTypeClass, Dtype] = {
|
|
Int8: (DtypeKind.INT, 8, "c", NE),
|
|
Int16: (DtypeKind.INT, 16, "s", NE),
|
|
Int32: (DtypeKind.INT, 32, "i", NE),
|
|
Int64: (DtypeKind.INT, 64, "l", NE),
|
|
UInt8: (DtypeKind.UINT, 8, "C", NE),
|
|
UInt16: (DtypeKind.UINT, 16, "S", NE),
|
|
UInt32: (DtypeKind.UINT, 32, "I", NE),
|
|
UInt64: (DtypeKind.UINT, 64, "L", NE),
|
|
Float32: (DtypeKind.FLOAT, 32, "f", NE),
|
|
Float64: (DtypeKind.FLOAT, 64, "g", NE),
|
|
Boolean: (DtypeKind.BOOL, 1, "b", NE),
|
|
String: (DtypeKind.STRING, 8, "U", NE),
|
|
Date: (DtypeKind.DATETIME, 32, "tdD", NE),
|
|
Time: (DtypeKind.DATETIME, 64, "ttu", NE),
|
|
Datetime: (DtypeKind.DATETIME, 64, "tsu:", NE),
|
|
Duration: (DtypeKind.DATETIME, 64, "tDu", NE),
|
|
Categorical: (DtypeKind.CATEGORICAL, 32, "I", NE),
|
|
Enum: (DtypeKind.CATEGORICAL, 32, "I", NE),
|
|
}
|
|
|
|
|
|
def polars_dtype_to_dtype(dtype: PolarsDataType) -> Dtype:
|
|
"""Convert Polars data type to interchange protocol data type."""
|
|
try:
|
|
result = polars_dtype_to_dtype_map[dtype.base_type()]
|
|
except KeyError as exc:
|
|
msg = f"data type {dtype!r} not supported by the interchange protocol"
|
|
raise ValueError(msg) from exc
|
|
|
|
# Handle instantiated data types
|
|
if isinstance(dtype, Datetime):
|
|
return _datetime_to_dtype(dtype)
|
|
elif isinstance(dtype, Duration):
|
|
return _duration_to_dtype(dtype)
|
|
|
|
return result
|
|
|
|
|
|
def _datetime_to_dtype(dtype: Datetime) -> Dtype:
|
|
tu = dtype.time_unit[0]
|
|
tz = dtype.time_zone if dtype.time_zone is not None else ""
|
|
arrow_c_type = f"ts{tu}:{tz}"
|
|
return DtypeKind.DATETIME, 64, arrow_c_type, NE
|
|
|
|
|
|
def _duration_to_dtype(dtype: Duration) -> Dtype:
|
|
tu = dtype.time_unit[0]
|
|
arrow_c_type = f"tD{tu}"
|
|
return DtypeKind.DATETIME, 64, arrow_c_type, NE
|
|
|
|
|
|
dtype_to_polars_dtype_map: dict[DtypeKind, dict[int, PolarsDataType]] = {
|
|
DtypeKind.INT: {
|
|
8: Int8,
|
|
16: Int16,
|
|
32: Int32,
|
|
64: Int64,
|
|
},
|
|
DtypeKind.UINT: {
|
|
8: UInt8,
|
|
16: UInt16,
|
|
32: UInt32,
|
|
64: UInt64,
|
|
},
|
|
DtypeKind.FLOAT: {
|
|
32: Float32,
|
|
64: Float64,
|
|
},
|
|
DtypeKind.BOOL: {
|
|
1: Boolean,
|
|
8: Boolean,
|
|
},
|
|
DtypeKind.STRING: {8: String},
|
|
}
|
|
|
|
|
|
def dtype_to_polars_dtype(dtype: Dtype) -> PolarsDataType:
|
|
"""Convert interchange protocol data type to Polars data type."""
|
|
kind, bit_width, format_str, _ = dtype
|
|
|
|
if kind == DtypeKind.DATETIME:
|
|
return _temporal_dtype_to_polars_dtype(format_str, dtype)
|
|
elif kind == DtypeKind.CATEGORICAL:
|
|
return Enum
|
|
|
|
try:
|
|
return dtype_to_polars_dtype_map[kind][bit_width]
|
|
except KeyError as exc:
|
|
msg = f"unsupported data type: {dtype!r}"
|
|
raise NotImplementedError(msg) from exc
|
|
|
|
|
|
def _temporal_dtype_to_polars_dtype(format_str: str, dtype: Dtype) -> PolarsDataType:
|
|
if (match := re.fullmatch(r"ts([mun]):(.*)", format_str)) is not None:
|
|
time_unit = match.group(1) + "s"
|
|
time_zone = match.group(2) or None
|
|
return Datetime(
|
|
time_unit=time_unit, # type: ignore[arg-type]
|
|
time_zone=time_zone,
|
|
)
|
|
elif format_str == "tdD":
|
|
return Date
|
|
elif format_str == "ttu":
|
|
return Time
|
|
elif (match := re.fullmatch(r"tD([mun])", format_str)) is not None:
|
|
time_unit = match.group(1) + "s"
|
|
return Duration(time_unit=time_unit) # type: ignore[arg-type]
|
|
|
|
msg = f"unsupported temporal data type: {dtype!r}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
def get_buffer_length_in_elements(buffer_size: int, dtype: Dtype) -> int:
|
|
"""Get the length of a buffer in elements."""
|
|
bits_per_element = dtype[1]
|
|
bytes_per_element, rest = divmod(bits_per_element, 8)
|
|
if rest > 0:
|
|
msg = f"cannot get buffer length for buffer with dtype {dtype!r}"
|
|
raise ValueError(msg)
|
|
return buffer_size // bytes_per_element
|
|
|
|
|
|
def polars_dtype_to_data_buffer_dtype(dtype: PolarsDataType) -> PolarsDataType:
|
|
"""Get the data type of the data buffer."""
|
|
if dtype.is_integer() or dtype.is_float() or dtype == Boolean:
|
|
return dtype
|
|
elif dtype.is_temporal():
|
|
return Int32 if dtype == Date else Int64
|
|
elif dtype == String:
|
|
return UInt8
|
|
elif dtype in (Enum, Categorical):
|
|
return UInt32
|
|
|
|
msg = f"unsupported data type: {dtype}"
|
|
raise NotImplementedError(msg)
|