Skip to content

Commit f167bf5

Browse files
authored
Upsert (#76)
Upsert
1 parent 51f1533 commit f167bf5

File tree

10 files changed

+192
-22
lines changed

10 files changed

+192
-22
lines changed

beanie/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from beanie.odm.utils.general import init_beanie
55
from beanie.odm.documents import Document
66

7-
__version__ = "1.1.6"
7+
__version__ = "1.2.0"
88
__all__ = [
99
# ODM
1010
"Document",

beanie/odm/documents.py

+4
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ async def insert_one(
118118
:param session: ClientSession - pymongo session
119119
:return: InsertOneResult
120120
"""
121+
if not isinstance(document, cls):
122+
raise TypeError(
123+
"Inserting document must be of the original document class"
124+
)
121125
return await cls.get_motor_collection().insert_one(
122126
get_dict(document), session=session
123127
)

beanie/odm/queries/find.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
cast,
1313
)
1414

15+
from pydantic import BaseModel
16+
from pymongo.client_session import ClientSession
17+
from pymongo.results import UpdateResult
18+
1519
from beanie.exceptions import DocumentNotFound
1620
from beanie.odm.enums import SortDirection
1721
from beanie.odm.interfaces.aggregate import AggregateMethods
@@ -33,9 +37,6 @@
3337
UpdateOne,
3438
)
3539
from beanie.odm.utils.projection import get_projection
36-
from pydantic import BaseModel
37-
from pymongo.client_session import ClientSession
38-
from pymongo.results import UpdateResult
3940

4041
if TYPE_CHECKING:
4142
from beanie.odm.documents import DocType
@@ -83,7 +84,7 @@ def update(
8384
:param session: Optional[ClientSession]
8485
:return: UpdateMany query
8586
"""
86-
self.set_session(session=session)
87+
self.set_session(session)
8788
return (
8889
self.UpdateQueryType(
8990
document_model=self.document_model,
@@ -93,6 +94,32 @@ def update(
9394
.set_session(session=self.session)
9495
)
9596

97+
def upsert(
98+
self,
99+
*args: Mapping[str, Any],
100+
on_insert: "DocType",
101+
session: Optional[ClientSession] = None
102+
):
103+
"""
104+
Create Update with modifications query
105+
and provide search criteria there
106+
107+
:param args: *Mapping[str,Any] - the modifications to apply.
108+
:param on_insert: DocType - document to insert if there is no matched
109+
document in the collection
110+
:param session: Optional[ClientSession]
111+
:return: UpdateMany query
112+
"""
113+
self.set_session(session)
114+
return (
115+
self.UpdateQueryType(
116+
document_model=self.document_model,
117+
find_query=self.get_filter_query(),
118+
)
119+
.upsert(*args, on_insert=on_insert)
120+
.set_session(session=self.session)
121+
)
122+
96123
def delete(
97124
self, session: Optional[ClientSession] = None
98125
) -> Union[DeleteOne, DeleteMany]:

beanie/odm/queries/update.py

+52-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import abstractmethod
12
from typing import (
23
List,
34
Type,
@@ -6,10 +7,11 @@
67
Mapping,
78
Any,
89
Dict,
10+
Union,
911
)
1012

1113
from pymongo.client_session import ClientSession
12-
from pymongo.results import UpdateResult
14+
from pymongo.results import UpdateResult, InsertOneResult
1315

1416
from beanie.odm.interfaces.session import SessionMethods
1517
from beanie.odm.interfaces.update import (
@@ -40,6 +42,8 @@ def __init__(
4042
self.find_query = find_query
4143
self.update_expressions: List[Mapping[str, Any]] = []
4244
self.session = None
45+
self.is_upsert = False
46+
self.upsert_insert_doc: Optional["DocType"] = None
4347

4448
@property
4549
def update_query(self) -> Dict[str, Any]:
@@ -57,7 +61,7 @@ def update(
5761
self, *args: Mapping[str, Any], session: Optional[ClientSession] = None
5862
) -> "UpdateQuery":
5963
"""
60-
Provide modifications to the update query. The same as `update()`
64+
Provide modifications to the update query.
6165
6266
:param args: *Union[dict, Mapping] - the modifications to apply.
6367
:param session: Optional[ClientSession]
@@ -67,6 +71,48 @@ def update(
6771
self.update_expressions += args
6872
return self
6973

74+
def upsert(
75+
self,
76+
*args: Mapping[str, Any],
77+
on_insert: "DocType",
78+
session: Optional[ClientSession] = None
79+
) -> "UpdateQuery":
80+
"""
81+
Provide modifications to the upsert query.
82+
83+
:param args: *Union[dict, Mapping] - the modifications to apply.
84+
:param on_insert: DocType - document to insert if there is no matched
85+
document in the collection
86+
:param session: Optional[ClientSession]
87+
:return: UpdateMany query
88+
"""
89+
self.upsert_insert_doc = on_insert
90+
self.update(*args, session=session)
91+
return self
92+
93+
@abstractmethod
94+
async def _update(self):
95+
...
96+
97+
def __await__(self) -> Union[UpdateResult, InsertOneResult]:
98+
"""
99+
Run the query
100+
:return:
101+
"""
102+
103+
update_result = yield from self._update().__await__()
104+
if self.upsert_insert_doc is None:
105+
return update_result
106+
else:
107+
if update_result.matched_count == 0:
108+
return (
109+
yield from self.document_model.insert_one(
110+
document=self.upsert_insert_doc, session=self.session
111+
).__await__()
112+
)
113+
else:
114+
return update_result
115+
70116

71117
class UpdateMany(UpdateQuery):
72118
"""
@@ -89,12 +135,8 @@ def update_many(
89135
"""
90136
return self.update(*args, session=session)
91137

92-
def __await__(self) -> UpdateResult:
93-
"""
94-
Run the query
95-
:return:
96-
"""
97-
yield from self.document_model.get_motor_collection().update_many(
138+
async def _update(self):
139+
return await self.document_model.get_motor_collection().update_many(
98140
self.find_query, self.update_query, session=self.session
99141
)
100142

@@ -120,11 +162,7 @@ def update_one(
120162
"""
121163
return self.update(*args, session=session)
122164

123-
def __await__(self) -> UpdateResult:
124-
"""
125-
Run the query
126-
:return:
127-
"""
128-
yield from self.document_model.get_motor_collection().update_one(
165+
async def _update(self):
166+
return await self.document_model.get_motor_collection().update_one(
129167
self.find_query, self.update_query, session=self.session
130168
)

docs/changelog.md

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
Beanie project changes
44

5+
## [1.2.0] - 2021-06-25
6+
7+
### Added
8+
9+
- Upsert
10+
11+
### Implementation
12+
13+
- Issue - <https://github.com/roman-right/beanie/issues/64>
14+
515
## [1.1.6] - 2021-06-21
616

717
### Fix

docs/tutorial/update.md

+9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ Native MongoDB syntax is also accepted:
4444
await Product.find_one(Product.name == "Tony's").update({"$set": {Product.price: 3.33}})
4545
```
4646

47+
## Upsert
48+
49+
To insert a document if no one document was matched the search cruteria during update query, the `upsert` method can be used:
50+
```python
51+
await Product.find_one(Product.name == "Tony's").upsert(
52+
Set({Product.price: 3.33}),
53+
on_insert=Product(name="Tony's", price=3.33, category=chocolate)
54+
)
55+
```
4756

4857
## Deleting documents
4958

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "beanie"
3-
version = "1.1.6"
3+
version = "1.2.0"
44
description = "Asynchronous Python ODM for MongoDB"
55
authors = ["Roman <roman-right@protonmail.com>"]
66
license = "Apache-2.0"

tests/odm/conftest.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,31 @@ async def preset_documents(point):
7474
await Sample.insert_many(documents=docs)
7575

7676

77-
object_storage = {}
77+
@pytest.fixture()
78+
def sample_doc_not_saved(point):
79+
nested = Nested(
80+
integer=0,
81+
option_1=Option1(s="TEST"),
82+
union=Option1(s="TEST"),
83+
optional=None,
84+
)
85+
geo = GeoObject(
86+
coordinates=[
87+
point["longitude"],
88+
point["latitude"],
89+
]
90+
)
91+
return Sample(
92+
timestamp=datetime.utcnow(),
93+
increment=0,
94+
integer=0,
95+
float_num=0,
96+
string="TEST_NOT_SAVED",
97+
nested=nested,
98+
optional=None,
99+
union=Option1(s="TEST"),
100+
geo=geo,
101+
)
78102

79103

80104
@pytest.fixture()

tests/odm/query/test_update.py

+58
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import pytest
24

35
from beanie.odm.operators.update.general import Set, Max
@@ -136,3 +138,59 @@ async def test_update_many_with_session(preset_documents, session):
136138
assert len(result) == 3
137139
for sample in result:
138140
assert sample.increment == 100
141+
142+
143+
async def test_update_many_upsert_with_insert(
144+
preset_documents, sample_doc_not_saved
145+
):
146+
await Sample.find_many(Sample.integer > 100000).upsert(
147+
Set({Sample.integer: 100}), on_insert=sample_doc_not_saved
148+
)
149+
await asyncio.sleep(2)
150+
new_docs = await Sample.find_many(
151+
Sample.string == sample_doc_not_saved.string
152+
).to_list()
153+
assert len(new_docs) == 1
154+
doc = new_docs[0]
155+
assert doc.integer == sample_doc_not_saved.integer
156+
157+
158+
async def test_update_many_upsert_without_insert(
159+
preset_documents, sample_doc_not_saved
160+
):
161+
await Sample.find_many(Sample.integer > 1).upsert(
162+
Set({Sample.integer: 100}), on_insert=sample_doc_not_saved
163+
)
164+
await asyncio.sleep(2)
165+
new_docs = await Sample.find_many(
166+
Sample.string == sample_doc_not_saved.string
167+
).to_list()
168+
assert len(new_docs) == 0
169+
170+
171+
async def test_update_one_upsert_with_insert(
172+
preset_documents, sample_doc_not_saved
173+
):
174+
await Sample.find_one(Sample.integer > 100000).upsert(
175+
Set({Sample.integer: 100}), on_insert=sample_doc_not_saved
176+
)
177+
await asyncio.sleep(2)
178+
new_docs = await Sample.find_many(
179+
Sample.string == sample_doc_not_saved.string
180+
).to_list()
181+
assert len(new_docs) == 1
182+
doc = new_docs[0]
183+
assert doc.integer == sample_doc_not_saved.integer
184+
185+
186+
async def test_update_one_upsert_without_insert(
187+
preset_documents, sample_doc_not_saved
188+
):
189+
await Sample.find_one(Sample.integer > 1).upsert(
190+
Set({Sample.integer: 100}), on_insert=sample_doc_not_saved
191+
)
192+
await asyncio.sleep(2)
193+
new_docs = await Sample.find_many(
194+
Sample.string == sample_doc_not_saved.string
195+
).to_list()
196+
assert len(new_docs) == 0

tests/test_beanie.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
def test_version():
5-
assert __version__ == "1.1.6"
5+
assert __version__ == "1.2.0"

0 commit comments

Comments
 (0)