Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 103 additions & 50 deletions src/stac_fastapi/geoparquet/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import json
import urllib.parse
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, TypedDict
from typing import Any, TypedDict, cast

import obstore.store
import pystac.utils
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request, Response
from rustac import DuckdbClient
from stac_fastapi.api.app import StacApi

Expand All @@ -17,6 +18,46 @@

GEOPARQUET_MEDIA_TYPE = "application/vnd.apache.parquet"

# Global cache for collections and reload control
_collections_cache = None
_collections_cache_lock = asyncio.Lock()
_collections_cache_last_load = 0.0


async def load_collections(settings: Settings) -> list[dict[str, Any]]:
if settings.stac_fastapi_collections_href:
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
href = settings.stac_fastapi_collections_href
else:
href = "file://" + str(
Path(settings.stac_fastapi_collections_href).absolute()
)
prefix, file_name = href.rsplit("/", 1)
store = obstore.store.from_url(prefix)
result = store.get(file_name)
collections = cast(list[dict[str, Any]], json.loads(bytes(result.bytes())))
else:
collections = []
return collections


async def collections_cache_refresher(settings: Settings) -> None:
global _collections_cache, _collections_cache_last_load
while True:
async with _collections_cache_lock:
_collections_cache = await load_collections(settings)
_collections_cache_last_load = asyncio.get_event_loop().time()
await asyncio.sleep(settings.stac_fastapi_collections_reload_seconds)


async def get_cached_collections(settings: Settings) -> list[dict[str, Any]]:
global _collections_cache, _collections_cache_last_load
async with _collections_cache_lock:
if _collections_cache is None:
_collections_cache = await load_collections(settings)
_collections_cache_last_load = asyncio.get_event_loop().time()
return _collections_cache


class State(TypedDict):
"""Application state."""
Expand All @@ -34,37 +75,57 @@ class State(TypedDict):
"""A mapping of collection id to geoparquet href."""


# Middleware to inject latest collections/hrefs into request.state
def collections_hot_reload_middleware(
settings: Settings,
) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]:
async def middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
collections = await get_cached_collections(settings)
collection_dict = dict()
hrefs = dict()
for collection in collections:
if collection["id"] in collection_dict:
raise HTTPException(
500, f"two collections with the same id: {collection['id']}"
)
else:
collection_dict[collection["id"]] = collection
for key, asset in collection["assets"].items():
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
if collection["id"] in hrefs:
raise HTTPException(
500, f"two hrefs for one collection: {collection['id']}"
)
else:
hrefs[collection["id"]] = pystac.utils.make_absolute_href(
asset["href"],
settings.stac_fastapi_collections_href,
start_is_dir=False,
)
request.state.collections = collection_dict
request.state.hrefs = hrefs
response = await call_next(request)
return response

return middleware


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
client = app.extra["duckdb_client"]
settings: Settings = app.extra["settings"]
collections = app.extra["collections"]
collection_dict = dict()
hrefs = dict()
for collection in collections:
if collection["id"] in collection_dict:
raise HTTPException(
500, f"two collections with the same id: {collection['id']}"
)
else:
collection_dict[collection["id"]] = collection
for key, asset in collection["assets"].items():
if asset.get("type") == GEOPARQUET_MEDIA_TYPE:
if collection["id"] in hrefs:
raise HTTPException(
500, f"two hrefs for one collection: {collection['id']}"
)
else:
hrefs[collection["id"]] = pystac.utils.make_absolute_href(
asset["href"],
settings.stac_fastapi_collections_href,
start_is_dir=False,
)
# Start background refresher
app.state._collections_refresher = asyncio.create_task(
collections_cache_refresher(settings)
)
Comment on lines +120 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on creating a function to list and store the collections in the state https://github.com/developmentseed/tipg/blob/47e1f091ecd8e8962808753e891e9bb0d6947c31/tipg/main.py#L40-L41

yield {
"client": client,
"collections": collection_dict,
"hrefs": hrefs,
"collections": {},
"hrefs": {},
}
app.state._collections_refresher.cancel()


def create(
Expand All @@ -80,20 +141,8 @@ def create(
stac_fastapi_description="A stac-fastapi server backend by stac-geoparquet",
)

if settings.stac_fastapi_collections_href:
if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme:
href = settings.stac_fastapi_collections_href
else:
href = "file://" + str(
Path(settings.stac_fastapi_collections_href).absolute()
)
prefix, file_name = href.rsplit("/", 1)
store = obstore.store.from_url(prefix)
result = store.get(file_name)
collections = json.loads(bytes(result.bytes()))
else:
collections = []

# collections will be loaded and cached by the refresher
collections = []
if settings.stac_fastapi_geoparquet_href:
collections.extend(
collections_from_geoparquet_href(
Expand All @@ -102,18 +151,22 @@ def create(
)
)

app = FastAPI(
lifespan=lifespan,
openapi_url=settings.openapi_url,
docs_url=settings.docs_url,
redoc_url=settings.docs_url,
settings=settings,
collections=collections,
duckdb_client=duckdb_client,
)
# Add hot-reload middleware
app.middleware("http")(collections_hot_reload_middleware(settings))

api = StacApi(
settings=settings,
client=Client(),
app=FastAPI(
lifespan=lifespan,
openapi_url=settings.openapi_url,
docs_url=settings.docs_url,
redoc_url=settings.docs_url,
settings=settings,
collections=collections,
duckdb_client=duckdb_client,
),
app=app,
search_get_request_model=GetSearchRequestModel,
search_post_request_model=PostSearchRequestModel,
extensions=EXTENSIONS,
Expand Down
3 changes: 3 additions & 0 deletions src/stac_fastapi/geoparquet/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
class Settings(ApiSettings):
"""stac-fastapi-geoparquet settings"""

stac_fastapi_collections_reload_seconds: int = 60
"""Interval in seconds to reload collections.json (default: 60)."""

stac_fastapi_collections_href: str | None = None
"""The href of a file containing JSON list of collections.

Expand Down