Skip to content

Commit

Permalink
fix: data layers (#1670)
Browse files Browse the repository at this point in the history
* fix: data layers

* chore: add data layer fixes changelog
  • Loading branch information
desaxce authored Jan 10, 2025
1 parent efad742 commit 1ec7198
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Errors in thread resume (like thread not found) now properly redirects to the the home page
- Elements like Dataframe, Plotly or text should now load correctly from cloud storages
- AskFileMessage is now usable even if spontaneous uploads are disabled
- Remove element objects from cloud storage on thread removal (Official & SQLAlchemy data layers)
- Fix custom element `props` storage for SQL Alchemy data layer

## [2.0.1] - 2025-01-09

Expand Down
26 changes: 25 additions & 1 deletion backend/chainlit/data/chainlit_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,21 @@ async def get_element(
@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
query = """
SELECT * FROM "Element"
WHERE id = $1
"""
elements = await self.execute_query(query, {"id": element_id})

if self.storage_client is not None and len(elements) > 0:
if elements[0]["objectKey"]:
await self.storage_client.delete_file(
object_key=elements[0]["objectKey"]
)
query = """
DELETE FROM "Element"
WHERE id = $1
"""
params = {"element_id": element_id}
params = {"id": element_id}

if thread_id:
query += ' AND "threadId" = $2'
Expand Down Expand Up @@ -375,6 +386,19 @@ async def get_thread_author(self, thread_id: str) -> str:
return results[0]["identifier"]

async def delete_thread(self, thread_id: str):
elements_query = """
SELECT * FROM "Element"
WHERE "threadId" = $1
"""
elements_results = await self.execute_query(
elements_query, {"thread_id": thread_id}
)

if self.storage_client is not None:
for elem in elements_results:
if elem["objectKey"]:
await self.storage_client.delete_file(object_key=elem["objectKey"])

await self.execute_query(
'DELETE FROM "Thread" WHERE id = $1', {"thread_id": thread_id}
)
Expand Down
28 changes: 26 additions & 2 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ async def update_thread(
async def delete_thread(self, thread_id: str):
if self.show_logger:
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")

elements_query = """SELECT * FROM elements WHERE "threadId" = :id"""
elements = await self.execute_sql(elements_query, {"id": thread_id})

if self.storage_provider is not None and isinstance(elements, list):
for elem in filter(lambda x: x["objectKey"], elements):
await self.storage_provider.delete_file(object_key=elem["objectKey"])

# Delete feedbacks/elements/steps/thread
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
Expand Down Expand Up @@ -439,7 +447,7 @@ async def get_element(
url=element_dict.get("url"),
objectKey=element_dict.get("objectKey"),
name=element_dict["name"],
props=element_dict.get("props"),
props=json.loads(element_dict.get("props", "{}")),
display=element_dict["display"],
size=element_dict.get("size"),
language=element_dict.get("language"),
Expand Down Expand Up @@ -504,7 +512,10 @@ async def create_element(self, element: "Element"):

element_dict["url"] = uploaded_file.get("url")
element_dict["objectKey"] = uploaded_file.get("object_key")

element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
if "props" in element_dict_cleaned:
element_dict_cleaned["props"] = json.dumps(element_dict_cleaned["props"])

columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys())
placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys())
Expand All @@ -515,8 +526,21 @@ async def create_element(self, element: "Element"):
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
if self.show_logger:
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")

query = """SELECT * FROM elements WHERE "id" = :id"""
elements = await self.execute_sql(query, {"id": element_id})

if (
self.storage_provider is not None
and isinstance(elements, list)
and len(elements) > 0
and elements[0]["objectKey"]
):
await self.storage_provider.delete_file(object_key=elements[0]["objectKey"])

query = """DELETE FROM elements WHERE "id" = :id"""
parameters = {"id": element_id}

await self.execute_sql(query=query, parameters=parameters)

async def get_all_user_threads(
Expand Down Expand Up @@ -688,7 +712,7 @@ async def get_all_user_threads(
autoPlay=element.get("element_autoPlay"),
playerConfig=element.get("element_playerconfig"),
page=element.get("element_page"),
props=element.get("element_props"),
props=json.loads(element.get("props", "{}")),
forId=element.get("element_forid"),
mime=element.get("element_mime"),
)
Expand Down
17 changes: 13 additions & 4 deletions backend/chainlit/data/storage_clients/azure_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self, container_name: str, storage_account: str, storage_key: str):
self.service_client = AsyncBlobServiceClient.from_connection_string(
connection_string
)
self.container_client = self.service_client.get_container_client(
self.container_name
)
logger.info("AzureBlobStorageClient initialized")

async def get_read_url(self, object_key: str) -> str:
Expand Down Expand Up @@ -52,10 +55,7 @@ async def upload_file(
overwrite: bool = True,
) -> Dict[str, Any]:
try:
container_client = self.service_client.get_container_client(
self.container_name
)
blob_client = container_client.get_blob_client(object_key)
blob_client = self.container_client.get_blob_client(object_key)

if isinstance(data, str):
data = data.encode("utf-8")
Expand All @@ -78,3 +78,12 @@ async def upload_file(

except Exception as e:
raise Exception(f"Failed to upload file to Azure Blob Storage: {e!s}")

async def delete_file(self, object_key: str) -> bool:
try:
blob_client = self.container_client.get_blob_client(blob=object_key)
await blob_client.delete_blob()
return True
except Exception as e:
logger.warn(f"AzureBlobStorageClient, delete_file error: {e}")
return False
4 changes: 4 additions & 0 deletions backend/chainlit/data/storage_clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ async def upload_file(
) -> Dict[str, Any]:
pass

@abstractmethod
async def delete_file(self, object_key: str) -> bool:
pass

@abstractmethod
async def get_read_url(self, object_key: str) -> str:
pass
14 changes: 12 additions & 2 deletions backend/chainlit/data/storage_clients/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
)

self.client = storage.Client(project=project_id, credentials=credentials)
self.bucket_name = bucket_name
self.bucket = self.client.bucket(bucket_name)
logger.info("GCSStorageClient initialized")

Expand Down Expand Up @@ -60,7 +59,7 @@ def sync_upload_file(

return {
"object_key": object_key,
"url": f"gs://{self.bucket_name}/{object_key}",
"url": f"gs://{self.bucket.name}/{object_key}",
}

except Exception as e:
Expand All @@ -76,3 +75,14 @@ async def upload_file(
return await make_async(self.sync_upload_file)(
object_key, data, mime, overwrite
)

def sync_delete_file(self, object_key: str) -> bool:
try:
self.bucket.blob(object_key).delete()
return True
except Exception as e:
logger.warn(f"GCSStorageClient, delete_file error: {e}")
return False

async def delete_file(self, object_key: str) -> bool:
return await make_async(self.sync_delete_file)(object_key)
30 changes: 28 additions & 2 deletions backend/chainlit/data/storage_clients/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import boto3 # type: ignore

from chainlit import make_async
from chainlit.data.storage_clients.base import EXPIRY_TIME, BaseStorageClient
from chainlit.logger import logger

Expand All @@ -19,7 +20,7 @@ def __init__(self, bucket: str, **kwargs: Any):
except Exception as e:
logger.warn(f"S3StorageClient initialization error: {e}")

async def get_read_url(self, object_key: str) -> str:
def sync_get_read_url(self, object_key: str) -> str:
try:
url = self.client.generate_presigned_url(
"get_object",
Expand All @@ -31,7 +32,10 @@ async def get_read_url(self, object_key: str) -> str:
logger.warn(f"S3StorageClient, get_read_url error: {e}")
return object_key

async def upload_file(
async def get_read_url(self, object_key: str) -> str:
return await make_async(self.sync_get_read_url)(object_key)

def sync_upload_file(
self,
object_key: str,
data: Union[bytes, str],
Expand All @@ -47,3 +51,25 @@ async def upload_file(
except Exception as e:
logger.warn(f"S3StorageClient, upload_file error: {e}")
return {}

async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
return await make_async(self.sync_upload_file)(
object_key, data, mime, overwrite
)

def sync_delete_file(self, object_key: str) -> bool:
try:
self.client.delete_object(Bucket=self.bucket, Key=object_key)
return True
except Exception as e:
logger.warn(f"S3StorageClient, delete_file error: {e}")
return False

async def delete_file(self, object_key: str) -> bool:
return await make_async(self.sync_delete_file)(object_key)

0 comments on commit 1ec7198

Please sign in to comment.