266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import sys
|
|
from collections import OrderedDict
|
|
from collections.abc import Mapping
|
|
from typing import TYPE_CHECKING, Literal, Union, overload
|
|
|
|
from polars._typing import PythonDataType
|
|
from polars._utils.unstable import unstable
|
|
from polars.datatypes import DataType, DataTypeClass, is_polars_dtype
|
|
from polars.datatypes._parse import parse_into_dtype
|
|
from polars.exceptions import DuplicateError
|
|
from polars.interchange.protocol import CompatLevel
|
|
|
|
with contextlib.suppress(ImportError): # Module not available when building docs
|
|
from polars._plr import (
|
|
init_polars_schema_from_arrow_c_schema,
|
|
polars_schema_field_from_arrow_c_schema,
|
|
polars_schema_to_pycapsule,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable
|
|
|
|
from polars import DataFrame, LazyFrame
|
|
from polars._typing import ArrowSchemaExportable
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from typing import TypeAlias
|
|
else:
|
|
from typing_extensions import TypeAlias
|
|
|
|
if sys.version_info >= (3, 10):
|
|
|
|
def _required_init_args(tp: DataTypeClass) -> bool:
|
|
# note: this check is ~20% faster than the check for a
|
|
# custom "__init__", below, but is not available on py39
|
|
return bool(tp.__annotations__)
|
|
else:
|
|
|
|
def _required_init_args(tp: DataTypeClass) -> bool:
|
|
# indicates override of the default __init__
|
|
# (eg: this type requires specific args)
|
|
return "__init__" in tp.__dict__
|
|
|
|
|
|
BaseSchema = OrderedDict[str, DataType]
|
|
SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType]
|
|
|
|
__all__ = ["Schema"]
|
|
|
|
|
|
def _check_dtype(tp: DataType | DataTypeClass) -> DataType:
|
|
if not isinstance(tp, DataType):
|
|
# note: if nested/decimal, or has signature params, this implies required args
|
|
if tp.is_nested() or tp.is_decimal() or _required_init_args(tp):
|
|
msg = f"dtypes must be fully-specified, got: {tp!r}"
|
|
raise TypeError(msg)
|
|
tp = tp()
|
|
return tp # type: ignore[return-value]
|
|
|
|
|
|
class Schema(BaseSchema):
|
|
"""
|
|
Ordered mapping of column names to their data type.
|
|
|
|
Parameters
|
|
----------
|
|
schema
|
|
The schema definition given by column names and their associated
|
|
Polars data type. Accepts a mapping, or an iterable of tuples, or any
|
|
object implementing the `__arrow_c_schema__` PyCapsule interface
|
|
(e.g. pyarrow schemas).
|
|
|
|
Examples
|
|
--------
|
|
Define a schema by passing instantiated data types.
|
|
|
|
>>> schema = pl.Schema(
|
|
... {
|
|
... "foo": pl.String(),
|
|
... "bar": pl.Duration("us"),
|
|
... "baz": pl.Array(pl.Int8, 4),
|
|
... }
|
|
... )
|
|
>>> schema
|
|
Schema({'foo': String, 'bar': Duration(time_unit='us'), 'baz': Array(Int8, shape=(4,))})
|
|
|
|
Access the data type associated with a specific column name.
|
|
|
|
>>> schema["baz"]
|
|
Array(Int8, shape=(4,))
|
|
|
|
Access various schema properties using the `names`, `dtypes`, and `len` methods.
|
|
|
|
>>> schema.names()
|
|
['foo', 'bar', 'baz']
|
|
>>> schema.dtypes()
|
|
[String, Duration(time_unit='us'), Array(Int8, shape=(4,))]
|
|
>>> schema.len()
|
|
3
|
|
|
|
Import a pyarrow schema.
|
|
|
|
>>> import pyarrow as pa
|
|
>>> pl.Schema(pa.schema([pa.field("x", pa.int32())]))
|
|
Schema({'x': Int32})
|
|
|
|
Export a schema to pyarrow.
|
|
|
|
>>> pa.schema(pl.Schema({"x": pl.Int32}))
|
|
x: int32
|
|
""" # noqa: W505
|
|
|
|
def __init__(
|
|
self,
|
|
schema: (
|
|
Mapping[str, SchemaInitDataType]
|
|
| Iterable[tuple[str, SchemaInitDataType] | ArrowSchemaExportable]
|
|
| ArrowSchemaExportable
|
|
| None
|
|
) = None,
|
|
*,
|
|
check_dtypes: bool = True,
|
|
) -> None:
|
|
if hasattr(schema, "__arrow_c_schema__") and not isinstance(schema, Schema):
|
|
init_polars_schema_from_arrow_c_schema(self, schema)
|
|
return
|
|
|
|
input = schema.items() if isinstance(schema, Mapping) else (schema or ())
|
|
for v in input:
|
|
name, tp = (
|
|
polars_schema_field_from_arrow_c_schema(v)
|
|
if hasattr(v, "__arrow_c_schema__") and not isinstance(v, DataType)
|
|
else v
|
|
)
|
|
|
|
if name in self:
|
|
msg = f"iterable passed to pl.Schema contained duplicate name '{name}'"
|
|
raise DuplicateError(msg)
|
|
|
|
if not check_dtypes:
|
|
super().__setitem__(name, tp) # type: ignore[assignment]
|
|
elif is_polars_dtype(tp):
|
|
super().__setitem__(name, _check_dtype(tp))
|
|
else:
|
|
self[name] = tp
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, Mapping):
|
|
return False
|
|
if len(self) != len(other):
|
|
return False
|
|
for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()):
|
|
if nm1 != nm2 or not tp1.is_(tp2):
|
|
return False
|
|
return True
|
|
|
|
def __ne__(self, other: object) -> bool:
|
|
return not self.__eq__(other)
|
|
|
|
def __setitem__(
|
|
self, name: str, dtype: DataType | DataTypeClass | PythonDataType
|
|
) -> None:
|
|
dtype = _check_dtype(parse_into_dtype(dtype))
|
|
super().__setitem__(name, dtype)
|
|
|
|
@unstable()
|
|
def __arrow_c_schema__(self) -> object:
|
|
"""
|
|
Export a Schema via the Arrow PyCapsule Interface.
|
|
|
|
https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
|
|
"""
|
|
return polars_schema_to_pycapsule(self, CompatLevel.newest()._version)
|
|
|
|
def names(self) -> list[str]:
|
|
"""
|
|
Get the column names of the schema.
|
|
|
|
Examples
|
|
--------
|
|
>>> s = pl.Schema({"x": pl.Float64(), "y": pl.Datetime(time_zone="UTC")})
|
|
>>> s.names()
|
|
['x', 'y']
|
|
"""
|
|
return list(self.keys())
|
|
|
|
def dtypes(self) -> list[DataType]:
|
|
"""
|
|
Get the data types of the schema.
|
|
|
|
Examples
|
|
--------
|
|
>>> s = pl.Schema({"x": pl.UInt8(), "y": pl.List(pl.UInt8)})
|
|
>>> s.dtypes()
|
|
[UInt8, List(UInt8)]
|
|
"""
|
|
return list(self.values())
|
|
|
|
@overload
|
|
def to_frame(self, *, eager: Literal[False]) -> LazyFrame: ...
|
|
|
|
@overload
|
|
def to_frame(self, *, eager: Literal[True] = ...) -> DataFrame: ...
|
|
|
|
def to_frame(self, *, eager: bool = True) -> DataFrame | LazyFrame:
|
|
"""
|
|
Create an empty DataFrame (or LazyFrame) from this Schema.
|
|
|
|
Parameters
|
|
----------
|
|
eager
|
|
If True, create a DataFrame; otherwise, create a LazyFrame.
|
|
|
|
Examples
|
|
--------
|
|
>>> s = pl.Schema({"x": pl.Int32(), "y": pl.String()})
|
|
>>> s.to_frame()
|
|
shape: (0, 2)
|
|
┌─────┬─────┐
|
|
│ x ┆ y │
|
|
│ --- ┆ --- │
|
|
│ i32 ┆ str │
|
|
╞═════╪═════╡
|
|
└─────┴─────┘
|
|
>>> s.to_frame(eager=False) # doctest: +IGNORE_RESULT
|
|
<LazyFrame at 0x11BC0AD80>
|
|
"""
|
|
from polars import DataFrame, LazyFrame
|
|
|
|
return DataFrame(schema=self) if eager else LazyFrame(schema=self)
|
|
|
|
def len(self) -> int:
|
|
"""
|
|
Get the number of schema entries.
|
|
|
|
Examples
|
|
--------
|
|
>>> s = pl.Schema({"x": pl.Int32(), "y": pl.List(pl.String)})
|
|
>>> s.len()
|
|
2
|
|
>>> len(s)
|
|
2
|
|
"""
|
|
return len(self)
|
|
|
|
def to_python(self) -> dict[str, type]:
|
|
"""
|
|
Return a dictionary of column names and Python types.
|
|
|
|
Examples
|
|
--------
|
|
>>> s = pl.Schema(
|
|
... {
|
|
... "x": pl.Int8(),
|
|
... "y": pl.String(),
|
|
... "z": pl.Duration("us"),
|
|
... }
|
|
... )
|
|
>>> s.to_python()
|
|
{'x': <class 'int'>, 'y': <class 'str'>, 'z': <class 'datetime.timedelta'>}
|
|
"""
|
|
return {name: tp.to_python() for name, tp in self.items()}
|