Skip to content

Commit d3a5959

Browse files
author
Maxim Borisov
committed
Annotate decorators that wrap Document methods (#679)
Type checkers were complaining about missing `self` argument in decorated `Document` methods. This was caused by incomplete annotations of used decorators.
1 parent 72b35f9 commit d3a5959

File tree

4 files changed

+105
-31
lines changed

4 files changed

+105
-31
lines changed

beanie/odm/actions.py

+40-15
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
Optional,
1212
Tuple,
1313
Type,
14+
TypeVar,
1415
Union,
1516
)
1617

18+
from typing_extensions import ParamSpec
19+
1720
if TYPE_CHECKING:
18-
from beanie.odm.documents import Document
21+
from beanie.odm.documents import AsyncDocMethod, DocType, Document
22+
23+
P = ParamSpec("P")
24+
R = TypeVar("R")
1925

2026

2127
class EventTypes(str, Enum):
@@ -136,10 +142,14 @@ async def run_actions(
136142
await asyncio.gather(*coros)
137143

138144

145+
# `Any` because there is arbitrary attribute assignment on this type
146+
F = TypeVar("F", bound=Any)
147+
148+
139149
def register_action(
140-
event_types: Tuple[Union[List[EventTypes], EventTypes]],
150+
event_types: Tuple[Union[List[EventTypes], EventTypes], ...],
141151
action_direction: ActionDirections,
142-
):
152+
) -> Callable[[F], F]:
143153
"""
144154
Decorator. Base registration method.
145155
Used inside `before_event` and `after_event`
@@ -154,7 +164,7 @@ def register_action(
154164
else:
155165
final_event_types.append(event_type)
156166

157-
def decorator(f):
167+
def decorator(f: F) -> F:
158168
f.has_action = True
159169
f.event_types = final_event_types
160170
f.action_direction = action_direction
@@ -163,7 +173,9 @@ def decorator(f):
163173
return decorator
164174

165175

166-
def before_event(*args: Union[List[EventTypes], EventTypes]):
176+
def before_event(
177+
*args: Union[List[EventTypes], EventTypes]
178+
) -> Callable[[F], F]:
167179
"""
168180
Decorator. It adds action, which should run before mentioned one
169181
or many events happen
@@ -172,11 +184,13 @@ def before_event(*args: Union[List[EventTypes], EventTypes]):
172184
:return: None
173185
"""
174186
return register_action(
175-
action_direction=ActionDirections.BEFORE, event_types=args # type: ignore
187+
action_direction=ActionDirections.BEFORE, event_types=args
176188
)
177189

178190

179-
def after_event(*args: Union[List[EventTypes], EventTypes]):
191+
def after_event(
192+
*args: Union[List[EventTypes], EventTypes]
193+
) -> Callable[[F], F]:
180194
"""
181195
Decorator. It adds action, which should run after mentioned one
182196
or many events happen
@@ -186,26 +200,32 @@ def after_event(*args: Union[List[EventTypes], EventTypes]):
186200
"""
187201

188202
return register_action(
189-
action_direction=ActionDirections.AFTER, event_types=args # type: ignore
203+
action_direction=ActionDirections.AFTER, event_types=args
190204
)
191205

192206

193-
def wrap_with_actions(event_type: EventTypes):
207+
def wrap_with_actions(
208+
event_type: EventTypes,
209+
) -> Callable[
210+
["AsyncDocMethod[DocType, P, R]"], "AsyncDocMethod[DocType, P, R]"
211+
]:
194212
"""
195213
Helper function to wrap Document methods with
196214
before and after event listeners
197215
:param event_type: EventTypes - event types
198216
:return: None
199217
"""
200218

201-
def decorator(f: Callable):
219+
def decorator(
220+
f: "AsyncDocMethod[DocType, P, R]",
221+
) -> "AsyncDocMethod[DocType, P, R]":
202222
@wraps(f)
203223
async def wrapper(
204-
self,
205-
*args,
224+
self: "Document",
225+
*args: P.args,
206226
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
207-
**kwargs,
208-
):
227+
**kwargs: P.kwargs,
228+
) -> R:
209229
if skip_actions is None:
210230
skip_actions = []
211231

@@ -216,7 +236,12 @@ async def wrapper(
216236
exclude=skip_actions,
217237
)
218238

219-
result = await f(self, *args, skip_actions=skip_actions, **kwargs)
239+
result = await f(
240+
self,
241+
*args,
242+
skip_actions=skip_actions, # type: ignore[arg-type]
243+
**kwargs,
244+
)
220245

221246
await ActionRegistry.run_actions(
222247
self,

beanie/odm/documents.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from enum import Enum
44
from typing import (
55
Any,
6+
Awaitable,
7+
Callable,
68
ClassVar,
79
Dict,
810
Iterable,
@@ -32,6 +34,7 @@
3234
DeleteResult,
3335
InsertManyResult,
3436
)
37+
from typing_extensions import Concatenate, ParamSpec, TypeAlias
3538

3639
from beanie.exceptions import (
3740
CollectionWasNotInitialized,
@@ -104,6 +107,10 @@
104107
from pydantic import model_validator
105108

106109
DocType = TypeVar("DocType", bound="Document")
110+
P = ParamSpec("P")
111+
R = TypeVar("R")
112+
SyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
113+
AsyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], Awaitable[R]]
107114
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)
108115

109116

@@ -529,7 +536,7 @@ async def save(
529536
link_rule: WriteRules = WriteRules.DO_NOTHING,
530537
ignore_revision: bool = False,
531538
**kwargs,
532-
) -> None:
539+
) -> DocType:
533540
"""
534541
Update an existing model in the database or
535542
insert it if it does not yet exist.
@@ -605,12 +612,12 @@ async def save(
605612
@wrap_with_actions(EventTypes.SAVE_CHANGES)
606613
@validate_self_before
607614
async def save_changes(
608-
self,
615+
self: DocType,
609616
ignore_revision: bool = False,
610617
session: Optional[ClientSession] = None,
611618
bulk_writer: Optional[BulkWriter] = None,
612619
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
613-
) -> None:
620+
) -> Optional[DocType]:
614621
"""
615622
Save changes.
616623
State management usage must be turned on
@@ -632,7 +639,7 @@ async def save_changes(
632639
)
633640
else:
634641
return await self.set(
635-
changes, # type: ignore #TODO fix typing
642+
changes,
636643
ignore_revision=ignore_revision,
637644
session=session,
638645
bulk_writer=bulk_writer,
@@ -741,13 +748,13 @@ def update_all(
741748
)
742749

743750
def set(
744-
self,
751+
self: DocType,
745752
expression: Dict[Union[ExpressionField, str], Any],
746753
session: Optional[ClientSession] = None,
747754
bulk_writer: Optional[BulkWriter] = None,
748755
skip_sync: Optional[bool] = None,
749756
**kwargs,
750-
):
757+
) -> Awaitable[DocType]:
751758
"""
752759
Set values
753760

