Skip to content

Commit

Permalink
feat: add http_cookie to WebsocketSession and UserSession (#1653)
Browse files Browse the repository at this point in the history
Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
  • Loading branch information
5enxia and willydouhard authored Jan 10, 2025
1 parent 5d006f4 commit 25489c6
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 0 deletions.
9 changes: 9 additions & 0 deletions backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
chat_profile: Optional[str] = None,
# Origin of the request
http_referer: Optional[str] = None,
# Cookie
http_cookie: Optional[str] = None,
):
if thread_id:
self.thread_id_to_resume = thread_id
Expand All @@ -75,6 +77,7 @@ def __init__(
self.user_env = user_env or {}
self.chat_profile = chat_profile
self.http_referer = http_referer
self.http_cookie = http_cookie

self.files: Dict[str, FileDict] = {}

Expand Down Expand Up @@ -167,6 +170,8 @@ def __init__(
user_env: Optional[Dict[str, str]] = None,
# Origin of the request
http_referer: Optional[str] = None,
# Cookie
http_cookie: Optional[str] = None,
):
super().__init__(
id=id,
Expand All @@ -176,6 +181,7 @@ def __init__(
client_type=client_type,
user_env=user_env,
http_referer=http_referer,
http_cookie=http_cookie,
)

def delete(self):
Expand Down Expand Up @@ -226,6 +232,8 @@ def __init__(
languages: Optional[str] = None,
# Origin of the request
http_referer: Optional[str] = None,
# Cookie
http_cookie: Optional[str] = None,
):
super().__init__(
id=id,
Expand All @@ -236,6 +244,7 @@ def __init__(
client_type=client_type,
chat_profile=chat_profile,
http_referer=http_referer,
http_cookie=http_cookie,
)

self.socket_id = socket_id
Expand Down
2 changes: 2 additions & 0 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):

client_type = auth.get("clientType")
http_referer = environ.get("HTTP_REFERER")
http_cookie = environ.get("HTTP_COOKIE")
url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
Expand All @@ -154,6 +155,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
thread_id=auth.get("threadId"),
languages=environ.get("HTTP_ACCEPT_LANGUAGE"),
http_referer=http_referer,
http_cookie=http_cookie,
)

trace_event("connection_successful")
Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/user_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get(self, key, default=None):
user_session["chat_profile"] = context.session.chat_profile
user_session["http_referer"] = context.session.http_referer
user_session["client_type"] = context.session.client_type
user_session["http_cookie"] = context.session.http_cookie

if isinstance(context.session, WebsocketSession):
user_session["languages"] = context.session.languages
Expand Down
1 change: 1 addition & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def create_mock_session(**kwargs) -> Mock:
mock.chat_settings = kwargs.get("chat_settings", {})
mock.chat_profile = kwargs.get("chat_profile", None)
mock.http_referer = kwargs.get("http_referer", None)
mock.http_cookie = kwargs.get("http_cookie", None)
mock.client_type = kwargs.get("client_type", "webapp")
mock.languages = kwargs.get("languages", ["en"])
mock.thread_id = kwargs.get("thread_id", "test_thread_id")
Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_user_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ async def test_user_session_set_get(mock_chainlit_context, user_session):
assert user_session.get("id") == context.session.id
assert user_session.get("env") == context.session.user_env
assert user_session.get("languages") == context.session.languages
assert user_session.get("http_cookie") == context.session.http_cookie

0 comments on commit 25489c6

Please sign in to comment.