Skip to content

Commit 33611f7

Browse files
author
Richard Smith
authored
refactoring to create dynamic request model (#213)
* refactoring to create dynamic request model * Correcting missed reference to configurable pagination class and updating conformance classes. * Changing pydantic limit model to have max/min [skip ci] * fixing against test suite for sqlalchemy * wrapping function calls to make get and post request models * working to pass tests for pgstac * changed constraints on limit parameter to pass test * fixing against pgstac tests * adding changelog entry
1 parent 2d78ca0 commit 33611f7

File tree

37 files changed

+791
-407
lines changed

37 files changed

+791
-407
lines changed

CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Changed
88

9+
* Refactor to remove hardcoded search request models. Request models are now dynamically created based on the enabled extensions.
10+
([#213](https://github.com/stac-utils/stac-fastapi/pull/213))
11+
912
### Removed
1013

1114
### Fixed

stac_fastapi/api/stac_fastapi/api/app.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi.openapi.utils import get_openapi
88
from pydantic import BaseModel
99
from stac_pydantic import Collection, Item, ItemCollection
10-
from stac_pydantic.api import ConformanceClasses, LandingPage, Search
10+
from stac_pydantic.api import ConformanceClasses, LandingPage
1111
from stac_pydantic.api.collections import Collections
1212
from stac_pydantic.version import STAC_VERSION
1313
from starlette.responses import JSONResponse, Response
@@ -20,18 +20,17 @@
2020
GeoJSONResponse,
2121
ItemCollectionUri,
2222
ItemUri,
23-
SearchGetRequest,
24-
_create_request_model,
23+
create_request_model,
2524
)
2625
from stac_fastapi.api.openapi import update_openapi
2726
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
2827

2928
# TODO: make this module not depend on `stac_fastapi.extensions`
30-
from stac_fastapi.extensions.core import FieldsExtension
29+
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
3130
from stac_fastapi.types.config import ApiSettings, Settings
3231
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
3332
from stac_fastapi.types.extension import ApiExtension
34-
from stac_fastapi.types.search import STACSearch
33+
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest
3534

3635

3736
@attr.s
@@ -76,9 +75,13 @@ class StacApi:
7675
api_version: str = attr.ib(default="0.1")
7776
stac_version: str = attr.ib(default=STAC_VERSION)
7877
description: str = attr.ib(default="stac-fastapi")
79-
search_request_model: Type[Search] = attr.ib(default=STACSearch)
80-
search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest)
81-
item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri)
78+
search_get_request_model: Type[BaseSearchGetRequest] = attr.ib(
79+
default=BaseSearchGetRequest
80+
)
81+
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
82+
default=BaseSearchPostRequest
83+
)
84+
pagination_extension = attr.ib(default=TokenPaginationExtension)
8285
response_class: Type[Response] = attr.ib(default=JSONResponse)
8386
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))
8487

@@ -176,7 +179,6 @@ def register_post_search(self):
176179
Returns:
177180
None
178181
"""
179-
search_request_model = _create_request_model(self.search_request_model)
180182
fields_ext = self.get_extension(FieldsExtension)
181183
self.router.add_api_route(
182184
name="Search",
@@ -189,7 +191,7 @@ def register_post_search(self):
189191
response_model_exclude_none=True,
190192
methods=["POST"],
191193
endpoint=self._create_endpoint(
192-
self.client.post_search, search_request_model, GeoJSONResponse
194+
self.client.post_search, self.search_post_request_model, GeoJSONResponse
193195
),
194196
)
195197

@@ -211,7 +213,7 @@ def register_get_search(self):
211213
response_model_exclude_none=True,
212214
methods=["GET"],
213215
endpoint=self._create_endpoint(
214-
self.client.get_search, self.search_get_request, GeoJSONResponse
216+
self.client.get_search, self.search_get_request_model, GeoJSONResponse
215217
),
216218
)
217219

@@ -261,6 +263,12 @@ def register_get_item_collection(self):
261263
Returns:
262264
None
263265
"""
266+
get_pagination_model = self.get_extension(self.pagination_extension).GET
267+
request_model = create_request_model(
268+
"ItemCollectionURI",
269+
base_model=ItemCollectionUri,
270+
mixins=[get_pagination_model],
271+
)
264272
self.router.add_api_route(
265273
name="Get ItemCollection",
266274
path="/collections/{collection_id}/items",
@@ -272,9 +280,7 @@ def register_get_item_collection(self):
272280
response_model_exclude_none=True,
273281
methods=["GET"],
274282
endpoint=self._create_endpoint(
275-
self.client.item_collection,
276-
self.item_collection_uri,
277-
self.response_class,
283+
self.client.item_collection, request_model, self.response_class
278284
),
279285
)
280286

