466 lines
13 KiB
Python
466 lines
13 KiB
Python
"""Strategies for generating various forms of data."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import decimal
|
|
from collections.abc import Mapping
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
from zoneinfo import ZoneInfo
|
|
|
|
import hypothesis.strategies as st
|
|
from hypothesis.errors import InvalidArgument
|
|
|
|
from polars._utils.constants import (
|
|
EPOCH,
|
|
I8_MAX,
|
|
I8_MIN,
|
|
I16_MAX,
|
|
I16_MIN,
|
|
I32_MAX,
|
|
I32_MIN,
|
|
I64_MAX,
|
|
I64_MIN,
|
|
I128_MAX,
|
|
I128_MIN,
|
|
U8_MAX,
|
|
U16_MAX,
|
|
U32_MAX,
|
|
U64_MAX,
|
|
U128_MAX,
|
|
)
|
|
from polars.datatypes import (
|
|
Array,
|
|
Binary,
|
|
Boolean,
|
|
Categorical,
|
|
Date,
|
|
Datetime,
|
|
Decimal,
|
|
Duration,
|
|
Enum,
|
|
Field,
|
|
Float32,
|
|
Float64,
|
|
Int8,
|
|
Int16,
|
|
Int32,
|
|
Int64,
|
|
Int128,
|
|
List,
|
|
Null,
|
|
Object,
|
|
String,
|
|
Struct,
|
|
Time,
|
|
UInt8,
|
|
UInt16,
|
|
UInt32,
|
|
UInt64,
|
|
UInt128,
|
|
)
|
|
from polars.testing.parametric.strategies._utils import flexhash
|
|
from polars.testing.parametric.strategies.dtype import (
|
|
_DEFAULT_ARRAY_WIDTH_LIMIT,
|
|
_DEFAULT_ENUM_CATEGORIES_LIMIT,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
from datetime import date, time
|
|
|
|
from hypothesis.strategies import SearchStrategy
|
|
|
|
from polars._typing import PolarsDataType, SchemaDict, TimeUnit
|
|
from polars.datatypes import DataType, DataTypeClass
|
|
|
|
_DEFAULT_LIST_LEN_LIMIT = 3
|
|
_DEFAULT_N_CATEGORIES = 10
|
|
|
|
_INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = {
|
|
True: {
|
|
8: st.integers(I8_MIN, I8_MAX),
|
|
16: st.integers(I16_MIN, I16_MAX),
|
|
32: st.integers(I32_MIN, I32_MAX),
|
|
64: st.integers(I64_MIN, I64_MAX),
|
|
128: st.integers(I128_MIN, I128_MAX),
|
|
},
|
|
False: {
|
|
8: st.integers(0, U8_MAX),
|
|
16: st.integers(0, U16_MAX),
|
|
32: st.integers(0, U32_MAX),
|
|
64: st.integers(0, U64_MAX),
|
|
128: st.integers(0, U128_MAX),
|
|
},
|
|
}
|
|
|
|
|
|
def integers(
|
|
bit_width: Literal[8, 16, 32, 64, 128] = 64, *, signed: bool = True
|
|
) -> SearchStrategy[int]:
|
|
"""Create a strategy for generating integers."""
|
|
return _INTEGER_STRATEGIES[signed][bit_width]
|
|
|
|
|
|
def floats(
|
|
bit_width: Literal[32, 64] = 64,
|
|
*,
|
|
allow_nan: bool = True,
|
|
allow_infinity: bool = True,
|
|
) -> SearchStrategy[float]:
|
|
"""Create a strategy for generating integers."""
|
|
return st.floats(
|
|
width=bit_width, allow_nan=allow_nan, allow_infinity=allow_infinity
|
|
)
|
|
|
|
|
|
def booleans() -> SearchStrategy[bool]:
|
|
"""Create a strategy for generating booleans."""
|
|
return st.booleans()
|
|
|
|
|
|
def strings() -> SearchStrategy[str]:
|
|
"""Create a strategy for generating string values."""
|
|
alphabet = st.characters(max_codepoint=1000, exclude_categories=["Cs", "Cc"])
|
|
return st.text(alphabet=alphabet, max_size=8)
|
|
|
|
|
|
def binary() -> SearchStrategy[bytes]:
|
|
"""Create a strategy for generating bytes."""
|
|
return st.binary()
|
|
|
|
|
|
def categories(n_categories: int = _DEFAULT_N_CATEGORIES) -> SearchStrategy[str]:
|
|
"""
|
|
Create a strategy for generating category strings.
|
|
|
|
Parameters
|
|
----------
|
|
n_categories
|
|
The number of categories.
|
|
"""
|
|
categories = [f"c{i}" for i in range(n_categories)]
|
|
return st.sampled_from(categories)
|
|
|
|
|
|
def times() -> SearchStrategy[time]:
|
|
"""Create a strategy for generating `time` objects."""
|
|
return st.times()
|
|
|
|
|
|
def dates() -> SearchStrategy[date]:
|
|
"""Create a strategy for generating `date` objects."""
|
|
return st.dates()
|
|
|
|
|
|
def datetimes(
|
|
time_unit: TimeUnit = "us", time_zone: str | None = None
|
|
) -> SearchStrategy[datetime]:
|
|
"""
|
|
Create a strategy for generating `datetime` objects in the time unit's range.
|
|
|
|
Parameters
|
|
----------
|
|
time_unit
|
|
Time unit for which the datetime objects are valid.
|
|
time_zone
|
|
Time zone for which the datetime objects are valid.
|
|
"""
|
|
if time_unit in ("us", "ms"):
|
|
min_value = datetime.min
|
|
max_value = datetime.max
|
|
elif time_unit == "ns":
|
|
min_value = EPOCH + timedelta(microseconds=I64_MIN // 1000 + 1)
|
|
max_value = EPOCH + timedelta(microseconds=I64_MAX // 1000)
|
|
else:
|
|
msg = f"invalid time unit: {time_unit!r}"
|
|
raise InvalidArgument(msg)
|
|
|
|
if time_zone is None:
|
|
return st.datetimes(min_value, max_value)
|
|
|
|
time_zone_info = ZoneInfo(time_zone)
|
|
|
|
# Make sure time zone offsets do not cause out-of-bound datetimes
|
|
if time_unit == "ns":
|
|
min_value += timedelta(days=1)
|
|
max_value -= timedelta(days=1)
|
|
|
|
# Return naive datetimes, but make sure they are valid for the given time zone
|
|
return st.datetimes(
|
|
min_value=min_value,
|
|
max_value=max_value,
|
|
timezones=st.just(time_zone_info),
|
|
allow_imaginary=False,
|
|
).map(lambda dt: dt.astimezone(timezone.utc).replace(tzinfo=None))
|
|
|
|
|
|
def durations(time_unit: TimeUnit = "us") -> SearchStrategy[timedelta]:
|
|
"""
|
|
Create a strategy for generating `timedelta` objects in the time unit's range.
|
|
|
|
Parameters
|
|
----------
|
|
time_unit
|
|
Time unit for which the timedelta objects are valid.
|
|
"""
|
|
if time_unit == "us":
|
|
return st.timedeltas(
|
|
min_value=timedelta(microseconds=I64_MIN),
|
|
max_value=timedelta(microseconds=I64_MAX),
|
|
)
|
|
elif time_unit == "ns":
|
|
return st.timedeltas(
|
|
min_value=timedelta(microseconds=I64_MIN // 1000),
|
|
max_value=timedelta(microseconds=I64_MAX // 1000),
|
|
)
|
|
elif time_unit == "ms":
|
|
# TODO: Enable full range of millisecond durations
|
|
# timedelta.min/max fall within the range
|
|
# return st.timedeltas()
|
|
return st.timedeltas(
|
|
min_value=timedelta(microseconds=I64_MIN),
|
|
max_value=timedelta(microseconds=I64_MAX),
|
|
)
|
|
else:
|
|
msg = f"invalid time unit: {time_unit!r}"
|
|
raise InvalidArgument(msg)
|
|
|
|
|
|
def decimals(
|
|
precision: int | None = 38, scale: int = 0
|
|
) -> SearchStrategy[decimal.Decimal]:
|
|
"""
|
|
Create a strategy for generating `Decimal` objects.
|
|
|
|
Parameters
|
|
----------
|
|
precision
|
|
Maximum number of digits in each number.
|
|
If set to `None`, the precision is set to 38 (the maximum supported by Polars).
|
|
scale
|
|
Number of digits to the right of the decimal point in each number.
|
|
"""
|
|
if precision is None:
|
|
precision = 38
|
|
|
|
c = decimal.Context(prec=precision)
|
|
exclusive_limit = c.create_decimal(f"1E+{precision - scale}")
|
|
max_value = c.next_minus(exclusive_limit)
|
|
min_value = c.copy_negate(max_value)
|
|
|
|
return st.decimals(
|
|
min_value=min_value,
|
|
max_value=max_value,
|
|
allow_nan=False,
|
|
allow_infinity=False,
|
|
places=scale,
|
|
)
|
|
|
|
|
|
def lists(
|
|
inner_dtype: DataType,
|
|
*,
|
|
select_from: Sequence[Any] | None = None,
|
|
min_size: int = 0,
|
|
max_size: int | None = None,
|
|
unique: bool = False,
|
|
**kwargs: Any,
|
|
) -> SearchStrategy[list[Any]]:
|
|
"""
|
|
Create a strategy for generating lists of the given data type.
|
|
|
|
.. warning::
|
|
This functionality is currently considered **unstable**. It may be
|
|
changed at any point without it being considered a breaking change.
|
|
|
|
Parameters
|
|
----------
|
|
inner_dtype
|
|
Data type of the list elements. If the data type is not fully instantiated,
|
|
defaults will be used, e.g. `Datetime` will become `Datetime('us')`.
|
|
select_from
|
|
The values to use for the innermost lists. If set to `None` (default),
|
|
the default strategy associated with the innermost data type is used.
|
|
min_size
|
|
The minimum length of the generated lists.
|
|
max_size
|
|
The maximum length of the generated lists. If set to `None` (default), the
|
|
maximum is set based on `min_size`: `3` if `min_size` is zero,
|
|
otherwise `2 * min_size`.
|
|
unique
|
|
Ensure that the generated lists contain unique values.
|
|
**kwargs
|
|
Additional arguments that are passed to nested data generation strategies.
|
|
|
|
Examples
|
|
--------
|
|
...
|
|
"""
|
|
if max_size is None:
|
|
max_size = _DEFAULT_LIST_LEN_LIMIT if min_size == 0 else min_size * 2
|
|
|
|
if select_from is not None and not inner_dtype.is_nested():
|
|
inner_strategy = st.sampled_from(select_from)
|
|
else:
|
|
inner_strategy = data(
|
|
inner_dtype,
|
|
select_from=select_from,
|
|
min_size=min_size,
|
|
max_size=max_size,
|
|
unique=unique,
|
|
**kwargs,
|
|
)
|
|
|
|
return st.lists(
|
|
elements=inner_strategy,
|
|
min_size=min_size,
|
|
max_size=max_size,
|
|
unique_by=(flexhash if unique else None),
|
|
)
|
|
|
|
|
|
def structs(
|
|
fields: Sequence[Field] | SchemaDict,
|
|
*,
|
|
allow_null: bool = True,
|
|
**kwargs: Any,
|
|
) -> SearchStrategy[dict[str, Any]]:
|
|
"""
|
|
Create a strategy for generating structs with the given fields.
|
|
|
|
Parameters
|
|
----------
|
|
fields
|
|
The fields that make up the struct. Can be either a sequence of Field
|
|
objects or a mapping of column names to data types.
|
|
allow_null
|
|
Allow nulls as possible values. If set to True, the returned dictionaries
|
|
may miss certain fields and are in random order.
|
|
**kwargs
|
|
Additional arguments that are passed to nested data generation strategies.
|
|
"""
|
|
if isinstance(fields, Mapping):
|
|
fields = [Field(name, dtype) for name, dtype in fields.items()]
|
|
|
|
strats = {f.name: data(f.dtype, allow_null=allow_null, **kwargs) for f in fields}
|
|
|
|
if allow_null:
|
|
return st.fixed_dictionaries({}, optional=strats)
|
|
else:
|
|
return st.fixed_dictionaries(strats)
|
|
|
|
|
|
def nulls() -> SearchStrategy[None]:
|
|
"""Create a strategy for generating null values."""
|
|
return st.none()
|
|
|
|
|
|
def objects() -> SearchStrategy[object]:
|
|
"""Create a strategy for generating arbitrary objects."""
|
|
return st.builds(object)
|
|
|
|
|
|
# Strategies that are not customizable through parameters
|
|
_STATIC_STRATEGIES: dict[DataTypeClass, SearchStrategy[Any]] = {
|
|
Boolean: booleans(),
|
|
Int8: integers(8, signed=True),
|
|
Int16: integers(16, signed=True),
|
|
Int32: integers(32, signed=True),
|
|
Int64: integers(64, signed=True),
|
|
Int128: integers(128, signed=True),
|
|
UInt8: integers(8, signed=False),
|
|
UInt16: integers(16, signed=False),
|
|
UInt32: integers(32, signed=False),
|
|
UInt64: integers(64, signed=False),
|
|
UInt128: integers(128, signed=False),
|
|
Time: times(),
|
|
Date: dates(),
|
|
String: strings(),
|
|
Binary: binary(),
|
|
Null: nulls(),
|
|
Object: objects(),
|
|
}
|
|
|
|
|
|
def data(
|
|
dtype: PolarsDataType, *, allow_null: bool = False, **kwargs: Any
|
|
) -> SearchStrategy[Any]:
|
|
"""
|
|
Create a strategy for generating data for the given data type.
|
|
|
|
Parameters
|
|
----------
|
|
dtype
|
|
A Polars data type. If the data type is not fully instantiated, defaults will
|
|
be used, e.g. `Datetime` will become `Datetime('us')`.
|
|
allow_null
|
|
Allow nulls as possible values.
|
|
**kwargs
|
|
Additional parameters for the strategy associated with the given `dtype`.
|
|
"""
|
|
if (strategy := _STATIC_STRATEGIES.get(dtype.base_type())) is not None:
|
|
strategy = strategy
|
|
elif dtype == Float32:
|
|
strategy = floats(
|
|
32,
|
|
allow_nan=kwargs.pop("allow_nan", True),
|
|
allow_infinity=kwargs.pop("allow_infinity", True),
|
|
)
|
|
elif dtype == Float64:
|
|
strategy = floats(
|
|
64,
|
|
allow_nan=kwargs.pop("allow_nan", True),
|
|
allow_infinity=kwargs.pop("allow_infinity", True),
|
|
)
|
|
elif dtype == Datetime:
|
|
strategy = datetimes(
|
|
time_unit=getattr(dtype, "time_unit", None) or "us",
|
|
time_zone=getattr(dtype, "time_zone", None),
|
|
)
|
|
elif dtype == Duration:
|
|
strategy = durations(time_unit=getattr(dtype, "time_unit", None) or "us")
|
|
elif dtype == Categorical:
|
|
strategy = categories(
|
|
n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES)
|
|
)
|
|
elif dtype == Enum:
|
|
if isinstance(dtype, Enum):
|
|
if (cats := dtype.categories).is_empty():
|
|
strategy = nulls()
|
|
else:
|
|
strategy = st.sampled_from(cats.to_list())
|
|
else:
|
|
strategy = categories(
|
|
n_categories=kwargs.pop("n_categories", _DEFAULT_ENUM_CATEGORIES_LIMIT)
|
|
)
|
|
elif dtype == Decimal:
|
|
strategy = decimals(
|
|
getattr(dtype, "precision", None), getattr(dtype, "scale", 0)
|
|
)
|
|
elif dtype == List:
|
|
inner = getattr(dtype, "inner", None) or Null()
|
|
strategy = lists(inner, allow_null=allow_null, **kwargs)
|
|
elif dtype == Array:
|
|
inner = getattr(dtype, "inner", None) or Null()
|
|
size = getattr(dtype, "size", _DEFAULT_ARRAY_WIDTH_LIMIT)
|
|
kwargs = {k: v for k, v in kwargs.items() if k not in ("min_size", "max_size")}
|
|
strategy = lists(
|
|
inner,
|
|
min_size=size,
|
|
max_size=size,
|
|
allow_null=allow_null,
|
|
**kwargs,
|
|
)
|
|
elif dtype == Struct:
|
|
fields = getattr(dtype, "fields", None) or [Field("f0", Null())]
|
|
strategy = structs(fields, allow_null=allow_null, **kwargs)
|
|
else:
|
|
msg = f"unsupported data type: {dtype}"
|
|
raise InvalidArgument(msg)
|
|
|
|
if allow_null:
|
|
strategy = nulls() | strategy
|
|
|
|
return strategy
|