Skip to content

Commit 6c71a89

Browse files
authored
Add support for db in GridFS functions (#178)
1 parent f66cdb9 commit 6c71a89

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

flask_pymongo/__init__.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ def init_app(self, app: Flask, uri: str | None = None, *args: Any, **kwargs: Any
123123

124124
# view helpers
125125
def send_file(
126-
self, filename: str, base: str = "fs", version: int = -1, cache_for: int = 31536000
126+
self,
127+
filename: str,
128+
base: str = "fs",
129+
version: int = -1,
130+
cache_for: int = 31536000,
131+
db: str | None = None,
127132
) -> Response:
128133
"""Respond with a file from GridFS.
129134
@@ -144,6 +149,7 @@ def get_upload(filename):
144149
revision. If no such version exists, return with HTTP status 404.
145150
:param int cache_for: number of seconds that browsers should be
146151
instructed to cache responses
152+
:param str db: the target database, if different from the default database.
147153
"""
148154
if not isinstance(base, str):
149155
raise TypeError("'base' must be string or unicode")
@@ -152,8 +158,13 @@ def get_upload(filename):
152158
if not isinstance(cache_for, int):
153159
raise TypeError("'cache_for' must be an integer")
154160

155-
assert self.db is not None, "Please initialize the app before calling send_file!"
156-
storage = GridFS(self.db, base)
161+
if db:
162+
db_obj = self.cx[db]
163+
else:
164+
db_obj = self.db
165+
166+
assert db_obj is not None, "Please initialize the app before calling send_file!"
167+
storage = GridFS(db_obj, base)
157168

158169
try:
159170
fileobj = storage.get_version(filename=filename, version=version)
@@ -189,6 +200,7 @@ def save_file(
189200
fileobj: Any,
190201
base: str = "fs",
191202
content_type: str | None = None,
203+
db: str | None = None,
192204
**kwargs: Any,
193205
) -> Any:
194206
"""Save a file-like object to GridFS using the given filename.
@@ -207,6 +219,7 @@ def save_upload(filename):
207219
:param str content_type: the MIME content-type of the file. If
208220
``None``, the content-type is guessed from the filename using
209221
:func:`~mimetypes.guess_type`
222+
:param str db: the target database, if different from the default database.
210223
:param kwargs: extra attributes to be stored in the file's document,
211224
passed directly to :meth:`gridfs.GridFS.put`
212225
"""
@@ -218,7 +231,11 @@ def save_upload(filename):
218231
if content_type is None:
219232
content_type, _ = guess_type(filename)
220233

221-
assert self.db is not None, "Please initialize the app before calling save_file!"
222-
storage = GridFS(self.db, base)
234+
if db:
235+
db_obj = self.cx[db]
236+
else:
237+
db_obj = self.db
238+
assert db_obj is not None, "Please initialize the app before calling save_file!"
239+
storage = GridFS(db_obj, base)
223240
id = storage.put(fileobj, filename=filename, content_type=content_type, **kwargs)
224241
return id

tests/test_gridfs.py

+13
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def test_it_saves_files(self):
3030
gridfs = GridFS(self.mongo.db)
3131
assert gridfs.exists({"filename": "my-file"})
3232

33+
def test_it_saves_files_to_another_db(self):
34+
fileobj = BytesIO(b"these are the bytes")
35+
36+
self.mongo.save_file("my-file", fileobj, db="other")
37+
assert self.mongo.db is not None
38+
gridfs = GridFS(self.mongo.cx["other"])
39+
assert gridfs.exists({"filename": "my-file"})
40+
3341
def test_it_saves_files_with_props(self):
3442
fileobj = BytesIO(b"these are the bytes")
3543

@@ -56,6 +64,7 @@ def setUp(self):
5664
# make it bigger than 1 gridfs chunk
5765
self.myfile = BytesIO(b"a" * 500 * 1024)
5866
self.mongo.save_file("myfile.txt", self.myfile)
67+
self.mongo.save_file("my_other_file.txt", self.myfile, db="other")
5968

6069
def test_it_404s_for_missing_files(self):
6170
with pytest.raises(NotFound):
@@ -65,6 +74,10 @@ def test_it_sets_content_type(self):
6574
resp = self.mongo.send_file("myfile.txt")
6675
assert resp.content_type.startswith("text/plain")
6776

77+
def test_it_sends_file_to_another_db(self):
78+
resp = self.mongo.send_file("my_other_file.txt", db="other")
79+
assert resp.content_type.startswith("text/plain")
80+
6881
def test_it_sets_content_length(self):
6982
resp = self.mongo.send_file("myfile.txt")
7083
assert resp.content_length == len(self.myfile.getvalue())

0 commit comments

Comments
 (0)