Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions cads_worker/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from typing import Annotated, TypedDict

import cacholote
import cads_broker.object_storage
import structlog
import typer
from typer import Option

from . import config, utils
from . import config, models

config.configure_logger()
LOGGER = structlog.get_logger(__name__)
Expand Down Expand Up @@ -35,17 +36,18 @@ class CleanerKwargs(TypedDict):

def _cache_cleaner() -> None:
use_database = strtobool(os.environ.get("USE_DATABASE", "1"))
cleaner_kwargs = CleanerKwargs(
maxsize=int(os.environ.get("MAX_SIZE", 1_000_000_000)),
method=os.environ.get("METHOD", "LRU"),
delete_unknown_files=not use_database,
lock_validity_period=float(os.environ.get("LOCK_VALIDITY_PERIOD", 86400)),
use_database=use_database,
depth=int(os.getenv("CACHE_DEPTH", 2)),
batch_size=int(os.getenv("BATCH_SIZE", 0)) or None,
batch_delay=float(os.getenv("BATCH_DELAY", 0)),
)
for cache_files_urlpath in utils.parse_data_volumes_config():
volumes = models.DataVolumes.from_yaml().volumes
for cache_files_urlpath, volume_config in volumes.items():
cleaner_kwargs = CleanerKwargs(
maxsize=volume_config.max_size,
method=os.environ.get("METHOD", "LRU"),
delete_unknown_files=not use_database,
lock_validity_period=float(os.environ.get("LOCK_VALIDITY_PERIOD", 86400)),
use_database=use_database,
depth=int(os.getenv("CACHE_DEPTH", 2)),
batch_size=int(os.getenv("BATCH_SIZE", 0)) or None,
batch_delay=float(os.getenv("BATCH_DELAY", 0)),
)
cacholote.config.set(cache_files_urlpath=cache_files_urlpath)
LOGGER.info(
"Running cache cleaner",
Expand Down Expand Up @@ -115,9 +117,30 @@ def _expire_cache_entries(
return count


def _init_buckets() -> None:
object_storage_url = os.environ["OBJECT_STORAGE_URL"]
storage_kws: dict[str, str] = {
"aws_access_key_id": os.environ["STORAGE_ADMIN"],
"aws_secret_access_key": os.environ["STORAGE_PASSWORD"],
}
LOGGER.info("Initializing buckets", object_storage_url=object_storage_url)
data_volumes = models.DataVolumes.from_yaml().volumes
for data_volume in data_volumes:
if data_volume.startswith("s3://"):
LOGGER.info("Initializing bucket", data_volume=data_volume)
cads_broker.object_storage.create_download_bucket(
data_volume, object_storage_url, **storage_kws
)
LOGGER.info("Buckets initialized")


def cache_cleaner() -> None:
typer.run(_cache_cleaner)


def expire_cache_entries() -> None:
typer.run(_expire_cache_entries)


def init_buckets() -> None:
typer.run(_init_buckets)
27 changes: 27 additions & 0 deletions cads_worker/filesystems.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import fsspec.implementations.local
from fsspec.utils import stringify_path


class CCIFileSystem(fsspec.implementations.local.LocalFileSystem):
protocol = "cci"

@classmethod
def _strip_protocol(cls, path):
assert isinstance(cls.protocol, str)
path = stringify_path(path)
if path.startswith(f"{cls.protocol}:"):
path = path.replace(cls.protocol, "file", 1)
return super()._strip_protocol(path)

def unstrip_protocol(self, name):
assert isinstance(self.protocol, str)
name = self._strip_protocol(name)
return f"{self.protocol}://{name}"


class CCI1FileSystem(CCIFileSystem):
protocol = "cci1"


class CCI2FileSystem(CCIFileSystem):
protocol = "cci2"
41 changes: 41 additions & 0 deletions cads_worker/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import random
from typing import Self

import yaml
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt


def get_env_max_size() -> int:
return int(os.getenv("MAX_SIZE", "1_000_000_000"))


class DataVolumeConfig(BaseModel):
weight: NonNegativeFloat = 1
max_size: NonNegativeInt = Field(default_factory=get_env_max_size)


class DataVolumes(BaseModel):
volumes: dict[str, DataVolumeConfig]

def get_random_volume(self) -> str:
(volume,) = random.choices(
list(self.volumes),
weights=[config.weight for config in self.volumes.values()],
k=1,
)
return volume

@classmethod
def from_yaml(cls, path: str | None = None) -> Self:
if path is None:
path = os.environ["DATA_VOLUMES_CONFIG"]

with open(path) as f:
raw_dict = yaml.safe_load(f)
return cls(
volumes={
k: DataVolumeConfig(**v) if v else DataVolumeConfig()
for k, v in raw_dict.items()
}
)
12 changes: 0 additions & 12 deletions cads_worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,6 @@
from collections.abc import Iterator


def parse_data_volumes_config(path: str | None = None) -> list[str]:
if path is None:
path = os.environ["DATA_VOLUMES_CONFIG"]

data_volumes = []
with open(path) as fp:
for line in fp:
if data_volume := os.path.expandvars(line.rstrip("\n")):
data_volumes.append(data_volume)
return data_volumes


@contextlib.contextmanager
def enter_tmp_working_dir() -> Iterator[str]:
old_cwd = os.getcwd()
Expand Down
15 changes: 10 additions & 5 deletions cads_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import logging
import os
import random
import socket
import time
from typing import Any
Expand All @@ -13,10 +12,11 @@
import dask
import dask.config
import distributed.worker
import fsspec.implementations.local
import structlog
from distributed import get_worker

from . import config, utils
from . import config, models, utils

config.configure_logger(os.getenv("WORKER_LOG_LEVEL", "NOT_SET").upper())

Expand Down Expand Up @@ -220,7 +220,8 @@ def submit_workflow(

structlog.contextvars.bind_contextvars(event_type="DATASET_COMPUTE", job_id=job_id)

cache_files_urlpath = random.choice(utils.parse_data_volumes_config())
volumes = models.DataVolumes.from_yaml()
cache_files_urlpath = volumes.get_random_volume()
depth = int(os.getenv("CACHE_DEPTH", 1))
if depth == 2:
cache_files_urlpath = os.path.join(
Expand All @@ -244,7 +245,11 @@ def submit_workflow(
adaptor_class = cads_adaptors.get_adaptor_class(entry_point, setup_code)
try:
with utils.enter_tmp_working_dir() as working_dir:
base_dir = dirname if "file" in fs.protocol else working_dir
base_dir = (
dirname
if isinstance(fs, fsspec.implementations.local.LocalFileSystem)
else working_dir
)
with utils.make_cache_tmp_path(base_dir) as cache_tmp_path:
adaptor = adaptor_class(
form=form,
Expand All @@ -259,7 +264,7 @@ def submit_workflow(
context.error(f"{err.__class__.__name__}: {str(err)}")
raise

if "s3" in fs.protocol:
if result.counter == 1 and "s3" in fs.protocol:
fs.chmod(result.result["args"][0]["file:local_path"], acl="public-read")
with context.session_maker() as session:
request = cads_broker.database.set_request_cache_id(
Expand Down
2 changes: 2 additions & 0 deletions ci/environment-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ dependencies:
- pytest-cov
- sphinx
- sphinx-autoapi
- types-PyYAML
# DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW
- pip
- pip:
- git+https://github.com/ecmwf-projects/cacholote
- git+https://github.com/ecmwf-projects/cads-adaptors
- git+https://github.com/ecmwf-projects/cads-broker
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ channels:
dependencies:
- distributed
- psycopg2
- pydantic
- pyyaml
# See: https://github.com/fsspec/s3fs/pull/910
- s3fs!=2024.10.0
- s3fs>=2023.12.2
Expand Down
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ classifiers = [
dependencies = [
"cacholote",
"cads-adaptors@git+https://github.com/ecmwf-projects/cads-adaptors.git",
"cads-broker@git+https://github.com/ecmwf-projects/cads-broker.git",
"distributed",
"fsspec",
"pydantic",
"pyyaml",
"structlog",
"typer"
]
Expand All @@ -27,9 +31,14 @@ license = {file = "LICENSE"}
name = "cads-worker"
readme = "README.md"

[project.entry-points."fsspec.specs"]
cci1 = "cads_worker.filesystems:CCI1FileSystem"
cci2 = "cads_worker.filesystems:CCI2FileSystem"

[project.scripts]
cache-cleaner = "cads_worker.entry_points:cache_cleaner"
expire-cache-entries = "cads_worker.entry_points:expire_cache_entries"
init-buckets = "cads_worker.entry_points:init_buckets"

[tool.coverage.run]
branch = true
Expand All @@ -40,9 +49,14 @@ strict = true
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"cads_adaptors.*"
"cads_adaptors.*",
"fsspec.*"
]

[[tool.mypy.overrides]]
ignore_errors = true
module = "cads_worker.filesystems"

[tool.ruff]
# Same as Black.
indent-width = 4
Expand Down
24 changes: 24 additions & 0 deletions tests/test_01_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pathlib

import pytest

from cads_worker.models import DataVolumes


def test_data_volumes_from_yaml(
tmp_path: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("MAX_SIZE", "10")
data_volumes_config = tmp_path / "data-volumes.yaml"
data_volumes_config.write_text("\nfoo:\nbar:\n weight: 0\n max_size: 20\n")

volumes = DataVolumes.from_yaml(str(data_volumes_config))
assert volumes.model_dump() == {
"volumes": {
"foo": {"weight": 1, "max_size": 10},
"bar": {"weight": 0, "max_size": 20},
}
}

assert volumes.get_random_volume() == "foo"
10 changes: 10 additions & 0 deletions tests/test_02_filesystems.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

import fsspec
import pytest


@pytest.mark.parametrize("protocol", ["cci1", "cci2"])
def test_unstrip_protocol(protocol: str) -> None:
fs = fsspec.filesystem(protocol)
assert fs.unstrip_protocol(".") == f"{protocol}://{os.getcwd()}"
4 changes: 2 additions & 2 deletions tests/test_10_cache_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_cache_cleaner(
assert cached_path.exists()

# create data nodes config
data_volumes_config = tmp_path / "data-volumes.config"
data_volumes_config.write_text(cache_files_urlpath)
data_volumes_config = tmp_path / "data-volumes.yaml"
data_volumes_config.write_text(f"{cache_files_urlpath}:")
monkeypatch.setenv("DATA_VOLUMES_CONFIG", str(data_volumes_config))

# clean cache
Expand Down
17 changes: 0 additions & 17 deletions tests/test_30_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,9 @@
import pathlib
import tempfile

import pytest

from cads_worker import utils


def test_utils_parse_data_volumes_config(
tmp_path: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("FOO", "foo")
monkeypatch.setenv("BAR", "bar")
monkeypatch.setenv("BAZ", "")
data_volumes_config = tmp_path / "data-volumes.config"
data_volumes_config.write_text("\n\n$FOO\n\n${BAR}\n\n$BAZ\n\n")
assert utils.parse_data_volumes_config(str(data_volumes_config)) == ["foo", "bar"]

monkeypatch.setenv("DATA_VOLUMES_CONFIG", str(data_volumes_config))
assert utils.parse_data_volumes_config(None) == ["foo", "bar"]


def test_utils_enter_tmp_working_dir() -> None:
with utils.enter_tmp_working_dir() as tmp_working_dir:
assert os.getcwd() == tmp_working_dir
Expand Down
Loading