stac_fastapi/api/stac_fastapi/api/models.py

+108-87
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,100 @@
11
"""api request/response models."""
22

3-
import abc
43
import importlib
5-
from typing import Dict, Optional, Type, Union
4+
from typing import Optional, Type, Union
65

76
import attr
87
from fastapi import Body, Path
98
from pydantic import BaseModel, create_model
109
from pydantic.fields import UndefinedType
1110

12-
13-
def _create_request_model(model: Type[BaseModel]) -> Type[BaseModel]:
11+
from stac_fastapi.types.extension import ApiExtension
12+
from stac_fastapi.types.search import (
13+
APIRequest,
14+
BaseSearchGetRequest,
15+
BaseSearchPostRequest,
16+
)
17+
18+
19+
def create_request_model(
20+
model_name="SearchGetRequest",
21+
base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest,
22+
extensions: Optional[ApiExtension] = None,
23+
mixins: Optional[Union[BaseModel, APIRequest]] = None,
24+
request_type: Optional[str] = "GET",
25+
) -> Union[Type[BaseModel], APIRequest]:
1426
"""Create a pydantic model for validating request bodies."""
1527
fields = {}
16-
for (k, v) in model.__fields__.items():
17-
# TODO: Filter out fields based on which extensions are present
18-
field_info = v.field_info
19-
body = Body(
20-
None
21-
if isinstance(field_info.default, UndefinedType)
22-
else field_info.default,
23-
default_factory=field_info.default_factory,
24-
alias=field_info.alias,
25-
alias_priority=field_info.alias_priority,
26-
title=field_info.title,
27-
description=field_info.description,
28-
const=field_info.const,
29-
gt=field_info.gt,
30-
ge=field_info.ge,
31-
lt=field_info.lt,
32-
le=field_info.le,
33-
multiple_of=field_info.multiple_of,
34-
min_items=field_info.min_items,
35-
max_items=field_info.max_items,
36-
min_length=field_info.min_length,
37-
max_length=field_info.max_length,
38-
regex=field_info.regex,
39-
extra=field_info.extra,
40-
)
41-
fields[k] = (v.outer_type_, body)
42-
return create_model(model.__name__, **fields, __base__=model)
43-
44-
45-
@attr.s # type:ignore
46-
class APIRequest(abc.ABC):
47-
"""Generic API Request base class."""
48-
49-
@abc.abstractmethod
50-
def kwargs(self) -> Dict:
51-
"""Transform api request params into format which matches the signature of the endpoint."""
52-
...
28+
extension_models = []
29+
30+
# Check extensions for additional parameters to search
31+
for extension in extensions or []:
32+
if extension_model := extension.get_request_model(request_type):
33+
extension_models.append(extension_model)
34+
35+
mixins = mixins or []
36+
37+
models = [base_model] + extension_models + mixins
38+
39+
# Handle GET requests
40+
if all([issubclass(m, APIRequest) for m in models]):
41+
return attr.make_class(model_name, attrs={}, bases=tuple(models))
42+
43+
# Handle POST requests
44+
elif all([issubclass(m, BaseModel) for m in models]):
45+
for model in models:
46+
for (k, v) in model.__fields__.items():
47+
field_info = v.field_info
48+
body = Body(
49+
None
50+
if isinstance(field_info.default, UndefinedType)
51+
else field_info.default,
52+
default_factory=field_info.default_factory,
53+
alias=field_info.alias,
54+
alias_priority=field_info.alias_priority,
55+
title=field_info.title,
56+
description=field_info.description,
57+
const=field_info.const,
58+
gt=field_info.gt,
59+
ge=field_info.ge,
60+
lt=field_info.lt,
61+
le=field_info.le,
62+
multiple_of=field_info.multiple_of,
63+
min_items=field_info.min_items,
64+
max_items=field_info.max_items,
65+
min_length=field_info.min_length,
66+
max_length=field_info.max_length,
67+
regex=field_info.regex,
68+
extra=field_info.extra,
69+
)
70+
fields[k] = (v.outer_type_, body)
71+
return create_model(model_name, **fields, __base__=base_model)
72+
73+
raise TypeError("Mixed Request Model types. Check extension request types.")
74+
75+
76+
def create_get_request_model(
77+
extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest
78+
):
79+
"""Wrap create_request_model to create the GET request model."""
80+
return create_request_model(
81+
"SearchGetRequest",
82+
base_model=BaseSearchGetRequest,
83+
extensions=extensions,
84+
request_type="GET",
85+
)
86+
87+
88+
def create_post_request_model(
89+
extensions, base_model: BaseSearchPostRequest = BaseSearchGetRequest
90+
):
91+
"""Wrap create_request_model to create the POST request model."""
92+
return create_request_model(
93+
"SearchPostRequest",
94+
base_model=BaseSearchPostRequest,
95+
extensions=extensions,
96+
request_type="POST",
97+
)
5398

