Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Closing dependency resolution V2 #852

Merged
merged 14 commits into from
Feb 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""Wiring module."""

import functools
import inspect
import importlib
import importlib.machinery
import inspect
import pkgutil
import warnings
import sys
import warnings
from types import ModuleType
from typing import (
Optional,
Iterable,
Iterator,
Callable,
Any,
Tuple,
Callable,
Dict,
Generic,
TypeVar,
Iterable,
Iterator,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
Set,
cast,
)

@@ -643,21 +643,18 @@ def _fetch_reference_injections( # noqa: C901


def _locate_dependent_closing_args(
provider: providers.Provider,
provider: providers.Provider, closing_deps: Dict[str, providers.Provider]
) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"):
return {}

closing_deps = {}
for arg in provider.args:
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
for arg in [
*getattr(provider, "args", []),
*getattr(provider, "kwargs", {}).values(),
]:
if not isinstance(arg, providers.Provider):
continue
if isinstance(arg, providers.Resource):
closing_deps[str(id(arg))] = arg

if not arg.args and isinstance(arg, providers.Resource):
return {str(id(arg)): arg}
else:
closing_deps += _locate_dependent_closing_args(arg)
return closing_deps
_locate_dependent_closing_args(arg, closing_deps)


def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
@@ -681,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non

if injection in patched_callable.reference_closing:
patched_callable.add_closing(injection, provider)
deps = _locate_dependent_closing_args(provider)
deps = {}
_locate_dependent_closing_args(provider, deps)
for key, dep in deps.items():
patched_callable.add_closing(key, dep)

95 changes: 75 additions & 20 deletions tests/unit/samples/wiringstringids/resourceclosing.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,80 @@
from typing import Any, Dict, List, Optional

from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
from dependency_injector.wiring import Closing, Provide, inject


class Counter:
def __init__(self) -> None:
self._init = 0
self._shutdown = 0

def init(self) -> None:
self._init += 1

def shutdown(self) -> None:
self._shutdown += 1

def reset(self) -> None:
self._init = 0
self._shutdown = 0


class Service:
init_counter: int = 0
shutdown_counter: int = 0
def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
self.counter = counter or Counter()
self.dependencies = dependencies

def init(self) -> None:
self.counter.init()

@classmethod
def reset_counter(cls):
cls.init_counter = 0
cls.shutdown_counter = 0
def shutdown(self) -> None:
self.counter.shutdown()

@classmethod
def init(cls):
cls.init_counter += 1
@property
def init_counter(self) -> int:
return self.counter._init

@classmethod
def shutdown(cls):
cls.shutdown_counter += 1
@property
def shutdown_counter(self) -> int:
return self.counter._shutdown


class FactoryService:
def __init__(self, service: Service):
def __init__(self, service: Service, service2: Service):
self.service = service
self.service2 = service2


class NestedService:
def __init__(self, factory_service: FactoryService):
self.factory_service = factory_service


def init_service():
service = Service()
def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
service = Service(counter, _list=_list, _dict=_dict)
service.init()
yield service
service.shutdown()


class Container(containers.DeclarativeContainer):

service = providers.Resource(init_service)
factory_service = providers.Factory(FactoryService, service)
counter = providers.Singleton(Counter)
_list = providers.List(
providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
)
_dict = providers.Dict(
a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4)
)
service = providers.Resource(init_service, counter, _list, _dict=_dict)
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
factory_service = providers.Factory(FactoryService, service, service2)
factory_service_kwargs = providers.Factory(
FactoryService,
service=service,
service2=service2,
)
nested_service = providers.Factory(NestedService, factory_service)


@inject
@@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]):


@inject
def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]):
def test_function_dependency(
factory: FactoryService = Closing[Provide["factory_service"]],
):
return factory


@inject
def test_function_dependency_kwargs(
factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
):
return factory


@inject
def test_function_nested_dependency(
nested: NestedService = Closing[Provide["nested_service"]],
):
return nested
57 changes: 40 additions & 17 deletions tests/unit/wiring/string_ids/test_main_py36.py
Original file line number Diff line number Diff line change
@@ -2,13 +2,13 @@

from decimal import Decimal

from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire
from pytest import fixture, mark, raises

from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.service import Service
from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service

from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire


@fixture(autouse=True)
@@ -34,10 +34,11 @@ def subcontainer():


@fixture
def resourceclosing_container():
def resourceclosing_container(request):
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
yield container
with container.reset_singletons():
yield container
container.unwire()


@@ -274,42 +275,65 @@ def test_wire_multiple_containers():

@mark.usefixtures("resourceclosing_container")
def test_closing_resource():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function()
assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1
assert result_1.shutdown_counter == 1
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}

result_2 = resourceclosing.test_function()
assert isinstance(result_2, resourceclosing.Service)
assert result_2.init_counter == 2
assert result_2.shutdown_counter == 2
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}

assert result_1 is not result_2


@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function_dependency()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1
assert result_1.service.shutdown_counter == 1
assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_dependency()

assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2
assert result_2.service.shutdown_counter == 2
assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 4


@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource_kwargs():
result_1 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 4


@mark.usefixtures("resourceclosing_container")
def test_closing_nested_dependency_resource():
result_1 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_1, resourceclosing.NestedService)
assert result_1.factory_service.service.init_counter == 2
assert result_1.factory_service.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_2, resourceclosing.NestedService)
assert result_2.factory_service.service.init_counter == 4
assert result_2.factory_service.service.shutdown_counter == 4

assert result_1 is not result_2


@mark.usefixtures("resourceclosing_container")
def test_closing_resource_bypass_marker_injection():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1
@@ -325,7 +349,6 @@ def test_closing_resource_bypass_marker_injection():

@mark.usefixtures("resourceclosing_container")
def test_closing_resource_context():
resourceclosing.Service.reset_counter()
service = resourceclosing.Service()

result_1 = resourceclosing.test_function(service=service)