DriverTrac/venv/lib/python3.12/site-packages/polars/io/database/_inference.py
2025-11-28 09:08:33 +05:30

315 lines
9.2 KiB
Python

from __future__ import annotations
import functools
import re
from contextlib import suppress
from inspect import isclass
from typing import TYPE_CHECKING, Any
from polars.datatypes import (
Binary,
Boolean,
Date,
Datetime,
Decimal,
Duration,
Float32,
Float64,
Int8,
Int16,
Int32,
Int64,
Int128,
List,
Null,
String,
Time,
UInt8,
UInt16,
UInt32,
UInt64,
)
from polars.datatypes._parse import parse_py_type_into_dtype
from polars.datatypes.group import (
INTEGER_DTYPES,
UNSIGNED_INTEGER_DTYPES,
)
if TYPE_CHECKING:
from polars._typing import PolarsDataType
def dtype_from_database_typename(
value: str,
*,
raise_unmatched: bool = True,
) -> PolarsDataType | None:
"""
Attempt to infer Polars dtype from database cursor `type_code` string value.
Examples
--------
>>> dtype_from_database_typename("INT2")
Int16
>>> dtype_from_database_typename("NVARCHAR")
String
>>> dtype_from_database_typename("NUMERIC(10,2)")
Decimal(precision=10, scale=2)
>>> dtype_from_database_typename("TIMESTAMP WITHOUT TZ")
Datetime(time_unit='us', time_zone=None)
"""
dtype: PolarsDataType | None = None
# normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
original_value = value
value = value.upper().replace("TYPE", "")
# extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
if re.search(r"\([\w,: ]+\)$", value):
modifier = value[value.find("(") + 1 : -1]
value = value.split("(")[0]
elif (
not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
modifier = value[value.find("[") + 1 : -1]
value = value.split("[")[0]
else:
modifier = ""
# array dtypes
array_aliases = ("ARRAY", "LIST", "[]")
if value.endswith(array_aliases) or value.startswith(array_aliases):
for a in array_aliases:
value = value.replace(a, "", 1) if value else ""
nested: PolarsDataType | None = None
if not value and modifier:
nested = dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
else:
if inner_value := dtype_from_database_typename(
value[1:-1]
if (value[0], value[-1]) == ("<", ">")
else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
raise_unmatched=False,
):
nested = inner_value
elif modifier:
nested = dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
if nested:
dtype = List(nested)
# float dtypes
elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
dtype = (
Float32
if value == "FLOAT4"
or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
else Float64
)
# integer dtypes
elif ("INTERVAL" not in value) and (
value.startswith(("INT", "UINT", "UNSIGNED"))
or value.endswith(("INT", "SERIAL"))
or ("INTEGER" in value)
or value in ("TINY", "SHORT", "LONG", "LONGLONG", "ROWID")
):
sz: Any
if "HUGEINT" in value:
sz = 128
elif (
"LARGE" in value or value.startswith("BIG") or value in ("INT8", "LONGLONG")
):
sz = 64
elif "MEDIUM" in value or value in ("INT4", "UINT4", "LONG", "SERIAL"):
sz = 32
elif "SMALL" in value or value in ("INT2", "UINT2", "SHORT"):
sz = 16
elif "TINY" in value:
sz = 8
elif n := re.sub(r"^\D+", "", value):
if (sz := int(n)) <= 8:
sz = sz * 8
else:
sz = None
sz = modifier if (not sz and modifier) else sz
if not isinstance(sz, int):
sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
if (
("U" in value and "MEDIUM" not in value)
or ("UNSIGNED" in value)
or value == "ROWID"
):
dtype = integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
else:
dtype = integer_dtype_from_nbits(sz, unsigned=False, default=Int64)
# number types (note: 'number' alone is not that helpful and requires refinement)
elif "NUMBER" in value and "CARDINAL" in value:
dtype = UInt64
# decimal dtypes
elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
if "," in modifier:
prec, scale = modifier.split(",")
dtype = Decimal(int(prec), int(scale))
else:
dtype = Decimal if is_dec else Float64
# string dtypes
elif (
any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
or value.startswith(("STR", "CHAR", "BPCHAR", "NCHAR", "UTF"))
or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
):
dtype = String
# binary dtypes
elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
dtype = Binary
# boolean dtypes
elif value.startswith("BOOL"):
dtype = Boolean
# null dtype; odd, but valid
elif value == "NULL":
dtype = Null
# temporal dtypes
elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
if "WITHOUT" not in value:
return None # there's a timezone, but we don't know what it is
unit = timeunit_from_precision(modifier) if modifier else "us"
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]
else:
value = re.sub(r"\d", "", value)
if value in ("INTERVAL", "TIMEDELTA", "DURATION"):
dtype = Duration
elif value == "DATE":
dtype = Date
elif value == "TIME":
dtype = Time
if not dtype and raise_unmatched:
msg = f"cannot infer dtype from {original_value!r} string value"
raise ValueError(msg)
return dtype
def dtype_from_cursor_description(
cursor: Any,
description: tuple[Any, ...],
) -> PolarsDataType | None:
"""Attempt to infer Polars dtype from database cursor description `type_code`."""
type_code, _disp_size, internal_size, precision, scale, *_ = description
dtype: PolarsDataType | None = None
if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = parse_py_type_into_dtype(type_code) # type: ignore[arg-type]
elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)
# check additional cursor attrs to refine dtype specification
if dtype is not None:
if dtype == Float64 and internal_size == 4:
dtype = Float32
elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(precision, int)
and isinstance(scale, int)
and precision <= 38
and scale <= 38
):
dtype = Decimal(precision, scale)
return dtype
@functools.lru_cache(8)
def integer_dtype_from_nbits(
bits: int,
*,
unsigned: bool,
default: PolarsDataType | None = None,
) -> PolarsDataType | None:
"""
Return matching Polars integer dtype from num bits and signed/unsigned flag.
Examples
--------
>>> integer_dtype_from_nbits(8, unsigned=False)
Int8
>>> integer_dtype_from_nbits(32, unsigned=True)
UInt32
"""
dtype = {
(8, False): Int8,
(8, True): UInt8,
(16, False): Int16,
(16, True): UInt16,
(32, False): Int32,
(32, True): UInt32,
(64, False): Int64,
(64, True): UInt64,
(128, False): Int128,
(128, True): Int128, # UInt128 not (yet?) supported
}.get((bits, unsigned), None)
if dtype is None and default is not None:
return default
return dtype
def timeunit_from_precision(precision: int | str | None) -> str | None:
"""
Return `time_unit` from integer precision value.
Examples
--------
>>> timeunit_from_precision(3)
'ms'
>>> timeunit_from_precision(5)
'us'
>>> timeunit_from_precision(7)
'ns'
"""
from math import ceil
if not precision:
return None
elif isinstance(precision, str):
if precision.isdigit():
precision = int(precision)
elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
return "ms" if precision == "s" else precision
try:
n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
return {3: "ms", 6: "us", 9: "ns"}.get(n)
except TypeError:
return None