5499

55100
@attr.s # type:ignore
@@ -58,76 +103,52 @@ class CollectionUri(APIRequest):
58103

59104
collection_id: str = attr.ib(default=Path(..., description="Collection ID"))
60105

61-
def kwargs(self) -> Dict:
62-
"""kwargs."""
63-
return {"id": self.collection_id}
64-
65106

66107
@attr.s
67108
class ItemUri(CollectionUri):
68109
"""Delete item."""
69110

70111
item_id: str = attr.ib(default=Path(..., description="Item ID"))
71112

72-
def kwargs(self) -> Dict:
73-
"""kwargs."""
74-
return {"collection_id": self.collection_id, "item_id": self.item_id}
75-
76113

77114
@attr.s
78115
class EmptyRequest(APIRequest):
79116
"""Empty request."""
80117

81-
def kwargs(self) -> Dict:
82-
"""kwargs."""
83-
return {}
118+
...
84119

85120

86121
@attr.s
87122
class ItemCollectionUri(CollectionUri):
88123
"""Get item collection."""
89124

90125
limit: int = attr.ib(default=10)
91-
token: str = attr.ib(default=None)
92126

93-
def kwargs(self) -> Dict:
94-
"""kwargs."""
95-
return {
96-
"id": self.collection_id,
97-
"limit": self.limit,
98-
"token": self.token,
99-
}
127+
128+
class POSTTokenPagination(BaseModel):
129+
"""Token pagination model for POST requests."""
130+
131+
token: Optional[str] = None
100132

101133

102134
@attr.s
103-
class SearchGetRequest(APIRequest):
104-
"""GET search request."""
105-
106-
collections: Optional[str] = attr.ib(default=None)
107-
ids: Optional[str] = attr.ib(default=None)
108-
bbox: Optional[str] = attr.ib(default=None)
109-
datetime: Optional[Union[str]] = attr.ib(default=None)
110-
limit: Optional[int] = attr.ib(default=10)
111-
query: Optional[str] = attr.ib(default=None)
135+
class GETTokenPagination(APIRequest):
136+
"""Token pagination for GET requests."""
137+
112138
token: Optional[str] = attr.ib(default=None)
113-
fields: Optional[str] = attr.ib(default=None)
114-
sortby: Optional[str] = attr.ib(default=None)
115-
116-
def kwargs(self) -> Dict:
117-
"""kwargs."""
118-
return {
119-
"collections": self.collections.split(",")
120-
if self.collections
121-
else self.collections,
122-
"ids": self.ids.split(",") if self.ids else self.ids,
123-
"bbox": self.bbox.split(",") if self.bbox else self.bbox,
124-
"datetime": self.datetime,
125-
"limit": self.limit,
126-
"query": self.query,
127-
"token": self.token,
128-
"fields": self.fields.split(",") if self.fields else self.fields,
129-
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
130-
}
139+
140+
141+
class POSTPagination(BaseModel):
142+
"""Page based pagination for POST requests."""
143+
144+
page: Optional[str] = None
145+
146+
147+
@attr.s
148+
class GETPagination(APIRequest):
149+
"""Page based pagination for GET requests."""
150+
151+
page: Optional[str] = attr.ib(default=None)
131152

132153

133154
# Test for ORJSON and use it rather than stdlib JSON where supported

stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .context import ContextExtension
55
from .fields import FieldsExtension
66
from .filter import FilterExtension
7+
from .pagination import PaginationExtension, TokenPaginationExtension
78
from .query import QueryExtension
89
from .sort import SortExtension
910
from .transaction import TransactionExtension
@@ -12,8 +13,10 @@
1213
"ContextExtension",
1314
"FieldsExtension",
1415
"FilterExtension",
16+
"PaginationExtension",
1517
"QueryExtension",
1618
"SortExtension",
1719
"TilesExtension",
20+
"TokenPaginationExtension",
1821
"TransactionExtension",
1922
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Fields extension module."""
2+
3+
4+
from .fields import FieldsExtension
5+
6+
__all__ = ["FieldsExtension"]

0 commit comments

Comments
 (0)