diff --git a/pyproject.toml b/pyproject.toml index 9a06ec21..3cdc91c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.1" +version = "0.8.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index 9e2af074..5de8f31c 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -1,6 +1,7 @@ from asyncio import Lock from collections.abc import Callable -from typing import Any, Optional, TypeVar +from contextlib import suppress +from typing import Any, Iterable, Literal, Optional, TypeVar, overload from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.key import DependencyKey @@ -9,6 +10,7 @@ from .dependency_source import FactoryType from .exceptions import ( ExitError, + NoContextValueError, NoFactoryError, ) from .provider import BaseProvider @@ -109,6 +111,45 @@ async def get( async with lock: return await self._get_unlocked(key) + @overload + async def resolve_all(self, components: None = None) -> None: ... + @overload + async def resolve_all(self, components: Literal[True]) -> None: ... + @overload + async def resolve_all(self, components: Iterable[Component]) -> None: ... + + async def resolve_all(self, components: Any = None) -> None: + """ + Resolve all container dependencies in the current scope for the given + components. + + Examples: + >>> container.resolve_all() + Resolve all dependencies for the default component. + + >>> container.resolve_all(True) + Resolve all dependencies for all components. + + >>> container.resolve_all(['component1', 'component2']) + Resolve dependencies for 'component1' and 'component2'. + """ + if not components: + + def component_check(k: DependencyKey) -> bool: + return k.component == DEFAULT_COMPONENT + elif components is True: + + def component_check(k: DependencyKey) -> bool: + return True + else: + + def component_check(k: DependencyKey) -> bool: + return k.component in components + + for key in filter(component_check, self.registry.factories): + with suppress(NoContextValueError): + await self._get_unlocked(key) + async def _get_unlocked(self, key: DependencyKey) -> Any: if key in self.context: return self.context[key] diff --git a/src/dishka/container.py b/src/dishka/container.py index 2952673d..ac437533 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -1,6 +1,7 @@ from collections.abc import Callable +from contextlib import suppress from threading import Lock -from typing import Any, Optional, TypeVar +from typing import Any, Iterable, Literal, Optional, TypeVar, overload from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.key import DependencyKey @@ -9,6 +10,7 @@ from .dependency_source import FactoryType from .exceptions import ( ExitError, + NoContextValueError, NoFactoryError, ) from .provider import BaseProvider @@ -107,6 +109,45 @@ def get( with lock: return self._get_unlocked(key) + @overload + def resolve_all(self, components: None = None) -> None: ... + @overload + def resolve_all(self, components: Literal[True]) -> None: ... + @overload + def resolve_all(self, components: Iterable[Component]) -> None: ... + + def resolve_all(self, components: Any = None) -> None: + """ + Resolve all container dependencies in the current scope for the given + components. + + Examples: + >>> container.resolve_all() + Resolve all dependencies for the default component. + + >>> container.resolve_all(True) + Resolve all dependencies for all components. + + >>> container.resolve_all(['component1', 'component2']) + Resolve dependencies for 'component1' and 'component2'. + """ + if not components: + + def component_check(k: DependencyKey) -> bool: + return k.component == DEFAULT_COMPONENT + elif components is True: + + def component_check(k: DependencyKey) -> bool: + return True + else: + + def component_check(k: DependencyKey) -> bool: + return k.component in components + + for key in filter(component_check, self.registry.factories): + with suppress(NoContextValueError): + self._get_unlocked(key) + def _get_unlocked(self, key: DependencyKey) -> Any: if key in self.context: return self.context[key] diff --git a/tests/unit/container/test_components.py b/tests/unit/container/test_components.py index 492c7a07..4f1407e4 100644 --- a/tests/unit/container/test_components.py +++ b/tests/unit/container/test_components.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Literal import pytest @@ -55,6 +55,24 @@ def foo(self, a: Annotated[int, FromComponent()]) -> float: return a + 1 +class YProvider(Provider): + scope = Scope.APP + component = "Y" + + @provide + def foo(self) -> float: + return 42 + + +class ZProvider(Provider): + scope = Scope.APP + component = "Z" + + @provide + def foo(self) -> bool: + return True + + def test_from_component(): container = make_container(MainProvider(20), XProvider()) assert container.get(complex) == 210 @@ -63,6 +81,31 @@ def test_from_component(): container.get(float) +@pytest.mark.parametrize( + ("component", "expected_count"), + [ + (None, 4), + (("",), 4), + (True, 6), + (("X",), 3), + (("X", ""), 4), + (("X", "Y"), 4), + (("X", "Y", ""), 5), + (("X", "Y", "Z"), 5), + (("X", "Y", "Z", ""), 6), + ], +) +def test_from_component_resolve_all( + component: Literal[True] | tuple[Component] | None, expected_count: int +): + container = make_container( + MainProvider(20), XProvider(), YProvider(), ZProvider() + ) + assert len(container.context) == 1 + container.resolve_all(component) + assert len(container.context) == expected_count + + @pytest.mark.asyncio() async def test_from_component_async(): container = make_async_container(MainProvider(20), XProvider()) @@ -72,6 +115,32 @@ async def test_from_component_async(): await container.get(float) +@pytest.mark.parametrize( + ("component", "expected_count"), + [ + (None, 4), + (("",), 4), + (True, 6), + (("X",), 3), + (("X", ""), 4), + (("X", "Y"), 4), + (("X", "Y", ""), 5), + (("X", "Y", "Z"), 5), + (("X", "Y", "Z", ""), 6), + ], +) +@pytest.mark.asyncio +async def test_from_component_resolve_all_async( + component: Literal[True] | tuple[Component] | None, expected_count: int +): + container = make_async_container( + MainProvider(20), XProvider(), YProvider(), ZProvider() + ) + assert len(container.context) == 1 + await container.resolve_all(component) + assert len(container.context) == expected_count + + class SingleProvider(Provider): scope = Scope.APP diff --git a/tests/unit/container/test_context_vars.py b/tests/unit/container/test_context_vars.py index 1cac5877..12be5f28 100644 --- a/tests/unit/container/test_context_vars.py +++ b/tests/unit/container/test_context_vars.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from dishka import ( @@ -9,6 +11,7 @@ ) from dishka.dependency_source import from_context from dishka.exceptions import NoContextValueError +from ..sample_providers import ClassA def test_simple(): @@ -18,6 +21,35 @@ def test_simple(): assert container.get(int) == 1 +class AProvider(Provider): + scope = Scope.APP + a = from_context(provides=int) + b = from_context(provides=str) + + @provide + def foo(self, a: int) -> ClassA: + return ClassA(a) + + @provide + def bar(self, a: str) -> bool: + return bool(a) + + +@pytest.mark.parametrize( + ("context", "expected_count"), + [ + ({}, 1), + ({int: 1}, 3), + ({int: 1, str: "1"}, 5), + ], +) +def test_simple_resolve_all(context: dict[type, Any], expected_count: int): + provider = AProvider() + container = make_container(provider, context=context) + container.resolve_all() + assert len(container.context) == expected_count + + @pytest.mark.asyncio async def test_simple_async(): provider = Provider() @@ -26,6 +58,24 @@ async def test_simple_async(): assert await container.get(int) == 1 +@pytest.mark.parametrize( + ("context", "expected_count"), + [ + ({}, 1), + ({int: 1}, 3), + ({int: 1, str: "1"}, 5), + ], +) +@pytest.mark.asyncio +async def test_simple_resolve_all_async( + context: dict[type, Any], expected_count: int +): + provider = AProvider() + container = make_async_container(provider, context=context) + await container.resolve_all() + assert len(container.context) == expected_count + + def test_not_found(): provider = Provider() provider.from_context(provides=int, scope=Scope.APP) diff --git a/tests/unit/container/test_resolve.py b/tests/unit/container/test_resolve.py index 3954dfa0..57f43528 100644 --- a/tests/unit/container/test_resolve.py +++ b/tests/unit/container/test_resolve.py @@ -1,3 +1,5 @@ +from typing import Any, Callable + import pytest from dishka import ( @@ -129,3 +131,66 @@ def test_external_method(method): container = make_container(provider) assert container.get(ClassA) is A_VALUE + + +@pytest.mark.parametrize( + ("factory", "cache", "expected_count"), + [ + (ClassA, True, 3), + (ClassA, False, 2), + (sync_func_a, True, 3), + (sync_func_a, False, 2), + (sync_iter_a, True, 3), + (sync_iter_a, False, 2), + (sync_gen_a, True, 3), + (sync_gen_a, False, 2), + ], +) +def test_sync_resolve_all( + factory: Callable[..., Any], cache: bool, expected_count: int +): + class MyProvider(Provider): + a = provide(factory, scope=Scope.APP, cache=cache) + + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 100 + + container = make_container(MyProvider()) + assert container.registry.scope is Scope.APP + assert len(container.context) == 1 + container.resolve_all() + assert len(container.context) == expected_count + container.close() + + +@pytest.mark.parametrize( + ("factory", "cache", "expected_count"), + [ + (ClassA, True, 3), + (ClassA, False, 2), + (async_func_a, True, 3), + (async_func_a, False, 2), + (async_iter_a, True, 3), + (async_iter_a, False, 2), + (async_gen_a, True, 3), + (async_gen_a, False, 2), + ], +) +@pytest.mark.asyncio +async def test_async_resolve_all( + factory: Callable[..., Any], cache: bool, expected_count: int +): + class MyProvider(Provider): + a = provide(factory, scope=Scope.APP, cache=cache) + + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 100 + + container = make_async_container(MyProvider()) + assert container.registry.scope is Scope.APP + assert len(container.context) == 1 + await container.resolve_all() + assert len(container.context) == expected_count + await container.close()