44import logging
55import time
66from collections import OrderedDict
7- from typing import TYPE_CHECKING , Any , Literal
7+ from dataclasses import dataclass , field
8+ from typing import TYPE_CHECKING , Any , Literal , Self
89
910from zarr .abc .store import ByteRequest , Store
1011from zarr .storage ._wrapper import WrapperStore
1516 from zarr .core .buffer .core import Buffer , BufferPrototype
1617
1718
19+ @dataclass (slots = True )
20+ class _CacheState :
21+ cache_order : OrderedDict [str , None ] = field (default_factory = OrderedDict )
22+ current_size : int = 0
23+ key_sizes : dict [str , int ] = field (default_factory = dict )
24+ lock : asyncio .Lock = field (default_factory = asyncio .Lock )
25+ hits : int = 0
26+ misses : int = 0
27+ evictions : int = 0
28+ key_insert_times : dict [str , float ] = field (default_factory = dict )
29+
30+
1831class CacheStore (WrapperStore [Store ]):
1932 """
2033 A dual-store caching implementation for Zarr stores.
@@ -36,9 +49,6 @@ class CacheStore(WrapperStore[Store]):
3649 Maximum size of the cache in bytes. When exceeded, least recently used
3750 items are evicted. None means unlimited size. Default is None.
3851 Note: Individual values larger than max_size will not be cached.
39- key_insert_times : dict[str, float] | None, optional
40- Dictionary to track insertion times (using monotonic time).
41- Primarily for internal use. Default is None (creates new dict).
4252 cache_set_data : bool, optional
4353 Whether to cache data when it's written to the store. Default is True.
4454
@@ -69,15 +79,8 @@ class CacheStore(WrapperStore[Store]):
6979 _cache : Store
7080 max_age_seconds : int | Literal ["infinity" ]
7181 max_size : int | None
72- key_insert_times : dict [str , float ]
7382 cache_set_data : bool
74- _cache_order : OrderedDict [str , None ] # Track access order for LRU
75- _current_size : int # Track current cache size
76- _key_sizes : dict [str , int ] # Track size of each cached key
77- _lock : asyncio .Lock
78- _hits : int # Cache hit counter
79- _misses : int # Cache miss counter
80- _evictions : int # Cache eviction counter
83+ _state : _CacheState
8184
8285 def __init__ (
8386 self ,
@@ -86,7 +89,6 @@ def __init__(
8689 cache_store : Store ,
8790 max_age_seconds : int | str = "infinity" ,
8891 max_size : int | None = None ,
89- key_insert_times : dict [str , float ] | None = None ,
9092 cache_set_data : bool = True ,
9193 ) -> None :
9294 super ().__init__ (store )
@@ -107,18 +109,25 @@ def __init__(
107109 else :
108110 self .max_age_seconds = max_age_seconds
109111 self .max_size = max_size
110- if key_insert_times is None :
111- self .key_insert_times = {}
112- else :
113- self .key_insert_times = key_insert_times
114112 self .cache_set_data = cache_set_data
115- self ._cache_order = OrderedDict ()
116- self ._current_size = 0
117- self ._key_sizes = {}
118- self ._lock = asyncio .Lock ()
119- self ._hits = 0
120- self ._misses = 0
121- self ._evictions = 0
113+ self ._state = _CacheState ()
114+
115+ def _with_store (self , store : Store ) -> Self :
116+ # Cannot support this operation because it would share a cache, but have a new store
117+ # So cache keys would conflict
118+ raise NotImplementedError ("CacheStore does not support this operation." )
119+
120+ def with_read_only (self , read_only : bool = False ) -> Self :
121+ # Create a new cache store that shares the same cache and mutable state
122+ store = type (self )(
123+ store = self ._store .with_read_only (read_only ),
124+ cache_store = self ._cache ,
125+ max_age_seconds = self .max_age_seconds ,
126+ max_size = self .max_size ,
127+ cache_set_data = self .cache_set_data ,
128+ )
129+ store ._state = self ._state
130+ return store
122131
123132 def _is_key_fresh (self , key : str ) -> bool :
124133 """Check if a cached key is still fresh based on max_age_seconds.
@@ -128,7 +137,7 @@ def _is_key_fresh(self, key: str) -> bool:
128137 if self .max_age_seconds == "infinity" :
129138 return True
130139 now = time .monotonic ()
131- elapsed = now - self .key_insert_times .get (key , 0 )
140+ elapsed = now - self ._state . key_insert_times .get (key , 0 )
132141 return elapsed < self .max_age_seconds
133142
134143 async def _accommodate_value (self , value_size : int ) -> None :
@@ -140,9 +149,9 @@ async def _accommodate_value(self, value_size: int) -> None:
140149 return
141150
142151 # Remove least recently used items until we have enough space
143- while self ._current_size + value_size > self .max_size and self ._cache_order :
152+ while self ._state . current_size + value_size > self .max_size and self ._state . cache_order :
144153 # Get the least recently used key (first in OrderedDict)
145- lru_key = next (iter (self ._cache_order ))
154+ lru_key = next (iter (self ._state . cache_order ))
146155 await self ._evict_key (lru_key )
147156
148157 async def _evict_key (self , key : str ) -> None :
@@ -152,15 +161,15 @@ async def _evict_key(self, key: str) -> None:
152161 Updates size tracking atomically with deletion.
153162 """
154163 try :
155- key_size = self ._key_sizes .get (key , 0 )
164+ key_size = self ._state . key_sizes .get (key , 0 )
156165
157166 # Delete from cache store
158167 await self ._cache .delete (key )
159168
160169 # Update tracking after successful deletion
161170 self ._remove_from_tracking (key )
162- self ._current_size = max (0 , self ._current_size - key_size )
163- self ._evictions += 1
171+ self ._state . current_size = max (0 , self ._state . current_size - key_size )
172+ self ._state . evictions += 1
164173
165174 logger .debug ("_evict_key: evicted key %s, freed %d bytes" , key , key_size )
166175 except Exception :
@@ -183,39 +192,39 @@ async def _cache_value(self, key: str, value: Buffer) -> None:
183192 )
184193 return
185194
186- async with self ._lock :
195+ async with self ._state . lock :
187196 # If key already exists, subtract old size first
188- if key in self ._key_sizes :
189- old_size = self ._key_sizes [key ]
190- self ._current_size -= old_size
197+ if key in self ._state . key_sizes :
198+ old_size = self ._state . key_sizes [key ]
199+ self ._state . current_size -= old_size
191200 logger .debug ("_cache_value: updating existing key %s, old size %d" , key , old_size )
192201
193202 # Make room for the new value (this calls _evict_key_locked internally)
194203 await self ._accommodate_value (value_size )
195204
196205 # Update tracking atomically
197- self ._cache_order [key ] = None # OrderedDict to track access order
198- self ._current_size += value_size
199- self ._key_sizes [key ] = value_size
200- self .key_insert_times [key ] = time .monotonic ()
206+ self ._state . cache_order [key ] = None # OrderedDict to track access order
207+ self ._state . current_size += value_size
208+ self ._state . key_sizes [key ] = value_size
209+ self ._state . key_insert_times [key ] = time .monotonic ()
201210
202211 logger .debug ("_cache_value: cached key %s with size %d bytes" , key , value_size )
203212
204213 async def _update_access_order (self , key : str ) -> None :
205214 """Update the access order for LRU tracking."""
206- if key in self ._cache_order :
207- async with self ._lock :
215+ if key in self ._state . cache_order :
216+ async with self ._state . lock :
208217 # Move to end (most recently used)
209- self ._cache_order .move_to_end (key )
218+ self ._state . cache_order .move_to_end (key )
210219
211220 def _remove_from_tracking (self , key : str ) -> None :
212221 """Remove a key from all tracking structures.
213222
214- Must be called while holding self._lock .
223+ Must be called while holding self._state.lock .
215224 """
216- self ._cache_order .pop (key , None )
217- self .key_insert_times .pop (key , None )
218- self ._key_sizes .pop (key , None )
225+ self ._state . cache_order .pop (key , None )
226+ self ._state . key_insert_times .pop (key , None )
227+ self ._state . key_sizes .pop (key , None )
219228
220229 async def _get_try_cache (
221230 self , key : str , prototype : BufferPrototype , byte_range : ByteRequest | None = None
@@ -224,20 +233,20 @@ async def _get_try_cache(
224233 maybe_cached_result = await self ._cache .get (key , prototype , byte_range )
225234 if maybe_cached_result is not None :
226235 logger .debug ("_get_try_cache: key %s found in cache (HIT)" , key )
227- self ._hits += 1
236+ self ._state . hits += 1
228237 # Update access order for LRU
229238 await self ._update_access_order (key )
230239 return maybe_cached_result
231240 else :
232241 logger .debug (
233242 "_get_try_cache: key %s not found in cache (MISS), fetching from store" , key
234243 )
235- self ._misses += 1
244+ self ._state . misses += 1
236245 maybe_fresh_result = await super ().get (key , prototype , byte_range )
237246 if maybe_fresh_result is None :
238247 # Key doesn't exist in source store
239248 await self ._cache .delete (key )
240- async with self ._lock :
249+ async with self ._state . lock :
241250 self ._remove_from_tracking (key )
242251 else :
243252 # Cache the newly fetched value
@@ -249,12 +258,12 @@ async def _get_no_cache(
249258 self , key : str , prototype : BufferPrototype , byte_range : ByteRequest | None = None
250259 ) -> Buffer | None :
251260 """Get data directly from source store and update cache."""
252- self ._misses += 1
261+ self ._state . misses += 1
253262 maybe_fresh_result = await super ().get (key , prototype , byte_range )
254263 if maybe_fresh_result is None :
255264 # Key doesn't exist in source, remove from cache and tracking
256265 await self ._cache .delete (key )
257- async with self ._lock :
266+ async with self ._state . lock :
258267 self ._remove_from_tracking (key )
259268 else :
260269 logger .debug ("_get_no_cache: key %s found in store, setting in cache" , key )
@@ -312,7 +321,7 @@ async def set(self, key: str, value: Buffer) -> None:
312321 else :
313322 logger .debug ("set: deleting key %s from cache" , key )
314323 await self ._cache .delete (key )
315- async with self ._lock :
324+ async with self ._state . lock :
316325 self ._remove_from_tracking (key )
317326
318327 async def delete (self , key : str ) -> None :
@@ -328,7 +337,7 @@ async def delete(self, key: str) -> None:
328337 await super ().delete (key )
329338 logger .debug ("delete: deleting key %s from cache" , key )
330339 await self ._cache .delete (key )
331- async with self ._lock :
340+ async with self ._state . lock :
332341 self ._remove_from_tracking (key )
333342
334343 def cache_info (self ) -> dict [str , Any ]:
@@ -339,20 +348,20 @@ def cache_info(self) -> dict[str, Any]:
339348 if self .max_age_seconds == "infinity"
340349 else self .max_age_seconds ,
341350 "max_size" : self .max_size ,
342- "current_size" : self ._current_size ,
351+ "current_size" : self ._state . current_size ,
343352 "cache_set_data" : self .cache_set_data ,
344- "tracked_keys" : len (self .key_insert_times ),
345- "cached_keys" : len (self ._cache_order ),
353+ "tracked_keys" : len (self ._state . key_insert_times ),
354+ "cached_keys" : len (self ._state . cache_order ),
346355 }
347356
348357 def cache_stats (self ) -> dict [str , Any ]:
349358 """Return cache performance statistics."""
350- total_requests = self ._hits + self ._misses
351- hit_rate = self ._hits / total_requests if total_requests > 0 else 0.0
359+ total_requests = self ._state . hits + self ._state . misses
360+ hit_rate = self ._state . hits / total_requests if total_requests > 0 else 0.0
352361 return {
353- "hits" : self ._hits ,
354- "misses" : self ._misses ,
355- "evictions" : self ._evictions ,
362+ "hits" : self ._state . hits ,
363+ "misses" : self ._state . misses ,
364+ "evictions" : self ._state . evictions ,
356365 "total_requests" : total_requests ,
357366 "hit_rate" : hit_rate ,
358367 }
@@ -364,11 +373,11 @@ async def clear_cache(self) -> None:
364373 await self ._cache .clear ()
365374
366375 # Reset tracking
367- async with self ._lock :
368- self .key_insert_times .clear ()
369- self ._cache_order .clear ()
370- self ._key_sizes .clear ()
371- self ._current_size = 0
376+ async with self ._state . lock :
377+ self ._state . key_insert_times .clear ()
378+ self ._state . cache_order .clear ()
379+ self ._state . key_sizes .clear ()
380+ self ._state . current_size = 0
372381 logger .debug ("clear_cache: cleared all cache data" )
373382
374383 def __repr__ (self ) -> str :
@@ -379,6 +388,6 @@ def __repr__(self) -> str:
379388 f"cache_store={ self ._cache !r} , "
380389 f"max_age_seconds={ self .max_age_seconds } , "
381390 f"max_size={ self .max_size } , "
382- f"current_size={ self ._current_size } , "
383- f"cached_keys={ len (self ._cache_order )} )"
391+ f"current_size={ self ._state . current_size } , "
392+ f"cached_keys={ len (self ._state . cache_order )} )"
384393 )
0 commit comments