-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathSessionState.py
129 lines (100 loc) · 3.45 KB
/
SessionState.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Adds pre-session state to StreamLit.
This file is borrowed from
https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
"""
# pylint: disable=protected-access
try:
import streamlit.ReportThread as ReportThread
from streamlit.server.Server import Server
except ModuleNotFoundError:
# Streamlit >= 0.65.0
import streamlit.report_thread as ReportThread
from streamlit.server.server import Server
class SessionState(object):
"""Hack to add per-session state to Streamlit.
Usage
-----
>>> import SessionState
>>>
>>> session_state = SessionState.get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
def __init__(self, **kwargs):
"""A new SessionState object.
Parameters
----------
**kwargs : any
Default values for the session state.
Example
-------
>>> session_state = SessionState(user_name='', favorite_color='black')
>>> session_state.user_name = 'Mary'
''
>>> session_state.favorite_color
'black'
"""
for key, val in kwargs.items():
setattr(self, key, val)
def get(**kwargs):
"""Gets a SessionState object for the current session.
Creates a new object if necessary.
Parameters
----------
**kwargs : any
Default values you want to add to the session state, if we're creating a
new one.
Example
-------
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
this_session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
s = session_info.session
if (
# Streamlit < 0.54.0
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
or
# Streamlit >= 0.54.0
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
or
# Streamlit >= 0.65.2
(not hasattr(s, '_main_dg') and
s._uploaded_file_mgr == ctx.uploaded_file_mgr)
):
this_session = s
if this_session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object. "
'Are you doing something fancy with threads?')
# Got the session object! Now let's attach some state into it.
if not hasattr(this_session, '_custom_session_state'):
this_session._custom_session_state = SessionState(**kwargs)
return this_session._custom_session_state
# pylint: enable=protected-access