1
+ from abc import abstractmethod
1
2
from typing import (
2
3
List ,
3
4
Type ,
6
7
Mapping ,
7
8
Any ,
8
9
Dict ,
10
+ Union ,
9
11
)
10
12
11
13
from pymongo .client_session import ClientSession
12
- from pymongo .results import UpdateResult
14
+ from pymongo .results import UpdateResult , InsertOneResult
13
15
14
16
from beanie .odm .interfaces .session import SessionMethods
15
17
from beanie .odm .interfaces .update import (
@@ -40,6 +42,8 @@ def __init__(
40
42
self .find_query = find_query
41
43
self .update_expressions : List [Mapping [str , Any ]] = []
42
44
self .session = None
45
+ self .is_upsert = False
46
+ self .upsert_insert_doc : Optional ["DocType" ] = None
43
47
44
48
@property
45
49
def update_query (self ) -> Dict [str , Any ]:
@@ -57,7 +61,7 @@ def update(
57
61
self , * args : Mapping [str , Any ], session : Optional [ClientSession ] = None
58
62
) -> "UpdateQuery" :
59
63
"""
60
- Provide modifications to the update query. The same as `update()`
64
+ Provide modifications to the update query.
61
65
62
66
:param args: *Union[dict, Mapping] - the modifications to apply.
63
67
:param session: Optional[ClientSession]
@@ -67,6 +71,48 @@ def update(
67
71
self .update_expressions += args
68
72
return self
69
73
74
+ def upsert (
75
+ self ,
76
+ * args : Mapping [str , Any ],
77
+ on_insert : "DocType" ,
78
+ session : Optional [ClientSession ] = None
79
+ ) -> "UpdateQuery" :
80
+ """
81
+ Provide modifications to the upsert query.
82
+
83
+ :param args: *Union[dict, Mapping] - the modifications to apply.
84
+ :param on_insert: DocType - document to insert if there is no matched
85
+ document in the collection
86
+ :param session: Optional[ClientSession]
87
+ :return: UpdateMany query
88
+ """
89
+ self .upsert_insert_doc = on_insert
90
+ self .update (* args , session = session )
91
+ return self
92
+
93
+ @abstractmethod
94
+ async def _update (self ):
95
+ ...
96
+
97
+ def __await__ (self ) -> Union [UpdateResult , InsertOneResult ]:
98
+ """
99
+ Run the query
100
+ :return:
101
+ """
102
+
103
+ update_result = yield from self ._update ().__await__ ()
104
+ if self .upsert_insert_doc is None :
105
+ return update_result
106
+ else :
107
+ if update_result .matched_count == 0 :
108
+ return (
109
+ yield from self .document_model .insert_one (
110
+ document = self .upsert_insert_doc , session = self .session
111
+ ).__await__ ()
112
+ )
113
+ else :
114
+ return update_result
115
+
70
116
71
117
class UpdateMany (UpdateQuery ):
72
118
"""
@@ -89,12 +135,8 @@ def update_many(
89
135
"""
90
136
return self .update (* args , session = session )
91
137
92
- def __await__ (self ) -> UpdateResult :
93
- """
94
- Run the query
95
- :return:
96
- """
97
- yield from self .document_model .get_motor_collection ().update_many (
138
+ async def _update (self ):
139
+ return await self .document_model .get_motor_collection ().update_many (
98
140
self .find_query , self .update_query , session = self .session
99
141
)
100
142
@@ -120,11 +162,7 @@ def update_one(
120
162
"""
121
163
return self .update (* args , session = session )
122
164
123
- def __await__ (self ) -> UpdateResult :
124
- """
125
- Run the query
126
- :return:
127
- """
128
- yield from self .document_model .get_motor_collection ().update_one (
165
+ async def _update (self ):
166
+ return await self .document_model .get_motor_collection ().update_one (
129
167
self .find_query , self .update_query , session = self .session
130
168
)
0 commit comments