diff --git a/src/stac_fastapi/geoparquet/api.py b/src/stac_fastapi/geoparquet/api.py index d4a33e5..4552d37 100644 --- a/src/stac_fastapi/geoparquet/api.py +++ b/src/stac_fastapi/geoparquet/api.py @@ -1,13 +1,14 @@ +import asyncio import json import urllib.parse -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, TypedDict +from typing import Any, TypedDict, cast import obstore.store import pystac.utils -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request, Response from rustac import DuckdbClient from stac_fastapi.api.app import StacApi @@ -17,6 +18,46 @@ GEOPARQUET_MEDIA_TYPE = "application/vnd.apache.parquet" +# Global cache for collections and reload control +_collections_cache = None +_collections_cache_lock = asyncio.Lock() +_collections_cache_last_load = 0.0 + + +async def load_collections(settings: Settings) -> list[dict[str, Any]]: + if settings.stac_fastapi_collections_href: + if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme: + href = settings.stac_fastapi_collections_href + else: + href = "file://" + str( + Path(settings.stac_fastapi_collections_href).absolute() + ) + prefix, file_name = href.rsplit("/", 1) + store = obstore.store.from_url(prefix) + result = store.get(file_name) + collections = cast(list[dict[str, Any]], json.loads(bytes(result.bytes()))) + else: + collections = [] + return collections + + +async def collections_cache_refresher(settings: Settings) -> None: + global _collections_cache, _collections_cache_last_load + while True: + async with _collections_cache_lock: + _collections_cache = await load_collections(settings) + _collections_cache_last_load = asyncio.get_event_loop().time() + await asyncio.sleep(settings.stac_fastapi_collections_reload_seconds) + + +async def get_cached_collections(settings: Settings) -> list[dict[str, Any]]: + global _collections_cache, _collections_cache_last_load + async with _collections_cache_lock: + if _collections_cache is None: + _collections_cache = await load_collections(settings) + _collections_cache_last_load = asyncio.get_event_loop().time() + return _collections_cache + class State(TypedDict): """Application state.""" @@ -34,37 +75,57 @@ class State(TypedDict): """A mapping of collection id to geoparquet href.""" +# Middleware to inject latest collections/hrefs into request.state +def collections_hot_reload_middleware( + settings: Settings, +) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]: + async def middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + collections = await get_cached_collections(settings) + collection_dict = dict() + hrefs = dict() + for collection in collections: + if collection["id"] in collection_dict: + raise HTTPException( + 500, f"two collections with the same id: {collection['id']}" + ) + else: + collection_dict[collection["id"]] = collection + for key, asset in collection["assets"].items(): + if asset.get("type") == GEOPARQUET_MEDIA_TYPE: + if collection["id"] in hrefs: + raise HTTPException( + 500, f"two hrefs for one collection: {collection['id']}" + ) + else: + hrefs[collection["id"]] = pystac.utils.make_absolute_href( + asset["href"], + settings.stac_fastapi_collections_href, + start_is_dir=False, + ) + request.state.collections = collection_dict + request.state.hrefs = hrefs + response = await call_next(request) + return response + + return middleware + + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[State]: client = app.extra["duckdb_client"] settings: Settings = app.extra["settings"] - collections = app.extra["collections"] - collection_dict = dict() - hrefs = dict() - for collection in collections: - if collection["id"] in collection_dict: - raise HTTPException( - 500, f"two collections with the same id: {collection['id']}" - ) - else: - collection_dict[collection["id"]] = collection - for key, asset in collection["assets"].items(): - if asset.get("type") == GEOPARQUET_MEDIA_TYPE: - if collection["id"] in hrefs: - raise HTTPException( - 500, f"two hrefs for one collection: {collection['id']}" - ) - else: - hrefs[collection["id"]] = pystac.utils.make_absolute_href( - asset["href"], - settings.stac_fastapi_collections_href, - start_is_dir=False, - ) + # Start background refresher + app.state._collections_refresher = asyncio.create_task( + collections_cache_refresher(settings) + ) yield { "client": client, - "collections": collection_dict, - "hrefs": hrefs, + "collections": {}, + "hrefs": {}, } + app.state._collections_refresher.cancel() def create( @@ -80,20 +141,8 @@ def create( stac_fastapi_description="A stac-fastapi server backend by stac-geoparquet", ) - if settings.stac_fastapi_collections_href: - if urllib.parse.urlparse(settings.stac_fastapi_collections_href).scheme: - href = settings.stac_fastapi_collections_href - else: - href = "file://" + str( - Path(settings.stac_fastapi_collections_href).absolute() - ) - prefix, file_name = href.rsplit("/", 1) - store = obstore.store.from_url(prefix) - result = store.get(file_name) - collections = json.loads(bytes(result.bytes())) - else: - collections = [] - + # collections will be loaded and cached by the refresher + collections = [] if settings.stac_fastapi_geoparquet_href: collections.extend( collections_from_geoparquet_href( @@ -102,18 +151,22 @@ def create( ) ) + app = FastAPI( + lifespan=lifespan, + openapi_url=settings.openapi_url, + docs_url=settings.docs_url, + redoc_url=settings.docs_url, + settings=settings, + collections=collections, + duckdb_client=duckdb_client, + ) + # Add hot-reload middleware + app.middleware("http")(collections_hot_reload_middleware(settings)) + api = StacApi( settings=settings, client=Client(), - app=FastAPI( - lifespan=lifespan, - openapi_url=settings.openapi_url, - docs_url=settings.docs_url, - redoc_url=settings.docs_url, - settings=settings, - collections=collections, - duckdb_client=duckdb_client, - ), + app=app, search_get_request_model=GetSearchRequestModel, search_post_request_model=PostSearchRequestModel, extensions=EXTENSIONS, diff --git a/src/stac_fastapi/geoparquet/settings.py b/src/stac_fastapi/geoparquet/settings.py index 150f9f4..e0eb3f6 100644 --- a/src/stac_fastapi/geoparquet/settings.py +++ b/src/stac_fastapi/geoparquet/settings.py @@ -4,6 +4,9 @@ class Settings(ApiSettings): """stac-fastapi-geoparquet settings""" + stac_fastapi_collections_reload_seconds: int = 60 + """Interval in seconds to reload collections.json (default: 60).""" + stac_fastapi_collections_href: str | None = None """The href of a file containing JSON list of collections.