Skip to content

Commit 1ed266d

Browse files
BorisTheBraved-v-b
andauthored
Support with read only in wrappers (#3700)
* Support with_read_only in LoggingStore and LatencyStore. Fixes #3699 * Support CacheStore.with_read_only * Add entry to changes/ * use a dataclass for mutable cache state --------- Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent 2a032a8 commit 1ed266d

File tree

8 files changed

+268
-100
lines changed

8 files changed

+268
-100
lines changed

changes/3700.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CacheStore, LoggingStore and LatencyStore now support with_read_only.

src/zarr/experimental/cache_store.py

Lines changed: 76 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import logging
55
import time
66
from collections import OrderedDict
7-
from typing import TYPE_CHECKING, Any, Literal
7+
from dataclasses import dataclass, field
8+
from typing import TYPE_CHECKING, Any, Literal, Self
89

910
from zarr.abc.store import ByteRequest, Store
1011
from zarr.storage._wrapper import WrapperStore
@@ -15,6 +16,18 @@
1516
from zarr.core.buffer.core import Buffer, BufferPrototype
1617

1718

19+
@dataclass(slots=True)
20+
class _CacheState:
21+
cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict)
22+
current_size: int = 0
23+
key_sizes: dict[str, int] = field(default_factory=dict)
24+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
25+
hits: int = 0
26+
misses: int = 0
27+
evictions: int = 0
28+
key_insert_times: dict[str, float] = field(default_factory=dict)
29+
30+
1831
class CacheStore(WrapperStore[Store]):
1932
"""
2033
A dual-store caching implementation for Zarr stores.
@@ -36,9 +49,6 @@ class CacheStore(WrapperStore[Store]):
3649
Maximum size of the cache in bytes. When exceeded, least recently used
3750
items are evicted. None means unlimited size. Default is None.
3851
Note: Individual values larger than max_size will not be cached.
39-
key_insert_times : dict[str, float] | None, optional
40-
Dictionary to track insertion times (using monotonic time).
41-
Primarily for internal use. Default is None (creates new dict).
4252
cache_set_data : bool, optional
4353
Whether to cache data when it's written to the store. Default is True.
4454
@@ -69,15 +79,8 @@ class CacheStore(WrapperStore[Store]):
6979
_cache: Store
7080
max_age_seconds: int | Literal["infinity"]
7181
max_size: int | None
72-
key_insert_times: dict[str, float]
7382
cache_set_data: bool
74-
_cache_order: OrderedDict[str, None] # Track access order for LRU
75-
_current_size: int # Track current cache size
76-
_key_sizes: dict[str, int] # Track size of each cached key
77-
_lock: asyncio.Lock
78-
_hits: int # Cache hit counter
79-
_misses: int # Cache miss counter
80-
_evictions: int # Cache eviction counter
83+
_state: _CacheState
8184

8285
def __init__(
8386
self,
@@ -86,7 +89,6 @@ def __init__(
8689
cache_store: Store,
8790
max_age_seconds: int | str = "infinity",
8891
max_size: int | None = None,
89-
key_insert_times: dict[str, float] | None = None,
9092
cache_set_data: bool = True,
9193
) -> None:
9294
super().__init__(store)
@@ -107,18 +109,25 @@ def __init__(
107109
else:
108110
self.max_age_seconds = max_age_seconds
109111
self.max_size = max_size
110-
if key_insert_times is None:
111-
self.key_insert_times = {}
112-
else:
113-
self.key_insert_times = key_insert_times
114112
self.cache_set_data = cache_set_data
115-
self._cache_order = OrderedDict()
116-
self._current_size = 0
117-
self._key_sizes = {}
118-
self._lock = asyncio.Lock()
119-
self._hits = 0
120-
self._misses = 0
121-
self._evictions = 0
113+
self._state = _CacheState()
114+
115+
def _with_store(self, store: Store) -> Self:
116+
# Cannot support this operation because it would share a cache, but have a new store
117+
# So cache keys would conflict
118+
raise NotImplementedError("CacheStore does not support this operation.")
119+
120+
def with_read_only(self, read_only: bool = False) -> Self:
121+
# Create a new cache store that shares the same cache and mutable state
122+
store = type(self)(
123+
store=self._store.with_read_only(read_only),
124+
cache_store=self._cache,
125+
max_age_seconds=self.max_age_seconds,
126+
max_size=self.max_size,
127+
cache_set_data=self.cache_set_data,
128+
)
129+
store._state = self._state
130+
return store
122131

123132
def _is_key_fresh(self, key: str) -> bool:
124133
"""Check if a cached key is still fresh based on max_age_seconds.
@@ -128,7 +137,7 @@ def _is_key_fresh(self, key: str) -> bool:
128137
if self.max_age_seconds == "infinity":
129138
return True
130139
now = time.monotonic()
131-
elapsed = now - self.key_insert_times.get(key, 0)
140+
elapsed = now - self._state.key_insert_times.get(key, 0)
132141
return elapsed < self.max_age_seconds
133142

134143
async def _accommodate_value(self, value_size: int) -> None:
@@ -140,9 +149,9 @@ async def _accommodate_value(self, value_size: int) -> None:
140149
return
141150

