Skip to content

Added bulk item inserts for pgstac implementation #411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Add hook to allow adding dependencies to routes. ([#295](https://github.com/stac-utils/stac-fastapi/pull/295))
* Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367))
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))

### Changed

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""bulk transactions extension."""
import abc
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Type, Union

import attr
from fastapi import APIRouter, FastAPI
from pydantic import BaseModel

from stac_fastapi.api.models import create_request_model
from stac_fastapi.api.routes import create_sync_endpoint
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.search import APIRequest


class Items(BaseModel):
Expand Down Expand Up @@ -51,6 +52,24 @@ def bulk_item_insert(
raise NotImplementedError


@attr.s # type: ignore
class AsyncBaseBulkTransactionsClient(abc.ABC):
"""BulkTransactionsClient."""

@abc.abstractmethod
async def bulk_item_insert(self, items: Items, **kwargs) -> str:
"""Bulk creation of items.

Args:
items: list of items.

Returns:
Message indicating the status of the insert.

"""
raise NotImplementedError


@attr.s
class BulkTransactionExtension(ApiExtension):
"""Bulk Transaction Extension.
Expand All @@ -68,10 +87,24 @@ class BulkTransactionExtension(ApiExtension):

"""

client: BaseBulkTransactionsClient = attr.ib()
client: Union[
AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient
] = attr.ib()
conformance_classes: List[str] = attr.ib(default=list())
schema_href: Optional[str] = attr.ib(default=None)

def _create_endpoint(
self,
func: Callable,
request_type: Union[Type[APIRequest], Type[BaseModel], Dict],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseBulkTransactionsClient):
return create_async_endpoint(func, request_type)
elif isinstance(self.client, BaseBulkTransactionsClient):
return create_sync_endpoint(func, request_type)
raise NotImplementedError

def register(self, app: FastAPI) -> None:
"""Register the extension with a FastAPI application.

Expand All @@ -91,7 +124,7 @@ def register(self, app: FastAPI) -> None:
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=create_sync_endpoint(
endpoint=self._create_endpoint(
self.client.bulk_item_insert, items_request_model
),
)
Expand Down
4 changes: 3 additions & 1 deletion stac_fastapi/pgstac/stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.core import CoreCrudClient
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.extensions import QueryExtension
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

settings = Settings()
Expand All @@ -29,6 +30,7 @@
FieldsExtension(),
TokenPaginationExtension(),
ContextExtension(),
BulkTransactionExtension(client=BulkTransactionsClient()),
]

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
Expand Down
19 changes: 19 additions & 0 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import attr
from starlette.responses import JSONResponse, Response

from stac_fastapi.extensions.third_party.bulk_transactions import (
AsyncBaseBulkTransactionsClient,
Items,
)
from stac_fastapi.pgstac.db import dbfunc
from stac_fastapi.types import stac as stac_types
from stac_fastapi.types.core import AsyncBaseTransactionsClient
Expand Down Expand Up @@ -71,3 +75,18 @@ async def delete_collection(
pool = request.app.state.writepool
await dbfunc(pool, "delete_collection", collection_id)
return JSONResponse({"deleted collection": collection_id})


@attr.s
class BulkTransactionsClient(AsyncBaseBulkTransactionsClient):
"""Postgres bulk transactions."""

async def bulk_item_insert(self, items: Items, **kwargs) -> str:
"""Bulk item insertion using pgstac."""
request = kwargs["request"]
pool = request.app.state.writepool
items = list(items.items.values())
await dbfunc(pool, "create_items", items)

return_msg = f"Successfully added {len(items)} items."
return return_msg
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same format for bulk inserts as used in sqlalchemy implementation

27 changes: 27 additions & 0 deletions stac_fastapi/pgstac/tests/clients/test_postgres.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from copy import deepcopy
from typing import Callable

from stac_pydantic import Collection, Item
Expand Down Expand Up @@ -117,6 +118,32 @@ async def test_get_collection_items(app_client, load_test_collection, load_test_
assert len(fc["features"]) == 5


async def test_create_bulk_items(
app_client, load_test_data: Callable, load_test_collection
):
coll = load_test_collection
item = load_test_data("test_item.json")

items = {}
for _ in range(2):
_item = deepcopy(item)
_item["id"] = str(uuid.uuid4())
items[_item["id"]] = _item

payload = {"items": items}

resp = await app_client.post(
f"/collections/{coll.id}/bulk_items",
json=payload,
)
assert resp.status_code == 200
assert resp.text == '"Successfully added 2 items."'

for item_id in items.keys():
resp = await app_client.get(f"/collections/{coll.id}/items/{item_id}")
assert resp.status_code == 200


# TODO since right now puts implement upsert
# test_create_collection_already_exists
# test create_item_already_exists
Expand Down
4 changes: 3 additions & 1 deletion stac_fastapi/pgstac/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.core import CoreCrudClient
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.extensions import QueryExtension
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
Expand Down Expand Up @@ -117,6 +118,7 @@ def api_client(request, pg):
SortExtension(),
FieldsExtension(),
TokenPaginationExtension(),
BulkTransactionExtension(client=BulkTransactionsClient()),
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
Expand Down