1
1
"""api request/response models."""
2
2
3
- import abc
4
3
import importlib
5
- from typing import Dict , Optional , Type , Union
4
+ from typing import Optional , Type , Union
6
5
7
6
import attr
8
7
from fastapi import Body , Path
9
8
from pydantic import BaseModel , create_model
10
9
from pydantic .fields import UndefinedType
11
10
12
-
13
- def _create_request_model (model : Type [BaseModel ]) -> Type [BaseModel ]:
11
+ from stac_fastapi .types .extension import ApiExtension
12
+ from stac_fastapi .types .search import (
13
+ APIRequest ,
14
+ BaseSearchGetRequest ,
15
+ BaseSearchPostRequest ,
16
+ )
17
+
18
+
19
+ def create_request_model (
20
+ model_name = "SearchGetRequest" ,
21
+ base_model : Union [Type [BaseModel ], APIRequest ] = BaseSearchGetRequest ,
22
+ extensions : Optional [ApiExtension ] = None ,
23
+ mixins : Optional [Union [BaseModel , APIRequest ]] = None ,
24
+ request_type : Optional [str ] = "GET" ,
25
+ ) -> Union [Type [BaseModel ], APIRequest ]:
14
26
"""Create a pydantic model for validating request bodies."""
15
27
fields = {}
16
- for (k , v ) in model .__fields__ .items ():
17
- # TODO: Filter out fields based on which extensions are present
18
- field_info = v .field_info
19
- body = Body (
20
- None
21
- if isinstance (field_info .default , UndefinedType )
22
- else field_info .default ,
23
- default_factory = field_info .default_factory ,
24
- alias = field_info .alias ,
25
- alias_priority = field_info .alias_priority ,
26
- title = field_info .title ,
27
- description = field_info .description ,
28
- const = field_info .const ,
29
- gt = field_info .gt ,
30
- ge = field_info .ge ,
31
- lt = field_info .lt ,
32
- le = field_info .le ,
33
- multiple_of = field_info .multiple_of ,
34
- min_items = field_info .min_items ,
35
- max_items = field_info .max_items ,
36
- min_length = field_info .min_length ,
37
- max_length = field_info .max_length ,
38
- regex = field_info .regex ,
39
- extra = field_info .extra ,
40
- )
41
- fields [k ] = (v .outer_type_ , body )
42
- return create_model (model .__name__ , ** fields , __base__ = model )
43
-
44
-
45
- @attr .s # type:ignore
46
- class APIRequest (abc .ABC ):
47
- """Generic API Request base class."""
48
-
49
- @abc .abstractmethod
50
- def kwargs (self ) -> Dict :
51
- """Transform api request params into format which matches the signature of the endpoint."""
52
- ...
28
+ extension_models = []
29
+
30
+ # Check extensions for additional parameters to search
31
+ for extension in extensions or []:
32
+ if extension_model := extension .get_request_model (request_type ):
33
+ extension_models .append (extension_model )
34
+
35
+ mixins = mixins or []
36
+
37
+ models = [base_model ] + extension_models + mixins
38
+
39
+ # Handle GET requests
40
+ if all ([issubclass (m , APIRequest ) for m in models ]):
41
+ return attr .make_class (model_name , attrs = {}, bases = tuple (models ))
42
+
43
+ # Handle POST requests
44
+ elif all ([issubclass (m , BaseModel ) for m in models ]):
45
+ for model in models :
46
+ for (k , v ) in model .__fields__ .items ():
47
+ field_info = v .field_info
48
+ body = Body (
49
+ None
50
+ if isinstance (field_info .default , UndefinedType )
51
+ else field_info .default ,
52
+ default_factory = field_info .default_factory ,
53
+ alias = field_info .alias ,
54
+ alias_priority = field_info .alias_priority ,
55
+ title = field_info .title ,
56
+ description = field_info .description ,
57
+ const = field_info .const ,
58
+ gt = field_info .gt ,
59
+ ge = field_info .ge ,
60
+ lt = field_info .lt ,
61
+ le = field_info .le ,
62
+ multiple_of = field_info .multiple_of ,
63
+ min_items = field_info .min_items ,
64
+ max_items = field_info .max_items ,
65
+ min_length = field_info .min_length ,
66
+ max_length = field_info .max_length ,
67
+ regex = field_info .regex ,
68
+ extra = field_info .extra ,
69
+ )
70
+ fields [k ] = (v .outer_type_ , body )
71
+ return create_model (model_name , ** fields , __base__ = base_model )
72
+
73
+ raise TypeError ("Mixed Request Model types. Check extension request types." )
74
+
75
+
76
+ def create_get_request_model (
77
+ extensions , base_model : BaseSearchGetRequest = BaseSearchGetRequest
78
+ ):
79
+ """Wrap create_request_model to create the GET request model."""
80
+ return create_request_model (
81
+ "SearchGetRequest" ,
82
+ base_model = BaseSearchGetRequest ,
83
+ extensions = extensions ,
84
+ request_type = "GET" ,
85
+ )
86
+
87
+
88
+ def create_post_request_model (
89
+ extensions , base_model : BaseSearchPostRequest = BaseSearchGetRequest
90
+ ):
91
+ """Wrap create_request_model to create the POST request model."""
92
+ return create_request_model (
93
+ "SearchPostRequest" ,
94
+ base_model = BaseSearchPostRequest ,
95
+ extensions = extensions ,
96
+ request_type = "POST" ,
97
+ )
53
98
54
99
55
100
@attr .s # type:ignore
@@ -58,76 +103,52 @@ class CollectionUri(APIRequest):
58
103
59
104
collection_id : str = attr .ib (default = Path (..., description = "Collection ID" ))
60
105
61
- def kwargs (self ) -> Dict :
62
- """kwargs."""
63
- return {"id" : self .collection_id }
64
-
65
106
66
107
@attr .s
67
108
class ItemUri (CollectionUri ):
68
109
"""Delete item."""
69
110
70
111
item_id : str = attr .ib (default = Path (..., description = "Item ID" ))
71
112
72
- def kwargs (self ) -> Dict :
73
- """kwargs."""
74
- return {"collection_id" : self .collection_id , "item_id" : self .item_id }
75
-
76
113
77
114
@attr .s
78
115
class EmptyRequest (APIRequest ):
79
116
"""Empty request."""
80
117
81
- def kwargs (self ) -> Dict :
82
- """kwargs."""
83
- return {}
118
+ ...
84
119
85
120
86
121
@attr .s
87
122
class ItemCollectionUri (CollectionUri ):
88
123
"""Get item collection."""
89
124
90
125
limit : int = attr .ib (default = 10 )
91
- token : str = attr .ib (default = None )
92
126
93
- def kwargs (self ) -> Dict :
94
- """kwargs."""
95
- return {
96
- "id" : self .collection_id ,
97
- "limit" : self .limit ,
98
- "token" : self .token ,
99
- }
127
+
128
+ class POSTTokenPagination (BaseModel ):
129
+ """Token pagination model for POST requests."""
130
+
131
+ token : Optional [str ] = None
100
132
101
133
102
134
@attr .s
103
- class SearchGetRequest (APIRequest ):
104
- """GET search request."""
105
-
106
- collections : Optional [str ] = attr .ib (default = None )
107
- ids : Optional [str ] = attr .ib (default = None )
108
- bbox : Optional [str ] = attr .ib (default = None )
109
- datetime : Optional [Union [str ]] = attr .ib (default = None )
110
- limit : Optional [int ] = attr .ib (default = 10 )
111
- query : Optional [str ] = attr .ib (default = None )
135
+ class GETTokenPagination (APIRequest ):
136
+ """Token pagination for GET requests."""
137
+
112
138
token : Optional [str ] = attr .ib (default = None )
113
- fields : Optional [str ] = attr .ib (default = None )
114
- sortby : Optional [str ] = attr .ib (default = None )
115
-
116
- def kwargs (self ) -> Dict :
117
- """kwargs."""
118
- return {
119
- "collections" : self .collections .split ("," )
120
- if self .collections
121
- else self .collections ,
122
- "ids" : self .ids .split ("," ) if self .ids else self .ids ,
123
- "bbox" : self .bbox .split ("," ) if self .bbox else self .bbox ,
124
- "datetime" : self .datetime ,
125
- "limit" : self .limit ,
126
- "query" : self .query ,
127
- "token" : self .token ,
128
- "fields" : self .fields .split ("," ) if self .fields else self .fields ,
129
- "sortby" : self .sortby .split ("," ) if self .sortby else self .sortby ,
130
- }
139
+
140
+
141
+ class POSTPagination (BaseModel ):
142
+ """Page based pagination for POST requests."""
143
+
144
+ page : Optional [str ] = None
145
+
146
+
147
+ @attr .s
148
+ class GETPagination (APIRequest ):
149
+ """Page based pagination for GET requests."""
150
+
151
+ page : Optional [str ] = attr .ib (default = None )
131
152
132
153
133
154
# Test for ORJSON and use it rather than stdlib JSON where supported
0 commit comments