315 lines
9.2 KiB
Python
315 lines
9.2 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import re
|
|
from contextlib import suppress
|
|
from inspect import isclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from polars.datatypes import (
|
|
Binary,
|
|
Boolean,
|
|
Date,
|
|
Datetime,
|
|
Decimal,
|
|
Duration,
|
|
Float32,
|
|
Float64,
|
|
Int8,
|
|
Int16,
|
|
Int32,
|
|
Int64,
|
|
Int128,
|
|
List,
|
|
Null,
|
|
String,
|
|
Time,
|
|
UInt8,
|
|
UInt16,
|
|
UInt32,
|
|
UInt64,
|
|
)
|
|
from polars.datatypes._parse import parse_py_type_into_dtype
|
|
from polars.datatypes.group import (
|
|
INTEGER_DTYPES,
|
|
UNSIGNED_INTEGER_DTYPES,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from polars._typing import PolarsDataType
|
|
|
|
|
|
def dtype_from_database_typename(
|
|
value: str,
|
|
*,
|
|
raise_unmatched: bool = True,
|
|
) -> PolarsDataType | None:
|
|
"""
|
|
Attempt to infer Polars dtype from database cursor `type_code` string value.
|
|
|
|
Examples
|
|
--------
|
|
>>> dtype_from_database_typename("INT2")
|
|
Int16
|
|
>>> dtype_from_database_typename("NVARCHAR")
|
|
String
|
|
>>> dtype_from_database_typename("NUMERIC(10,2)")
|
|
Decimal(precision=10, scale=2)
|
|
>>> dtype_from_database_typename("TIMESTAMP WITHOUT TZ")
|
|
Datetime(time_unit='us', time_zone=None)
|
|
"""
|
|
dtype: PolarsDataType | None = None
|
|
|
|
# normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
|
|
original_value = value
|
|
value = value.upper().replace("TYPE", "")
|
|
|
|
# extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
|
|
if re.search(r"\([\w,: ]+\)$", value):
|
|
modifier = value[value.find("(") + 1 : -1]
|
|
value = value.split("(")[0]
|
|
elif (
|
|
not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
|
|
) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
|
|
modifier = value[value.find("[") + 1 : -1]
|
|
value = value.split("[")[0]
|
|
else:
|
|
modifier = ""
|
|
|
|
# array dtypes
|
|
array_aliases = ("ARRAY", "LIST", "[]")
|
|
if value.endswith(array_aliases) or value.startswith(array_aliases):
|
|
for a in array_aliases:
|
|
value = value.replace(a, "", 1) if value else ""
|
|
|
|
nested: PolarsDataType | None = None
|
|
if not value and modifier:
|
|
nested = dtype_from_database_typename(
|
|
value=modifier,
|
|
raise_unmatched=False,
|
|
)
|
|
else:
|
|
if inner_value := dtype_from_database_typename(
|
|
value[1:-1]
|
|
if (value[0], value[-1]) == ("<", ">")
|
|
else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
|
|
raise_unmatched=False,
|
|
):
|
|
nested = inner_value
|
|
elif modifier:
|
|
nested = dtype_from_database_typename(
|
|
value=modifier,
|
|
raise_unmatched=False,
|
|
)
|
|
if nested:
|
|
dtype = List(nested)
|
|
|
|
# float dtypes
|
|
elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
|
|
dtype = (
|
|
Float32
|
|
if value == "FLOAT4"
|
|
or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
|
|
else Float64
|
|
)
|
|
|
|
# integer dtypes
|
|
elif ("INTERVAL" not in value) and (
|
|
value.startswith(("INT", "UINT", "UNSIGNED"))
|
|
or value.endswith(("INT", "SERIAL"))
|
|
or ("INTEGER" in value)
|
|
or value in ("TINY", "SHORT", "LONG", "LONGLONG", "ROWID")
|
|
):
|
|
sz: Any
|
|
if "HUGEINT" in value:
|
|
sz = 128
|
|
elif (
|
|
"LARGE" in value or value.startswith("BIG") or value in ("INT8", "LONGLONG")
|
|
):
|
|
sz = 64
|
|
elif "MEDIUM" in value or value in ("INT4", "UINT4", "LONG", "SERIAL"):
|
|
sz = 32
|
|
elif "SMALL" in value or value in ("INT2", "UINT2", "SHORT"):
|
|
sz = 16
|
|
elif "TINY" in value:
|
|
sz = 8
|
|
elif n := re.sub(r"^\D+", "", value):
|
|
if (sz := int(n)) <= 8:
|
|
sz = sz * 8
|
|
else:
|
|
sz = None
|
|
|
|
sz = modifier if (not sz and modifier) else sz
|
|
if not isinstance(sz, int):
|
|
sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
|
|
if (
|
|
("U" in value and "MEDIUM" not in value)
|
|
or ("UNSIGNED" in value)
|
|
or value == "ROWID"
|
|
):
|
|
dtype = integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
|
|
else:
|
|
dtype = integer_dtype_from_nbits(sz, unsigned=False, default=Int64)
|
|
|
|
# number types (note: 'number' alone is not that helpful and requires refinement)
|
|
elif "NUMBER" in value and "CARDINAL" in value:
|
|
dtype = UInt64
|
|
|
|
# decimal dtypes
|
|
elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
|
|
if "," in modifier:
|
|
prec, scale = modifier.split(",")
|
|
dtype = Decimal(int(prec), int(scale))
|
|
else:
|
|
dtype = Decimal if is_dec else Float64
|
|
|
|
# string dtypes
|
|
elif (
|
|
any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
|
|
or value.startswith(("STR", "CHAR", "BPCHAR", "NCHAR", "UTF"))
|
|
or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
|
|
):
|
|
dtype = String
|
|
|
|
# binary dtypes
|
|
elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
|
|
dtype = Binary
|
|
|
|
# boolean dtypes
|
|
elif value.startswith("BOOL"):
|
|
dtype = Boolean
|
|
|
|
# null dtype; odd, but valid
|
|
elif value == "NULL":
|
|
dtype = Null
|
|
|
|
# temporal dtypes
|
|
elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
|
|
if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
|
|
if "WITHOUT" not in value:
|
|
return None # there's a timezone, but we don't know what it is
|
|
unit = timeunit_from_precision(modifier) if modifier else "us"
|
|
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]
|
|
else:
|
|
value = re.sub(r"\d", "", value)
|
|
if value in ("INTERVAL", "TIMEDELTA", "DURATION"):
|
|
dtype = Duration
|
|
elif value == "DATE":
|
|
dtype = Date
|
|
elif value == "TIME":
|
|
dtype = Time
|
|
|
|
if not dtype and raise_unmatched:
|
|
msg = f"cannot infer dtype from {original_value!r} string value"
|
|
raise ValueError(msg)
|
|
|
|
return dtype
|
|
|
|
|
|
def dtype_from_cursor_description(
|
|
cursor: Any,
|
|
description: tuple[Any, ...],
|
|
) -> PolarsDataType | None:
|
|
"""Attempt to infer Polars dtype from database cursor description `type_code`."""
|
|
type_code, _disp_size, internal_size, precision, scale, *_ = description
|
|
dtype: PolarsDataType | None = None
|
|
|
|
if isclass(type_code):
|
|
# python types, eg: int, float, str, etc
|
|
with suppress(TypeError):
|
|
dtype = parse_py_type_into_dtype(type_code) # type: ignore[arg-type]
|
|
|
|
elif isinstance(type_code, str):
|
|
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
|
|
dtype = dtype_from_database_typename(
|
|
value=type_code,
|
|
raise_unmatched=False,
|
|
)
|
|
|
|
# check additional cursor attrs to refine dtype specification
|
|
if dtype is not None:
|
|
if dtype == Float64 and internal_size == 4:
|
|
dtype = Float32
|
|
|
|
elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
|
|
bits = internal_size * 8
|
|
dtype = integer_dtype_from_nbits(
|
|
bits,
|
|
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
|
|
default=dtype,
|
|
)
|
|
elif (
|
|
dtype == Decimal
|
|
and isinstance(precision, int)
|
|
and isinstance(scale, int)
|
|
and precision <= 38
|
|
and scale <= 38
|
|
):
|
|
dtype = Decimal(precision, scale)
|
|
|
|
return dtype
|
|
|
|
|
|
@functools.lru_cache(8)
|
|
def integer_dtype_from_nbits(
|
|
bits: int,
|
|
*,
|
|
unsigned: bool,
|
|
default: PolarsDataType | None = None,
|
|
) -> PolarsDataType | None:
|
|
"""
|
|
Return matching Polars integer dtype from num bits and signed/unsigned flag.
|
|
|
|
Examples
|
|
--------
|
|
>>> integer_dtype_from_nbits(8, unsigned=False)
|
|
Int8
|
|
>>> integer_dtype_from_nbits(32, unsigned=True)
|
|
UInt32
|
|
"""
|
|
dtype = {
|
|
(8, False): Int8,
|
|
(8, True): UInt8,
|
|
(16, False): Int16,
|
|
(16, True): UInt16,
|
|
(32, False): Int32,
|
|
(32, True): UInt32,
|
|
(64, False): Int64,
|
|
(64, True): UInt64,
|
|
(128, False): Int128,
|
|
(128, True): Int128, # UInt128 not (yet?) supported
|
|
}.get((bits, unsigned), None)
|
|
|
|
if dtype is None and default is not None:
|
|
return default
|
|
return dtype
|
|
|
|
|
|
def timeunit_from_precision(precision: int | str | None) -> str | None:
|
|
"""
|
|
Return `time_unit` from integer precision value.
|
|
|
|
Examples
|
|
--------
|
|
>>> timeunit_from_precision(3)
|
|
'ms'
|
|
>>> timeunit_from_precision(5)
|
|
'us'
|
|
>>> timeunit_from_precision(7)
|
|
'ns'
|
|
"""
|
|
from math import ceil
|
|
|
|
if not precision:
|
|
return None
|
|
elif isinstance(precision, str):
|
|
if precision.isdigit():
|
|
precision = int(precision)
|
|
elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
|
|
return "ms" if precision == "s" else precision
|
|
try:
|
|
n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
|
|
return {3: "ms", 6: "us", 9: "ns"}.get(n)
|
|
except TypeError:
|
|
return None
|