diff --git a/ldclient/__init__.py b/ldclient/__init__.py index 884c3af8..b1341f95 100644 --- a/ldclient/__init__.py +++ b/ldclient/__init__.py @@ -37,17 +37,16 @@ def set_config(config: Config): global __config global __client global __lock - try: - __lock.lock() - if __client: - log.info("Reinitializing LaunchDarkly Client " + VERSION + " with new config") - new_client = LDClient(config=config, start_wait=start_wait) - old_client = __client - __client = new_client - old_client.close() - finally: - __config = config - __lock.unlock() + with __lock.write(): + try: + if __client: + log.info("Reinitializing LaunchDarkly Client " + VERSION + " with new config") + new_client = LDClient(config=config, start_wait=start_wait) + old_client = __client + __client = new_client + old_client.close() + finally: + __config = config def get() -> LDClient: @@ -63,35 +62,27 @@ def get() -> LDClient: global __config global __client global __lock - try: - __lock.rlock() + with __lock.read(): if __client: return __client if __config is None: raise Exception("set_config was not called") - finally: - __lock.runlock() - try: - __lock.lock() + with __lock.write(): if not __client: log.info("Initializing LaunchDarkly Client " + VERSION) __client = LDClient(config=__config, start_wait=start_wait) return __client - finally: - __lock.unlock() # for testing only def _reset_client(): global __client global __lock - try: - __lock.lock() + c = None + with __lock.write(): c = __client __client = None - finally: - __lock.unlock() if c: c.close() diff --git a/ldclient/client.py b/ldclient/client.py index 1becbbad..01007610 100644 --- a/ldclient/client.py +++ b/ldclient/client.py @@ -111,13 +111,10 @@ def __wrapper(self, fn: Callable): raise def __update_availability(self, available: bool): - try: - self.__lock.lock() + with self.__lock.write(): if available == self.__last_available: return self.__last_available = available - finally: - self.__lock.unlock() status = DataStoreStatus(available, False) @@ -127,23 +124,19 @@ def __update_availability(self, available: bool): self.__store_update_sink.update_status(status) if available: - try: - self.__lock.lock() + with self.__lock.write(): if self.__poller is not None: self.__poller.stop() self.__poller = None - finally: - self.__lock.unlock() return log.warn("Detected persistent store unavailability; updates will be cached until it recovers") task = RepeatingTask("ldclient.check-availability", 0.5, 0, self.__check_availability) - self.__lock.lock() - self.__poller = task - self.__poller.start() - self.__lock.unlock() + with self.__lock.write(): + self.__poller = task + self.__poller.start() def __check_availability(self): try: @@ -717,9 +710,8 @@ def add_hook(self, hook: Hook): if not isinstance(hook, Hook): return - self.__hooks_lock.lock() - self.__hooks.append(hook) - self.__hooks_lock.unlock() + with self.__hooks_lock.write(): + self.__hooks.append(hook) def __evaluate_with_hooks(self, key: str, context: Context, default_value: Any, method: str, block: Callable[[], _EvaluationWithHookResult]) -> _EvaluationWithHookResult: """ @@ -733,15 +725,11 @@ def __evaluate_with_hooks(self, key: str, context: Context, default_value: Any, # :return: """ hooks = [] # type: List[Hook] - try: - self.__hooks_lock.rlock() - + with self.__hooks_lock.read(): if len(self.__hooks) == 0: return block() hooks = self.__hooks.copy() - finally: - self.__hooks_lock.runlock() series_context = EvaluationSeriesContext(key=key, context=context, default_value=default_value, method=method) hook_data = self.__execute_before_evaluation(hooks, series_context) diff --git a/ldclient/feature_store.py b/ldclient/feature_store.py index f4340b47..7a8912b2 100644 --- a/ldclient/feature_store.py +++ b/ldclient/feature_store.py @@ -77,8 +77,7 @@ def is_available(self) -> bool: def get(self, kind: VersionedDataKind, key: str, callback: Callable[[Any], Any] = lambda x: x) -> Any: """ """ - try: - self._lock.rlock() + with self._lock.read(): itemsOfKind = self._items[kind] item = itemsOfKind.get(key) if item is None: @@ -88,17 +87,12 @@ def get(self, kind: VersionedDataKind, key: str, callback: Callable[[Any], Any] log.debug("Attempted to get deleted key %s in '%s', returning None", key, kind.namespace) return callback(None) return callback(item) - finally: - self._lock.runlock() def all(self, kind, callback): """ """ - try: - self._lock.rlock() + with self._lock.read(): itemsOfKind = self._items[kind] return callback(dict((k, i) for k, i in itemsOfKind.items() if ('deleted' not in i) or not i['deleted'])) - finally: - self._lock.runlock() def init(self, all_data): """ """ @@ -108,51 +102,39 @@ def init(self, all_data): for key, item in items.items(): items_decoded[key] = kind.decode(item) all_decoded[kind] = items_decoded - try: - self._lock.rlock() + with self._lock.write(): self._items.clear() self._items.update(all_decoded) self._initialized = True for k in all_data: log.debug("Initialized '%s' store with %d items", k.namespace, len(all_data[k])) - finally: - self._lock.runlock() # noinspection PyShadowingNames def delete(self, kind, key: str, version: int): """ """ - try: - self._lock.rlock() + with self._lock.write(): itemsOfKind = self._items[kind] i = itemsOfKind.get(key) if i is None or i['version'] < version: i = {'deleted': True, 'version': version} itemsOfKind[key] = i - finally: - self._lock.runlock() def upsert(self, kind, item): """ """ decoded_item = kind.decode(item) key = item['key'] - try: - self._lock.rlock() + with self._lock.write(): itemsOfKind = self._items[kind] i = itemsOfKind.get(key) if i is None or i['version'] < item['version']: itemsOfKind[key] = decoded_item log.debug("Updated %s in '%s' to version %d", key, kind.namespace, item['version']) - finally: - self._lock.runlock() @property def initialized(self) -> bool: """ """ - try: - self._lock.rlock() + with self._lock.read(): return self._initialized - finally: - self._lock.runlock() def describe_configuration(self, config): return 'memory' diff --git a/ldclient/impl/datasource/status.py b/ldclient/impl/datasource/status.py index 172ffee9..c9813e04 100644 --- a/ldclient/impl/datasource/status.py +++ b/ldclient/impl/datasource/status.py @@ -29,11 +29,8 @@ def __init__(self, store: FeatureStore, status_listeners: Listeners, flag_change @property def status(self) -> DataSourceStatus: - try: - self.__lock.rlock() + with self.__lock.read(): return self.__status - finally: - self.__lock.runlock() def init(self, all_data: Mapping[VersionedDataKind, Mapping[str, dict]]): old_data = None @@ -70,8 +67,7 @@ def delete(self, kind: VersionedDataKind, key: str, version: int): def update_status(self, new_state: DataSourceState, new_error: Optional[DataSourceErrorInfo]): status_to_broadcast = None - try: - self.__lock.lock() + with self.__lock.write(): old_status = self.__status if new_state == DataSourceState.INTERRUPTED and old_status.state == DataSourceState.INITIALIZING: @@ -83,8 +79,6 @@ def update_status(self, new_state: DataSourceState, new_error: Optional[DataSour self.__status = DataSourceStatus(new_state, self.__status.since if new_state == self.__status.state else time.time(), self.__status.error if new_error is None else new_error) status_to_broadcast = self.__status - finally: - self.__lock.unlock() if status_to_broadcast is not None: self.__status_listeners.notify(status_to_broadcast) diff --git a/ldclient/impl/datastore/status.py b/ldclient/impl/datastore/status.py index ee9797dd..1e4f145b 100644 --- a/ldclient/impl/datastore/status.py +++ b/ldclient/impl/datastore/status.py @@ -27,16 +27,12 @@ def listeners(self) -> Listeners: return self.__listeners def status(self) -> DataStoreStatus: - self.__lock.rlock() - status = copy(self.__status) - self.__lock.runlock() - - return status + with self.__lock.read(): + return copy(self.__status) def update_status(self, status: DataStoreStatus): - self.__lock.lock() - old_value, self.__status = self.__status, status - self.__lock.unlock() + with self.__lock.write(): + old_value, self.__status = self.__status, status if old_value != status: self.__listeners.notify(status) diff --git a/ldclient/impl/datasystem/fdv2.py b/ldclient/impl/datasystem/fdv2.py index 7482d3a9..b8229e68 100644 --- a/ldclient/impl/datasystem/fdv2.py +++ b/ldclient/impl/datasystem/fdv2.py @@ -42,17 +42,13 @@ def __init__(self, listeners: Listeners): @property def status(self) -> DataSourceStatus: - self.__lock.rlock() - status = self.__status - self.__lock.runlock() - - return status + with self.__lock.read(): + return self.__status def update_status(self, new_state: DataSourceState, new_error: Optional[DataSourceErrorInfo]): status_to_broadcast = None - try: - self.__lock.lock() + with self.__lock.write(): old_status = self.__status if new_state == DataSourceState.INTERRUPTED and old_status.state == DataSourceState.INITIALIZING: @@ -67,8 +63,6 @@ def update_status(self, new_state: DataSourceState, new_error: Optional[DataSour self.__status = DataSourceStatus(new_state, new_since, new_error) status_to_broadcast = self.__status - finally: - self.__lock.unlock() if status_to_broadcast is not None: self.__listeners.notify(status_to_broadcast) @@ -92,25 +86,20 @@ def update_status(self, status: DataStoreStatus): """ update_status is called from the data store to push a status update. """ - self.__lock.lock() modified = False - if self.__status != status: - self.__status = status - modified = True - - self.__lock.unlock() + with self.__lock.write(): + if self.__status != status: + self.__status = status + modified = True if modified: self.__listeners.notify(status) @property def status(self) -> DataStoreStatus: - self.__lock.rlock() - status = copy(self.__status) - self.__lock.runlock() - - return status + with self.__lock.read(): + return copy(self.__status) def is_monitoring_enabled(self) -> bool: if self.__store is None: @@ -174,8 +163,7 @@ def __update_availability(self, available: bool): poller_to_stop = None task_to_start = None - self.__lock.lock() - try: + with self.__lock.write(): if available == self.__last_available: return @@ -188,8 +176,6 @@ def __update_availability(self, available: bool): elif self.__poller is None: task_to_start = RepeatingTask("ldclient.check-availability", 0.5, 0, self.__check_availability) self.__poller = task_to_start - finally: - self.__lock.unlock() if available: log.warning("Persistent store is available again") @@ -336,13 +322,12 @@ def stop(self): """Stop the FDv2 data system and all associated threads.""" self._stop_event.set() - self._lock.lock() - if self._active_synchronizer is not None: - try: - self._active_synchronizer.stop() - except Exception as e: - log.error("Error stopping active data source: %s", e) - self._lock.unlock() + with self._lock.write(): + if self._active_synchronizer is not None: + try: + self._active_synchronizer.stop() + except Exception as e: + log.error("Error stopping active data source: %s", e) # Wait for all threads to complete for thread in self._threads: @@ -426,12 +411,11 @@ def synchronizer_loop(self: 'FDv2'): while not self._stop_event.is_set() and self._primary_synchronizer_builder is not None: # Try primary synchronizer try: - self._lock.lock() - primary_sync = self._primary_synchronizer_builder(self._config) - if isinstance(primary_sync, DiagnosticSource) and self._diagnostic_accumulator is not None: - primary_sync.set_diagnostic_accumulator(self._diagnostic_accumulator) - self._active_synchronizer = primary_sync - self._lock.unlock() + with self._lock.write(): + primary_sync = self._primary_synchronizer_builder(self._config) + if isinstance(primary_sync, DiagnosticSource) and self._diagnostic_accumulator is not None: + primary_sync.set_diagnostic_accumulator(self._diagnostic_accumulator) + self._active_synchronizer = primary_sync log.info("Primary synchronizer %s is starting", primary_sync.name) @@ -462,13 +446,12 @@ def synchronizer_loop(self: 'FDv2'): if self._secondary_synchronizer_builder is None: continue - self._lock.lock() - secondary_sync = self._secondary_synchronizer_builder(self._config) - if isinstance(secondary_sync, DiagnosticSource) and self._diagnostic_accumulator is not None: - secondary_sync.set_diagnostic_accumulator(self._diagnostic_accumulator) - log.info("Secondary synchronizer %s is starting", secondary_sync.name) - self._active_synchronizer = secondary_sync - self._lock.unlock() + with self._lock.write(): + secondary_sync = self._secondary_synchronizer_builder(self._config) + if isinstance(secondary_sync, DiagnosticSource) and self._diagnostic_accumulator is not None: + secondary_sync.set_diagnostic_accumulator(self._diagnostic_accumulator) + log.info("Secondary synchronizer %s is starting", secondary_sync.name) + self._active_synchronizer = secondary_sync remove_sync, fallback_v1 = self._consume_synchronizer_results( secondary_sync, set_on_ready, self._recovery_condition @@ -497,11 +480,10 @@ def synchronizer_loop(self: 'FDv2'): finally: # Ensure we always set the ready event when exiting set_on_ready.set() - self._lock.lock() - if self._active_synchronizer is not None: - self._active_synchronizer.stop() - self._active_synchronizer = None - self._lock.unlock() + with self._lock.write(): + if self._active_synchronizer is not None: + self._active_synchronizer.stop() + self._active_synchronizer = None sync_thread = Thread( target=synchronizer_loop, diff --git a/ldclient/impl/datasystem/store.py b/ldclient/impl/datasystem/store.py index 0d731e03..6491cf97 100644 --- a/ldclient/impl/datasystem/store.py +++ b/ldclient/impl/datasystem/store.py @@ -50,8 +50,7 @@ def get( key: str, callback: Callable[[Any], Any] = lambda x: x, ) -> Any: - try: - self._lock.rlock() + with self._lock.read(): items_of_kind = self._items[kind] item = items_of_kind.get(key) if item is None: @@ -69,12 +68,9 @@ def get( ) return callback(None) return callback(item) - finally: - self._lock.runlock() def all(self, kind: VersionedDataKind, callback: Callable[[Any], Any] = lambda x: x) -> Any: - try: - self._lock.rlock() + with self._lock.read(): items_of_kind = self._items[kind] return callback( dict( @@ -83,8 +79,6 @@ def all(self, kind: VersionedDataKind, callback: Callable[[Any], Any] = lambda x if ("deleted" not in i) or not i["deleted"] ) ) - finally: - self._lock.runlock() def set_basis(self, collections: Collections) -> bool: """ @@ -95,15 +89,13 @@ def set_basis(self, collections: Collections) -> bool: return False try: - self._lock.lock() - self._items.clear() - self._items.update(all_decoded) - self._initialized = True + with self._lock.write(): + self._items.clear() + self._items.update(all_decoded) + self._initialized = True except Exception as e: log.error("Failed applying set_basis", exc_info=e) return False - finally: - self._lock.unlock() return True @@ -116,20 +108,18 @@ def apply_delta(self, collections: Collections) -> bool: return False try: - self._lock.lock() - for kind, kind_data in all_decoded.items(): - items_of_kind = self._items[kind] - kind_data = all_decoded[kind] - for key, item in kind_data.items(): - items_of_kind[key] = item - log.debug( - "Updated %s in '%s' to version %d", key, kind.namespace, item["version"] - ) + with self._lock.write(): + for kind, kind_data in all_decoded.items(): + items_of_kind = self._items[kind] + kind_data = all_decoded[kind] + for key, item in kind_data.items(): + items_of_kind[key] = item + log.debug( + "Updated %s in '%s' to version %d", key, kind.namespace, item["version"] + ) except Exception as e: log.error("Failed applying apply_delta", exc_info=e) return False - finally: - self._lock.unlock() return True @@ -153,11 +143,8 @@ def initialized(self) -> bool: """ Indicates whether the store has been initialized with data. """ - try: - self._lock.rlock() + with self._lock.read(): return self._initialized - finally: - self._lock.runlock() class Store: diff --git a/ldclient/impl/flag_tracker.py b/ldclient/impl/flag_tracker.py index e7c9b7c2..8ce16b23 100644 --- a/ldclient/impl/flag_tracker.py +++ b/ldclient/impl/flag_tracker.py @@ -22,9 +22,8 @@ def __call__(self, flag_change: FlagChange): new_value = self.__eval_fn(self.__key, self.__context) - self.__lock.lock() - old_value, self.__value = self.__value, new_value - self.__lock.unlock() + with self.__lock.write(): + old_value, self.__value = self.__value, new_value if new_value == old_value: return diff --git a/ldclient/impl/listeners.py b/ldclient/impl/listeners.py index d171d80d..58b88b96 100644 --- a/ldclient/impl/listeners.py +++ b/ldclient/impl/listeners.py @@ -1,6 +1,6 @@ -from threading import RLock from typing import Any, Callable +from ldclient.impl.rwlock import ReadWriteLock from ldclient.impl.util import log @@ -12,25 +12,25 @@ class Listeners: def __init__(self): self.__listeners = [] - self.__lock = RLock() + self.__lock = ReadWriteLock() def has_listeners(self) -> bool: - with self.__lock: + with self.__lock.read(): return len(self.__listeners) > 0 def add(self, listener: Callable): - with self.__lock: + with self.__lock.write(): self.__listeners.append(listener) def remove(self, listener: Callable): - with self.__lock: + with self.__lock.write(): try: self.__listeners.remove(listener) except ValueError: pass # removing a listener that wasn't in the list is a no-op def notify(self, value: Any): - with self.__lock: + with self.__lock.read(): listeners_copy = self.__listeners.copy() for listener in listeners_copy: try: diff --git a/ldclient/impl/rwlock.py b/ldclient/impl/rwlock.py index e394194b..a31a2624 100644 --- a/ldclient/impl/rwlock.py +++ b/ldclient/impl/rwlock.py @@ -1,4 +1,5 @@ import threading +from contextlib import contextmanager class ReadWriteLock: @@ -38,3 +39,33 @@ def lock(self): def unlock(self): """Release a write lock.""" self._read_ready.release() + + @contextmanager + def read(self): + """Context manager for acquiring a read lock. + + Usage: + with lock.read(): + # read lock held here + pass + """ + self.rlock() + try: + yield self + finally: + self.runlock() + + @contextmanager + def write(self): + """Context manager for acquiring a write lock. + + Usage: + with lock.write(): + # write lock held here + pass + """ + self.lock() + try: + yield self + finally: + self.unlock() diff --git a/ldclient/integrations/test_data.py b/ldclient/integrations/test_data.py index 56e06f9a..59d2e048 100644 --- a/ldclient/integrations/test_data.py +++ b/ldclient/integrations/test_data.py @@ -57,11 +57,8 @@ def __init__(self): def __call__(self, config, store, ready): data_source = _TestDataSource(store, self, ready) - try: - self._lock.lock() + with self._lock.write(): self._instances.append(data_source) - finally: - self._lock.unlock() return data_source @@ -89,14 +86,11 @@ def flag(self, key: str) -> 'FlagBuilder': :param str key: the flag key :return: the flag configuration builder object """ - try: - self._lock.rlock() + with self._lock.read(): if key in self._flag_builders and self._flag_builders[key]: return self._flag_builders[key]._copy() else: return FlagBuilder(key).boolean_flag() - finally: - self._lock.runlock() def update(self, flag_builder: 'FlagBuilder') -> 'TestData': """Updates the test data with the specified flag configuration. @@ -113,9 +107,7 @@ def update(self, flag_builder: 'FlagBuilder') -> 'TestData': :param flag_builder: a flag configuration builder :return: self (the TestData object) """ - try: - self._lock.lock() - + with self._lock.write(): old_version = 0 if flag_builder._key in self._current_flags: old_flag = self._current_flags[flag_builder._key] @@ -126,8 +118,6 @@ def update(self, flag_builder: 'FlagBuilder') -> 'TestData': self._current_flags[flag_builder._key] = new_flag self._flag_builders[flag_builder._key] = flag_builder._copy() - finally: - self._lock.unlock() for instance in self._instances: instance.upsert(new_flag) @@ -138,11 +128,8 @@ def _make_init_data(self) -> dict: return {FEATURES: copy.copy(self._current_flags)} def _closed_instance(self, instance): - try: - self._lock.lock() + with self._lock.write(): self._instances.remove(instance) - finally: - self._lock.unlock() class FlagBuilder: diff --git a/ldclient/integrations/test_datav2.py b/ldclient/integrations/test_datav2.py index a2da52db..3b791cf1 100644 --- a/ldclient/integrations/test_datav2.py +++ b/ldclient/integrations/test_datav2.py @@ -617,14 +617,11 @@ def flag(self, key: str) -> FlagBuilderV2: :param str key: the flag key :return: the flag configuration builder object """ - try: - self._lock.rlock() + with self._lock.read(): if key in self._flag_builders and self._flag_builders[key]: return self._flag_builders[key]._copy() return FlagBuilderV2(key).boolean_flag() - finally: - self._lock.runlock() def update(self, flag_builder: FlagBuilderV2) -> TestDataV2: """ @@ -643,9 +640,7 @@ def update(self, flag_builder: FlagBuilderV2) -> TestDataV2: :return: self (the TestDataV2 object) """ instances_copy = [] - try: - self._lock.lock() - + with self._lock.write(): old_version = 0 if flag_builder._key in self._current_flags: old_flag = self._current_flags[flag_builder._key] @@ -659,8 +654,6 @@ def update(self, flag_builder: FlagBuilderV2) -> TestDataV2: # Create a copy of instances while holding the lock to avoid race conditions instances_copy = list(self._instances) - finally: - self._lock.unlock() for instance in instances_copy: instance.upsert_flag(new_flag) @@ -668,35 +661,23 @@ def update(self, flag_builder: FlagBuilderV2) -> TestDataV2: return self def _make_init_data(self) -> Dict[str, Any]: - try: - self._lock.rlock() + with self._lock.read(): return copy.copy(self._current_flags) - finally: - self._lock.runlock() def _get_version(self) -> int: - try: - self._lock.lock() + with self._lock.write(): version = self._version self._version += 1 return version - finally: - self._lock.unlock() def _closed_instance(self, instance): - try: - self._lock.lock() + with self._lock.write(): if instance in self._instances: self._instances.remove(instance) - finally: - self._lock.unlock() def _add_instance(self, instance): - try: - self._lock.lock() + with self._lock.write(): self._instances.append(instance) - finally: - self._lock.unlock() def build_initializer(self, _: Config) -> _TestDataSourceV2: """