Skip to content

Commit 09efbff

Browse files
ZipFilefederinikjazzthief
authored
Fix Closing dependency resolution (#852)
Co-authored-by: federinik <[email protected]> Co-authored-by: jazzthief <[email protected]>
1 parent 8b625d8 commit 09efbff

File tree

3 files changed

+135
-59
lines changed

3 files changed

+135
-59
lines changed

src/dependency_injector/wiring.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""Wiring module."""
22

33
import functools
4-
import inspect
54
import importlib
65
import importlib.machinery
6+
import inspect
77
import pkgutil
8-
import warnings
98
import sys
9+
import warnings
1010
from types import ModuleType
1111
from typing import (
12-
Optional,
13-
Iterable,
14-
Iterator,
15-
Callable,
1612
Any,
17-
Tuple,
13+
Callable,
1814
Dict,
1915
Generic,
20-
TypeVar,
16+
Iterable,
17+
Iterator,
18+
Optional,
19+
Set,
20+
Tuple,
2121
Type,
22+
TypeVar,
2223
Union,
23-
Set,
2424
cast,
2525
)
2626

@@ -643,21 +643,18 @@ def _fetch_reference_injections( # noqa: C901
643643

644644

645645
def _locate_dependent_closing_args(
646-
provider: providers.Provider,
646+
provider: providers.Provider, closing_deps: Dict[str, providers.Provider]
647647
) -> Dict[str, providers.Provider]:
648-
if not hasattr(provider, "args"):
649-
return {}
650-
651-
closing_deps = {}
652-
for arg in provider.args:
653-
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
648+
for arg in [
649+
*getattr(provider, "args", []),
650+
*getattr(provider, "kwargs", {}).values(),
651+
]:
652+
if not isinstance(arg, providers.Provider):
654653
continue
654+
if isinstance(arg, providers.Resource):
655+
closing_deps[str(id(arg))] = arg
655656

656-
if not arg.args and isinstance(arg, providers.Resource):
657-
return {str(id(arg)): arg}
658-
else:
659-
closing_deps += _locate_dependent_closing_args(arg)
660-
return closing_deps
657+
_locate_dependent_closing_args(arg, closing_deps)
661658

662659

663660
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
@@ -681,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
681678

682679
if injection in patched_callable.reference_closing:
683680
patched_callable.add_closing(injection, provider)
684-
deps = _locate_dependent_closing_args(provider)
681+
deps = {}
682+
_locate_dependent_closing_args(provider, deps)
685683
for key, dep in deps.items():
686684
patched_callable.add_closing(key, dep)
687685

Original file line numberDiff line numberDiff line change
@@ -1,41 +1,80 @@
1+
from typing import Any, Dict, List, Optional
2+
13
from dependency_injector import containers, providers
2-
from dependency_injector.wiring import inject, Provide, Closing
4+
from dependency_injector.wiring import Closing, Provide, inject
5+
6+
7+
class Counter:
8+
def __init__(self) -> None:
9+
self._init = 0
10+
self._shutdown = 0
11+
12+
def init(self) -> None:
13+
self._init += 1
14+
15+
def shutdown(self) -> None:
16+
self._shutdown += 1
17+
18+
def reset(self) -> None:
19+
self._init = 0
20+
self._shutdown = 0
321

422

523
class Service:
6-
init_counter: int = 0
7-
shutdown_counter: int = 0
24+
def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
25+
self.counter = counter or Counter()
26+
self.dependencies = dependencies
27+
28+
def init(self) -> None:
29+
self.counter.init()
830

9-
@classmethod
10-
def reset_counter(cls):
11-
cls.init_counter = 0
12-
cls.shutdown_counter = 0
31+
def shutdown(self) -> None:
32+
self.counter.shutdown()
1333

14-
@classmethod
15-
def init(cls):
16-
cls.init_counter += 1
34+
@property
35+
def init_counter(self) -> int:
36+
return self.counter._init
1737

18-
@classmethod
19-
def shutdown(cls):
20-
cls.shutdown_counter += 1
38+
@property
39+
def shutdown_counter(self) -> int:
40+
return self.counter._shutdown
2141

2242

2343
class FactoryService:
24-
def __init__(self, service: Service):
44+
def __init__(self, service: Service, service2: Service):
2545
self.service = service
46+
self.service2 = service2
47+
48+
49+
class NestedService:
50+
def __init__(self, factory_service: FactoryService):
51+
self.factory_service = factory_service
2652

2753

28-
def init_service():
29-
service = Service()
54+
def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
55+
service = Service(counter, _list=_list, _dict=_dict)
3056
service.init()
3157
yield service
3258
service.shutdown()
3359

3460

3561
class Container(containers.DeclarativeContainer):
36-
37-
service = providers.Resource(init_service)
38-
factory_service = providers.Factory(FactoryService, service)
62+
counter = providers.Singleton(Counter)
63+
_list = providers.List(
64+
providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
65+
)
66+
_dict = providers.Dict(
67+
a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4)
68+
)
69+
service = providers.Resource(init_service, counter, _list, _dict=_dict)
70+
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
71+
factory_service = providers.Factory(FactoryService, service, service2)
72+
factory_service_kwargs = providers.Factory(
73+
FactoryService,
74+
service=service,
75+
service2=service2,
76+
)
77+
nested_service = providers.Factory(NestedService, factory_service)
3978

4079

4180
@inject
@@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]):
4483

4584

4685
@inject
47-
def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]):
86+
def test_function_dependency(
87+
factory: FactoryService = Closing[Provide["factory_service"]],
88+
):
89+
return factory
90+
91+
92+
@inject
93+
def test_function_dependency_kwargs(
94+
factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
95+
):
4896
return factory
97+
98+
99+
@inject
100+
def test_function_nested_dependency(
101+
nested: NestedService = Closing[Provide["nested_service"]],
102+
):
103+
return nested

tests/unit/wiring/string_ids/test_main_py36.py

+40-17
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from decimal import Decimal
44

5-
from dependency_injector import errors
6-
from dependency_injector.wiring import Closing, Provide, Provider, wire
75
from pytest import fixture, mark, raises
8-
96
from samples.wiringstringids import module, package, resourceclosing
10-
from samples.wiringstringids.service import Service
117
from samples.wiringstringids.container import Container, SubContainer
8+
from samples.wiringstringids.service import Service
9+
10+
from dependency_injector import errors
11+
from dependency_injector.wiring import Closing, Provide, Provider, wire
1212

1313

1414
@fixture(autouse=True)
@@ -34,10 +34,11 @@ def subcontainer():
3434

3535

3636
@fixture
37-
def resourceclosing_container():
37+
def resourceclosing_container(request):
3838
container = resourceclosing.Container()
3939
container.wire(modules=[resourceclosing])
40-
yield container
40+
with container.reset_singletons():
41+
yield container
4142
container.unwire()
4243

4344

@@ -274,42 +275,65 @@ def test_wire_multiple_containers():
274275

275276
@mark.usefixtures("resourceclosing_container")
276277
def test_closing_resource():
277-
resourceclosing.Service.reset_counter()
278-
279278
result_1 = resourceclosing.test_function()
280279
assert isinstance(result_1, resourceclosing.Service)
281280
assert result_1.init_counter == 1
282281
assert result_1.shutdown_counter == 1
282+
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
283283

284284
result_2 = resourceclosing.test_function()
285285
assert isinstance(result_2, resourceclosing.Service)
286286
assert result_2.init_counter == 2
287287
assert result_2.shutdown_counter == 2
288+
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
288289

289290
assert result_1 is not result_2
290291

291292

292293
@mark.usefixtures("resourceclosing_container")
293294
def test_closing_dependency_resource():
294-
resourceclosing.Service.reset_counter()
295-
296295
result_1 = resourceclosing.test_function_dependency()
297296
assert isinstance(result_1, resourceclosing.FactoryService)
298-
assert result_1.service.init_counter == 1
299-
assert result_1.service.shutdown_counter == 1
297+
assert result_1.service.init_counter == 2
298+
assert result_1.service.shutdown_counter == 2
300299

301300
result_2 = resourceclosing.test_function_dependency()
301+
302302
assert isinstance(result_2, resourceclosing.FactoryService)
303-
assert result_2.service.init_counter == 2
304-
assert result_2.service.shutdown_counter == 2
303+
assert result_2.service.init_counter == 4
304+
assert result_2.service.shutdown_counter == 4
305+
306+
307+
@mark.usefixtures("resourceclosing_container")
308+
def test_closing_dependency_resource_kwargs():
309+
result_1 = resourceclosing.test_function_dependency_kwargs()
310+
assert isinstance(result_1, resourceclosing.FactoryService)
311+
assert result_1.service.init_counter == 2
312+
assert result_1.service.shutdown_counter == 2
313+
314+
result_2 = resourceclosing.test_function_dependency_kwargs()
315+
assert isinstance(result_2, resourceclosing.FactoryService)
316+
assert result_2.service.init_counter == 4
317+
assert result_2.service.shutdown_counter == 4
318+
319+
320+
@mark.usefixtures("resourceclosing_container")
321+
def test_closing_nested_dependency_resource():
322+
result_1 = resourceclosing.test_function_nested_dependency()
323+
assert isinstance(result_1, resourceclosing.NestedService)
324+
assert result_1.factory_service.service.init_counter == 2
325+
assert result_1.factory_service.service.shutdown_counter == 2
326+
327+
result_2 = resourceclosing.test_function_nested_dependency()
328+
assert isinstance(result_2, resourceclosing.NestedService)
329+
assert result_2.factory_service.service.init_counter == 4
330+
assert result_2.factory_service.service.shutdown_counter == 4
305331

306332
assert result_1 is not result_2
307333

308334

309335
@mark.usefixtures("resourceclosing_container")
310336
def test_closing_resource_bypass_marker_injection():
311-
resourceclosing.Service.reset_counter()
312-
313337
result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
314338
assert isinstance(result_1, resourceclosing.Service)
315339
assert result_1.init_counter == 1
@@ -325,7 +349,6 @@ def test_closing_resource_bypass_marker_injection():
325349

326350
@mark.usefixtures("resourceclosing_container")
327351
def test_closing_resource_context():
328-
resourceclosing.Service.reset_counter()
329352
service = resourceclosing.Service()
330353

331354
result_1 = resourceclosing.test_function(service=service)

0 commit comments

Comments
 (0)