beanie/odm/utils/self_validation.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from functools import wraps
2-
from typing import TYPE_CHECKING, Callable
2+
from typing import TYPE_CHECKING, TypeVar
3+
4+
from typing_extensions import ParamSpec
35

46
if TYPE_CHECKING:
5-
from beanie.odm.documents import DocType
7+
from beanie.odm.documents import AsyncDocMethod, DocType
8+
9+
P = ParamSpec("P")
10+
R = TypeVar("R")
611

712

8-
def validate_self_before(f: Callable):
13+
def validate_self_before(
14+
f: "AsyncDocMethod[DocType, P, R]",
15+
) -> "AsyncDocMethod[DocType, P, R]":
916
@wraps(f)
10-
async def wrapper(self: "DocType", *args, **kwargs):
17+
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
1118
await self.validate_self(*args, **kwargs)
1219
return await f(self, *args, **kwargs)
1320

beanie/odm/utils/state.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import inspect
22
from functools import wraps
3-
from typing import TYPE_CHECKING, Callable
3+
from typing import TYPE_CHECKING, Callable, TypeVar, overload
4+
5+
from typing_extensions import ParamSpec
46

57
from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved
68

79
if TYPE_CHECKING:
8-
from beanie.odm.documents import DocType
10+
from beanie.odm.documents import AsyncDocMethod, DocType, SyncDocMethod
11+
12+
P = ParamSpec("P")
13+
R = TypeVar("R")
914

1015

1116
def check_if_state_saved(self: "DocType"):
@@ -17,7 +22,21 @@ def check_if_state_saved(self: "DocType"):
1722
raise StateNotSaved("No state was saved")
1823

1924

20-
def saved_state_needed(f: Callable):
25+
@overload
26+
def saved_state_needed(
27+
f: "AsyncDocMethod[DocType, P, R]",
28+
) -> "AsyncDocMethod[DocType, P, R]":
29+
...
30+
31+
32+
@overload
33+
def saved_state_needed(
34+
f: "SyncDocMethod[DocType, P, R]",
35+
) -> "SyncDocMethod[DocType, P, R]":
36+
...
37+
38+
39+
def saved_state_needed(f: Callable) -> Callable:
2140
@wraps(f)
2241
def sync_wrapper(self: "DocType", *args, **kwargs):
2342
check_if_state_saved(self)
@@ -44,7 +63,21 @@ def check_if_previous_state_saved(self: "DocType"):
4463
)
4564

4665

47-
def previous_saved_state_needed(f: Callable):
66+
@overload
67+
def previous_saved_state_needed(
68+
f: "AsyncDocMethod[DocType, P, R]",
69+
) -> "AsyncDocMethod[DocType, P, R]":
70+
...
71+
72+
73+
@overload
74+
def previous_saved_state_needed(
75+
f: "SyncDocMethod[DocType, P, R]",
76+
) -> "SyncDocMethod[DocType, P, R]":
77+
...
78+
79+
80+
def previous_saved_state_needed(f: Callable) -> Callable:
4881
@wraps(f)
4982
def sync_wrapper(self: "DocType", *args, **kwargs):
5083
check_if_previous_state_saved(self)
@@ -60,9 +93,11 @@ async def async_wrapper(self: "DocType", *args, **kwargs):
6093
return sync_wrapper
6194

6295

63-
def save_state_after(f: Callable):
96+
def save_state_after(
97+
f: "AsyncDocMethod[DocType, P, R]",
98+
) -> "AsyncDocMethod[DocType, P, R]":
6499
@wraps(f)
65-
async def wrapper(self: "DocType", *args, **kwargs):
100+
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
66101
result = await f(self, *args, **kwargs)
67102
self._save_state()
68103
return result

0 commit comments

Comments
 (0)