Skip to content

Commit e2e8d10

Browse files
authored
Use caching for the etag using stored hash if available (#182)
1 parent d683131 commit e2e8d10

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

Diff for: flask_pymongo/__init__.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
__all__ = ("PyMongo", "ASCENDING", "DESCENDING", "BSONObjectIdConverter", "BSONProvider")
2828

2929
import hashlib
30+
import warnings
3031
from mimetypes import guess_type
3132
from typing import Any
3233

@@ -183,12 +184,24 @@ def get_upload(filename):
183184
response.headers["Content-Disposition"] = f"attachment; filename={filename}"
184185
response.content_length = fileobj.length
185186
response.last_modified = fileobj.upload_date
186-
# Compute the sha1 sum of the file for the etag.
187-
pos = fileobj.tell()
188-
raw_data = fileobj.read()
189-
fileobj.seek(pos)
190-
digest = hashlib.sha1(raw_data).hexdigest()
191-
response.set_etag(digest)
187+
188+
# GridFS does not manage its own checksum.
189+
# Try to use a sha1 sum that we have added during a save_file.
190+
# Fall back to a legacy md5 sum if it exists.
191+
# Otherwise, compute the sha1 sum directly.
192+
try:
193+
etag = fileobj.sha1
194+
except AttributeError:
195+
with warnings.catch_warnings():
196+
warnings.simplefilter("ignore")
197+
etag = fileobj.md5
198+
if etag is None:
199+
pos = fileobj.tell()
200+
raw_data = fileobj.read()
201+
fileobj.seek(pos)
202+
etag = hashlib.sha1(raw_data).hexdigest()
203+
response.set_etag(etag)
204+
192205
response.cache_control.max_age = cache_for
193206
response.cache_control.public = True
194207
response.make_conditional(request)
@@ -237,5 +250,23 @@ def save_upload(filename):
237250
db_obj = self.db
238251
assert db_obj is not None, "Please initialize the app before calling save_file!"
239252
storage = GridFS(db_obj, base)
240-
id = storage.put(fileobj, filename=filename, content_type=content_type, **kwargs)
241-
return id
253+
254+
# GridFS does not manage its own checksum, so we attach a sha1 to the file
255+
# for use as an etag.
256+
hashingfile = _Wrapper(fileobj)
257+
with storage.new_file(filename=filename, content_type=content_type, **kwargs) as grid_file:
258+
grid_file.write(hashingfile)
259+
grid_file.sha1 = hashingfile.hash.hexdigest()
260+
return grid_file._id
261+
262+
263+
class _Wrapper:
264+
def __init__(self, file):
265+
self.file = file
266+
self.hash = hashlib.sha1()
267+
268+
def read(self, n):
269+
data = self.file.read(n)
270+
if data:
271+
self.hash.update(data)
272+
return data

Diff for: tests/test_gridfs.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from hashlib import sha1
3+
import warnings
4+
from hashlib import md5, sha1
45
from io import BytesIO
56

67
import pytest
@@ -95,6 +96,26 @@ def test_it_sets_supports_conditional_gets(self):
9596
resp = self.mongo.send_file("myfile.txt")
9697
assert resp.status_code == 304
9798

99+
def test_it_sets_supports_conditional_gets_md5(self):
100+
# a basic conditional GET
101+
md5_hash = md5(self.myfile.getvalue()).hexdigest()
102+
environ_args = {
103+
"method": "GET",
104+
"headers": {
105+
"If-None-Match": md5_hash,
106+
},
107+
}
108+
storage = storage = GridFS(self.mongo.db)
109+
with storage.new_file(filename="myfile.txt") as grid_file:
110+
grid_file.write(self.myfile.getvalue())
111+
with warnings.catch_warnings():
112+
warnings.simplefilter("ignore")
113+
grid_file.set("md5", md5_hash)
114+
115+
with self.app.test_request_context(**environ_args):
116+
resp = self.mongo.send_file("myfile.txt")
117+
assert resp.status_code == 304
118+
98119
def test_it_sets_cache_headers(self):
99120
resp = self.mongo.send_file("myfile.txt", cache_for=60)
100121
assert resp.cache_control.max_age == 60

Diff for: tests/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def setUp(self):
3333
def tearDown(self):
3434
assert self.mongo.cx is not None
3535
self.mongo.cx.drop_database(self.dbname)
36-
36+
self.mongo.cx.close()
3737
super().tearDown()

0 commit comments

Comments
 (0)