"""Strategies for generating various forms of data.""" from __future__ import annotations import decimal from collections.abc import Mapping from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Literal from zoneinfo import ZoneInfo import hypothesis.strategies as st from hypothesis.errors import InvalidArgument from polars._utils.constants import ( EPOCH, I8_MAX, I8_MIN, I16_MAX, I16_MIN, I32_MAX, I32_MIN, I64_MAX, I64_MIN, I128_MAX, I128_MIN, U8_MAX, U16_MAX, U32_MAX, U64_MAX, U128_MAX, ) from polars.datatypes import ( Array, Binary, Boolean, Categorical, Date, Datetime, Decimal, Duration, Enum, Field, Float32, Float64, Int8, Int16, Int32, Int64, Int128, List, Null, Object, String, Struct, Time, UInt8, UInt16, UInt32, UInt64, UInt128, ) from polars.testing.parametric.strategies._utils import flexhash from polars.testing.parametric.strategies.dtype import ( _DEFAULT_ARRAY_WIDTH_LIMIT, _DEFAULT_ENUM_CATEGORIES_LIMIT, ) if TYPE_CHECKING: from collections.abc import Sequence from datetime import date, time from hypothesis.strategies import SearchStrategy from polars._typing import PolarsDataType, SchemaDict, TimeUnit from polars.datatypes import DataType, DataTypeClass _DEFAULT_LIST_LEN_LIMIT = 3 _DEFAULT_N_CATEGORIES = 10 _INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = { True: { 8: st.integers(I8_MIN, I8_MAX), 16: st.integers(I16_MIN, I16_MAX), 32: st.integers(I32_MIN, I32_MAX), 64: st.integers(I64_MIN, I64_MAX), 128: st.integers(I128_MIN, I128_MAX), }, False: { 8: st.integers(0, U8_MAX), 16: st.integers(0, U16_MAX), 32: st.integers(0, U32_MAX), 64: st.integers(0, U64_MAX), 128: st.integers(0, U128_MAX), }, } def integers( bit_width: Literal[8, 16, 32, 64, 128] = 64, *, signed: bool = True ) -> SearchStrategy[int]: """Create a strategy for generating integers.""" return _INTEGER_STRATEGIES[signed][bit_width] def floats( bit_width: Literal[32, 64] = 64, *, allow_nan: bool = True, allow_infinity: bool = True, ) -> SearchStrategy[float]: """Create a strategy for generating integers.""" return st.floats( width=bit_width, allow_nan=allow_nan, allow_infinity=allow_infinity ) def booleans() -> SearchStrategy[bool]: """Create a strategy for generating booleans.""" return st.booleans() def strings() -> SearchStrategy[str]: """Create a strategy for generating string values.""" alphabet = st.characters(max_codepoint=1000, exclude_categories=["Cs", "Cc"]) return st.text(alphabet=alphabet, max_size=8) def binary() -> SearchStrategy[bytes]: """Create a strategy for generating bytes.""" return st.binary() def categories(n_categories: int = _DEFAULT_N_CATEGORIES) -> SearchStrategy[str]: """ Create a strategy for generating category strings. Parameters ---------- n_categories The number of categories. """ categories = [f"c{i}" for i in range(n_categories)] return st.sampled_from(categories) def times() -> SearchStrategy[time]: """Create a strategy for generating `time` objects.""" return st.times() def dates() -> SearchStrategy[date]: """Create a strategy for generating `date` objects.""" return st.dates() def datetimes( time_unit: TimeUnit = "us", time_zone: str | None = None ) -> SearchStrategy[datetime]: """ Create a strategy for generating `datetime` objects in the time unit's range. Parameters ---------- time_unit Time unit for which the datetime objects are valid. time_zone Time zone for which the datetime objects are valid. """ if time_unit in ("us", "ms"): min_value = datetime.min max_value = datetime.max elif time_unit == "ns": min_value = EPOCH + timedelta(microseconds=I64_MIN // 1000 + 1) max_value = EPOCH + timedelta(microseconds=I64_MAX // 1000) else: msg = f"invalid time unit: {time_unit!r}" raise InvalidArgument(msg) if time_zone is None: return st.datetimes(min_value, max_value) time_zone_info = ZoneInfo(time_zone) # Make sure time zone offsets do not cause out-of-bound datetimes if time_unit == "ns": min_value += timedelta(days=1) max_value -= timedelta(days=1) # Return naive datetimes, but make sure they are valid for the given time zone return st.datetimes( min_value=min_value, max_value=max_value, timezones=st.just(time_zone_info), allow_imaginary=False, ).map(lambda dt: dt.astimezone(timezone.utc).replace(tzinfo=None)) def durations(time_unit: TimeUnit = "us") -> SearchStrategy[timedelta]: """ Create a strategy for generating `timedelta` objects in the time unit's range. Parameters ---------- time_unit Time unit for which the timedelta objects are valid. """ if time_unit == "us": return st.timedeltas( min_value=timedelta(microseconds=I64_MIN), max_value=timedelta(microseconds=I64_MAX), ) elif time_unit == "ns": return st.timedeltas( min_value=timedelta(microseconds=I64_MIN // 1000), max_value=timedelta(microseconds=I64_MAX // 1000), ) elif time_unit == "ms": # TODO: Enable full range of millisecond durations # timedelta.min/max fall within the range # return st.timedeltas() return st.timedeltas( min_value=timedelta(microseconds=I64_MIN), max_value=timedelta(microseconds=I64_MAX), ) else: msg = f"invalid time unit: {time_unit!r}" raise InvalidArgument(msg) def decimals( precision: int | None = 38, scale: int = 0 ) -> SearchStrategy[decimal.Decimal]: """ Create a strategy for generating `Decimal` objects. Parameters ---------- precision Maximum number of digits in each number. If set to `None`, the precision is set to 38 (the maximum supported by Polars). scale Number of digits to the right of the decimal point in each number. """ if precision is None: precision = 38 c = decimal.Context(prec=precision) exclusive_limit = c.create_decimal(f"1E+{precision - scale}") max_value = c.next_minus(exclusive_limit) min_value = c.copy_negate(max_value) return st.decimals( min_value=min_value, max_value=max_value, allow_nan=False, allow_infinity=False, places=scale, ) def lists( inner_dtype: DataType, *, select_from: Sequence[Any] | None = None, min_size: int = 0, max_size: int | None = None, unique: bool = False, **kwargs: Any, ) -> SearchStrategy[list[Any]]: """ Create a strategy for generating lists of the given data type. .. warning:: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. Parameters ---------- inner_dtype Data type of the list elements. If the data type is not fully instantiated, defaults will be used, e.g. `Datetime` will become `Datetime('us')`. select_from The values to use for the innermost lists. If set to `None` (default), the default strategy associated with the innermost data type is used. min_size The minimum length of the generated lists. max_size The maximum length of the generated lists. If set to `None` (default), the maximum is set based on `min_size`: `3` if `min_size` is zero, otherwise `2 * min_size`. unique Ensure that the generated lists contain unique values. **kwargs Additional arguments that are passed to nested data generation strategies. Examples -------- ... """ if max_size is None: max_size = _DEFAULT_LIST_LEN_LIMIT if min_size == 0 else min_size * 2 if select_from is not None and not inner_dtype.is_nested(): inner_strategy = st.sampled_from(select_from) else: inner_strategy = data( inner_dtype, select_from=select_from, min_size=min_size, max_size=max_size, unique=unique, **kwargs, ) return st.lists( elements=inner_strategy, min_size=min_size, max_size=max_size, unique_by=(flexhash if unique else None), ) def structs( fields: Sequence[Field] | SchemaDict, *, allow_null: bool = True, **kwargs: Any, ) -> SearchStrategy[dict[str, Any]]: """ Create a strategy for generating structs with the given fields. Parameters ---------- fields The fields that make up the struct. Can be either a sequence of Field objects or a mapping of column names to data types. allow_null Allow nulls as possible values. If set to True, the returned dictionaries may miss certain fields and are in random order. **kwargs Additional arguments that are passed to nested data generation strategies. """ if isinstance(fields, Mapping): fields = [Field(name, dtype) for name, dtype in fields.items()] strats = {f.name: data(f.dtype, allow_null=allow_null, **kwargs) for f in fields} if allow_null: return st.fixed_dictionaries({}, optional=strats) else: return st.fixed_dictionaries(strats) def nulls() -> SearchStrategy[None]: """Create a strategy for generating null values.""" return st.none() def objects() -> SearchStrategy[object]: """Create a strategy for generating arbitrary objects.""" return st.builds(object) # Strategies that are not customizable through parameters _STATIC_STRATEGIES: dict[DataTypeClass, SearchStrategy[Any]] = { Boolean: booleans(), Int8: integers(8, signed=True), Int16: integers(16, signed=True), Int32: integers(32, signed=True), Int64: integers(64, signed=True), Int128: integers(128, signed=True), UInt8: integers(8, signed=False), UInt16: integers(16, signed=False), UInt32: integers(32, signed=False), UInt64: integers(64, signed=False), UInt128: integers(128, signed=False), Time: times(), Date: dates(), String: strings(), Binary: binary(), Null: nulls(), Object: objects(), } def data( dtype: PolarsDataType, *, allow_null: bool = False, **kwargs: Any ) -> SearchStrategy[Any]: """ Create a strategy for generating data for the given data type. Parameters ---------- dtype A Polars data type. If the data type is not fully instantiated, defaults will be used, e.g. `Datetime` will become `Datetime('us')`. allow_null Allow nulls as possible values. **kwargs Additional parameters for the strategy associated with the given `dtype`. """ if (strategy := _STATIC_STRATEGIES.get(dtype.base_type())) is not None: strategy = strategy elif dtype == Float32: strategy = floats( 32, allow_nan=kwargs.pop("allow_nan", True), allow_infinity=kwargs.pop("allow_infinity", True), ) elif dtype == Float64: strategy = floats( 64, allow_nan=kwargs.pop("allow_nan", True), allow_infinity=kwargs.pop("allow_infinity", True), ) elif dtype == Datetime: strategy = datetimes( time_unit=getattr(dtype, "time_unit", None) or "us", time_zone=getattr(dtype, "time_zone", None), ) elif dtype == Duration: strategy = durations(time_unit=getattr(dtype, "time_unit", None) or "us") elif dtype == Categorical: strategy = categories( n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES) ) elif dtype == Enum: if isinstance(dtype, Enum): if (cats := dtype.categories).is_empty(): strategy = nulls() else: strategy = st.sampled_from(cats.to_list()) else: strategy = categories( n_categories=kwargs.pop("n_categories", _DEFAULT_ENUM_CATEGORIES_LIMIT) ) elif dtype == Decimal: strategy = decimals( getattr(dtype, "precision", None), getattr(dtype, "scale", 0) ) elif dtype == List: inner = getattr(dtype, "inner", None) or Null() strategy = lists(inner, allow_null=allow_null, **kwargs) elif dtype == Array: inner = getattr(dtype, "inner", None) or Null() size = getattr(dtype, "size", _DEFAULT_ARRAY_WIDTH_LIMIT) kwargs = {k: v for k, v in kwargs.items() if k not in ("min_size", "max_size")} strategy = lists( inner, min_size=size, max_size=size, allow_null=allow_null, **kwargs, ) elif dtype == Struct: fields = getattr(dtype, "fields", None) or [Field("f0", Null())] strategy = structs(fields, allow_null=allow_null, **kwargs) else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) if allow_null: strategy = nulls() | strategy return strategy