199 lines
8.0 KiB
Python
199 lines
8.0 KiB
Python
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
|
|
|
from typing_extensions import Self
|
|
|
|
from streamlit.errors import StreamlitAPIException
|
|
from streamlit.logger import get_logger
|
|
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
|
|
|
if TYPE_CHECKING:
|
|
from streamlit.runtime.state import SessionState
|
|
from streamlit.runtime.state.common import WidgetValuePresenter
|
|
|
|
|
|
_LOGGER = get_logger(__name__)
|
|
|
|
|
|
class _TriggerPayload(TypedDict, total=False):
|
|
event: str
|
|
value: object
|
|
|
|
|
|
def make_bidi_component_presenter(
|
|
aggregator_id: str,
|
|
component_id: str | None = None,
|
|
allowed_state_keys: set[str] | None = None,
|
|
) -> WidgetValuePresenter:
|
|
"""Return a presenter that merges trigger events into CCv2 state.
|
|
|
|
This function returns a callable that takes a component's persistent state
|
|
value and the current `SessionState` instance, and returns the user-visible
|
|
value that should appear in `st.session_state`.
|
|
|
|
The presenter is side-effect-free and does not mutate stored state or
|
|
callback behavior. It is intended to be attached to the persistent state
|
|
widget via the generic `presenter` hook.
|
|
|
|
Parameters
|
|
----------
|
|
aggregator_id
|
|
The ID of the trigger aggregator widget that holds the event payloads.
|
|
|
|
Returns
|
|
-------
|
|
WidgetValuePresenter
|
|
A callable that merges the trigger event values into the component's
|
|
base state for presentation in `st.session_state`.
|
|
|
|
"""
|
|
|
|
def _present(base_value: object, session_state: SessionState) -> object:
|
|
def _check_modification(k: str) -> None:
|
|
ctx = get_script_run_ctx()
|
|
if ctx is not None and component_id is not None:
|
|
user_key = session_state._key_id_mapper.get_key_from_id(component_id)
|
|
if (
|
|
component_id in ctx.widget_ids_this_run
|
|
or user_key in ctx.form_ids_this_run
|
|
):
|
|
raise StreamlitAPIException(
|
|
f"`st.session_state.{user_key}.{k}` cannot be modified after the component"
|
|
f" with key `{user_key}` is instantiated."
|
|
)
|
|
|
|
# Base state must be a flat mapping; otherwise, present as-is.
|
|
base_map: dict[str, object] | None = None
|
|
if isinstance(base_value, dict):
|
|
base_map = cast("dict[str, object]", base_value)
|
|
|
|
if base_map is not None:
|
|
# Read the trigger aggregator payloads if present
|
|
try:
|
|
agg_meta = session_state._new_widget_state.widget_metadata.get(
|
|
aggregator_id
|
|
)
|
|
if agg_meta is None or agg_meta.value_type != "json_trigger_value":
|
|
return base_value
|
|
|
|
try:
|
|
agg_payloads_obj = session_state._new_widget_state[aggregator_id]
|
|
except KeyError:
|
|
agg_payloads_obj = None
|
|
|
|
payloads_list: list[_TriggerPayload] | None
|
|
if agg_payloads_obj is None:
|
|
payloads_list = None
|
|
elif isinstance(agg_payloads_obj, list):
|
|
# Filter and cast to the expected payload type shape
|
|
payloads_list = [
|
|
cast("_TriggerPayload", p)
|
|
for p in agg_payloads_obj
|
|
if isinstance(p, dict)
|
|
]
|
|
elif isinstance(agg_payloads_obj, dict):
|
|
payloads_list = [cast("_TriggerPayload", agg_payloads_obj)]
|
|
else:
|
|
payloads_list = None
|
|
|
|
event_to_val: dict[str, object] = {}
|
|
if payloads_list is not None:
|
|
for payload in payloads_list:
|
|
ev = payload.get("event")
|
|
if isinstance(ev, str):
|
|
event_to_val[ev] = payload.get("value")
|
|
|
|
# Merge triggers into a flat view: triggers first, then base
|
|
flat: dict[str, object] = dict(event_to_val)
|
|
flat.update(base_map)
|
|
|
|
# Return a write-through dict that updates the underlying
|
|
# component state when users assign nested keys via
|
|
# st.session_state[component_user_key][name] = value. Using a
|
|
# dict subclass ensures pretty-printing and JSON serialization
|
|
# behave as expected for st.write and logs.
|
|
class _WriteThrough(dict[str, object]):
|
|
def __init__(self, data: dict[str, object]) -> None:
|
|
super().__init__(data)
|
|
|
|
def __getattr__(self, name: str) -> object:
|
|
return self.get(name)
|
|
|
|
def __setattr__(self, name: str, value: object) -> None:
|
|
if name.startswith(("__", "_")):
|
|
return super().__setattr__(name, value)
|
|
self[name] = value
|
|
return None
|
|
|
|
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
|
|
# This object is a proxy to the real state. Don't copy it.
|
|
memo[id(self)] = self
|
|
return self
|
|
|
|
def __setitem__(self, k: str, v: object) -> None:
|
|
_check_modification(k)
|
|
|
|
if (
|
|
allowed_state_keys is not None
|
|
and k not in allowed_state_keys
|
|
):
|
|
# Silently ignore invalid keys to match permissive session_state semantics
|
|
return
|
|
|
|
# Update the underlying stored base state and this dict
|
|
super().__setitem__(k, v)
|
|
try:
|
|
# Store back to session state's widget store as a flat mapping
|
|
ss = session_state
|
|
# Directly set the value in the new widget state store
|
|
if component_id is not None:
|
|
ss._new_widget_state.set_from_value(
|
|
component_id, dict(self)
|
|
)
|
|
except Exception as e:
|
|
_LOGGER.debug("Failed to persist CCv2 state update: %s", e)
|
|
|
|
def __delitem__(self, k: str) -> None:
|
|
_check_modification(k)
|
|
|
|
super().__delitem__(k)
|
|
try:
|
|
ss = session_state
|
|
if component_id is not None:
|
|
ss._new_widget_state.set_from_value(
|
|
component_id, dict(self)
|
|
)
|
|
except Exception as e:
|
|
_LOGGER.debug(
|
|
"Failed to persist CCv2 state deletion: %s", e
|
|
)
|
|
|
|
return _WriteThrough(flat)
|
|
except Exception as e:
|
|
# On any error, fall back to the base value
|
|
_LOGGER.debug(
|
|
"Failed to merge trigger events into component state: %s",
|
|
e,
|
|
exc_info=e,
|
|
)
|
|
return base_value
|
|
|
|
return base_value
|
|
|
|
return _present
|