142151
# Remove least recently used items until we have enough space
143-
while self._current_size + value_size > self.max_size and self._cache_order:
152+
while self._state.current_size + value_size > self.max_size and self._state.cache_order:
144153
# Get the least recently used key (first in OrderedDict)
145-
lru_key = next(iter(self._cache_order))
154+
lru_key = next(iter(self._state.cache_order))
146155
await self._evict_key(lru_key)
147156

148157
async def _evict_key(self, key: str) -> None:
@@ -152,15 +161,15 @@ async def _evict_key(self, key: str) -> None:
152161
Updates size tracking atomically with deletion.
153162
"""
154163
try:
155-
key_size = self._key_sizes.get(key, 0)
164+
key_size = self._state.key_sizes.get(key, 0)
156165

157166
# Delete from cache store
158167
await self._cache.delete(key)
159168

160169
# Update tracking after successful deletion
161170
self._remove_from_tracking(key)
162-
self._current_size = max(0, self._current_size - key_size)
163-
self._evictions += 1
171+
self._state.current_size = max(0, self._state.current_size - key_size)
172+
self._state.evictions += 1
164173

165174
logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
166175
except Exception:
@@ -183,39 +192,39 @@ async def _cache_value(self, key: str, value: Buffer) -> None:
183192
)
184193
return
185194

186-
async with self._lock:
195+
async with self._state.lock:
187196
# If key already exists, subtract old size first
188-
if key in self._key_sizes:
189-
old_size = self._key_sizes[key]
190-
self._current_size -= old_size
197+
if key in self._state.key_sizes:
198+
old_size = self._state.key_sizes[key]
199+
self._state.current_size -= old_size
191200
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)
192201

193202
# Make room for the new value (this calls _evict_key_locked internally)
194203
await self._accommodate_value(value_size)
195204

196205
# Update tracking atomically
197-
self._cache_order[key] = None # OrderedDict to track access order
198-
self._current_size += value_size
199-
self._key_sizes[key] = value_size
200-
self.key_insert_times[key] = time.monotonic()
206+
self._state.cache_order[key] = None # OrderedDict to track access order
207+
self._state.current_size += value_size
208+
self._state.key_sizes[key] = value_size
209+
self._state.key_insert_times[key] = time.monotonic()
201210

202211
logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)
203212

204213
async def _update_access_order(self, key: str) -> None:
205214
"""Update the access order for LRU tracking."""
206-
if key in self._cache_order:
207-
async with self._lock:
215+
if key in self._state.cache_order:
216+
async with self._state.lock:
208217
# Move to end (most recently used)
209-
self._cache_order.move_to_end(key)
218+
self._state.cache_order.move_to_end(key)
210219

211220
def _remove_from_tracking(self, key: str) -> None:
212221
"""Remove a key from all tracking structures.
213222
214-
Must be called while holding self._lock.
223+
Must be called while holding self._state.lock.
215224
"""
216-
self._cache_order.pop(key, None)
217-
self.key_insert_times.pop(key, None)
218-
self._key_sizes.pop(key, None)
225+
self._state.cache_order.pop(key, None)
226+
self._state.key_insert_times.pop(key, None)
227+
self._state.key_sizes.pop(key, None)
219228

220229
async def _get_try_cache(
221230
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
@@ -224,20 +233,20 @@ async def _get_try_cache(
224233
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
225234
if maybe_cached_result is not None:
226235
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
227-
self._hits += 1
236+
self._state.hits += 1
228237
# Update access order for LRU
229238
await self._update_access_order(key)
230239
return maybe_cached_result
231240
else:
232241
logger.debug(
233242
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
234243
)
235-
self._misses += 1
244+
self._state.misses += 1
236245
maybe_fresh_result = await super().get(key, prototype, byte_range)
237246
if maybe_fresh_result is None:
238247
# Key doesn't exist in source store
239248
await self._cache.delete(key)
240-
async with self._lock:
249+
async with self._state.lock:
241250
self._remove_from_tracking(key)
242251
else:
243252
# Cache the newly fetched value
@@ -249,12 +258,12 @@ async def _get_no_cache(
249258
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
250259
) -> Buffer | None:
251260
"""Get data directly from source store and update cache."""
252-
self._misses += 1
261+
self._state.misses += 1
253262
maybe_fresh_result = await super().get(key, prototype, byte_range)
254263
if maybe_fresh_result is None:
255264
# Key doesn't exist in source, remove from cache and tracking
256265
await self._cache.delete(key)
257-
async with self._lock:
266+
async with self._state.lock:
258267
self._remove_from_tracking(key)
259268
else:
260269
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
@@ -312,7 +321,7 @@ async def set(self, key: str, value: Buffer) -> None:
312321
else:
313322
logger.debug("set: deleting key %s from cache", key)
314323
await self._cache.delete(key)
315-
async with self._lock:
324+
async with self._state.lock:
316325
self._remove_from_tracking(key)
317326

318327
async def delete(self, key: str) -> None:
@@ -328,7 +337,7 @@ async def delete(self, key: str) -> None:
328337
await super().delete(key)
329338
logger.debug("delete: deleting key %s from cache", key)
330339
await self._cache.delete(key)
331-
async with self._lock:
340+
async with self._state.lock:
332341
self._remove_from_tracking(key)
333342

334343
def cache_info(self) -> dict[str, Any]:
@@ -339,20 +348,20 @@ def cache_info(self) -> dict[str, Any]:
339348
if self.max_age_seconds == "infinity"
340349
else self.max_age_seconds,
341350
"max_size": self.max_size,
342-
"current_size": self._current_size,
351+
"current_size": self._state.current_size,
343352
"cache_set_data": self.cache_set_data,
344-
"tracked_keys": len(self.key_insert_times),
345-
"cached_keys": len(self._cache_order),
353+
"tracked_keys": len(self._state.key_insert_times),
354+
"cached_keys": len(self._state.cache_order),
346355
}
347356

348357
def cache_stats(self) -> dict[str, Any]:
349358
"""Return cache performance statistics."""
350-
total_requests = self._hits + self._misses
351-
hit_rate = self._hits / total_requests if total_requests > 0 else 0.0
359+
total_requests = self._state.hits + self._state.misses
360+
hit_rate = self._state.hits / total_requests if total_requests > 0 else 0.0
352361
return {
353-
"hits": self._hits,
354-
"misses": self._misses,
355-
"evictions": self._evictions,
362+
"hits": self._state.hits,
363+
"misses": self._state.misses,
364+
"evictions": self._state.evictions,
356365
"total_requests": total_requests,
357366
"hit_rate": hit_rate,
358367
}
@@ -364,11 +373,11 @@ async def clear_cache(self) -> None:
364373
await self._cache.clear()
365374

366375
# Reset tracking
367-
async with self._lock:
368-
self.key_insert_times.clear()
369-
self._cache_order.clear()
370-
self._key_sizes.clear()
371-
self._current_size = 0
376+
async with self._state.lock:
377+
self._state.key_insert_times.clear()
378+
self._state.cache_order.clear()
379+
self._state.key_sizes.clear()
380+
self._state.current_size = 0
372381
logger.debug("clear_cache: cleared all cache data")
373382

374383
def __repr__(self) -> str:
@@ -379,6 +388,6 @@ def __repr__(self) -> str:
379388
f"cache_store={self._cache!r}, "
380389
f"max_age_seconds={self.max_age_seconds}, "
381390
f"max_size={self.max_size}, "
382-
f"current_size={self._current_size}, "
383-
f"cached_keys={len(self._cache_order)})"
391+
f"current_size={self._state.current_size}, "
392+
f"cached_keys={len(self._state.cache_order)})"
384393
)

src/zarr/storage/_logging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _default_handler(self) -> logging.Handler:
7777
)
7878
return handler
7979

80+
def _with_store(self, store: T_Store) -> Self:
81+
return type(self)(store=store, log_level=self.log_level, log_handler=self.log_handler)
82+
8083
@contextmanager
8184
def log(self, hint: Any = "") -> Generator[None, None, None]:
8285
"""Context manager to log method calls

src/zarr/storage/_wrapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Generic, TypeVar
3+
from typing import TYPE_CHECKING, Generic, TypeVar, cast
44

55
if TYPE_CHECKING:
66
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
@@ -31,14 +31,23 @@ class WrapperStore(Store, Generic[T_Store]):
3131
def __init__(self, store: T_Store) -> None:
3232
self._store = store
3333

34+
def _with_store(self, store: T_Store) -> Self:
35+
"""
36+
Constructs a new instance of the wrapper store with the same details but a new store.
37+
"""
38+
return type(self)(store=store)
39+
3440
@classmethod
3541
async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self:
3642
store = store_cls(*args, **kwargs)
3743
await store._open()
3844
return cls(store=store)
3945

46+
def with_read_only(self, read_only: bool = False) -> Self:
47+
return self._with_store(cast(T_Store, self._store.with_read_only(read_only)))
48+
4049
def __enter__(self) -> Self:
41-
return type(self)(self._store.__enter__())
50+
return self._with_store(self._store.__enter__())
4251

4352
def __exit__(
4453
self,

src/zarr/testing/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import pickle
66
from abc import abstractmethod
7-
from typing import TYPE_CHECKING, Generic, TypeVar
7+
from typing import TYPE_CHECKING, Generic, Self, TypeVar
88

99
from zarr.storage import WrapperStore
1010

@@ -578,10 +578,13 @@ class LatencyStore(WrapperStore[Store]):
578578
get_latency: float
579579
set_latency: float
580580

581-
def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
581+
def __init__(self, store: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
582582
self.get_latency = float(get_latency)
583583
self.set_latency = float(set_latency)
584-
self._store = cls
584+
self._store = store
585+
586+
def _with_store(self, store: Store) -> Self:
587+
return type(self)(store, get_latency=self.get_latency, set_latency=self.set_latency)
585588

586589
async def set(self, key: str, value: Buffer) -> None:
587590
"""

0 commit comments

Comments
 (0)