diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 03cb1fa..0b82f7d 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -73,7 +73,7 @@ def _find_dep_provider(self, container: "Container", v: SignatureItem) -> "Abstr return provider for x in v.args: provider = container.providers_registry.find_provider(x) - if provider: + if provider is not None and provider is not self: return provider return None diff --git a/tests/providers/test_factory.py b/tests/providers/test_factory.py index 44e1e31..215e7cc 100644 --- a/tests/providers/test_factory.py +++ b/tests/providers/test_factory.py @@ -180,6 +180,23 @@ def second_creator(first_factory: str) -> str: assert app_container.resolve_provider(second_factory) == "one two" +def test_factory_self_reference_in_union_falls_through_to_default() -> None: + @dataclasses.dataclass(kw_only=True, slots=True) + class SelfRef: + x: int = 1 + + def make(x: int | SelfRef = 1) -> SelfRef: + return SelfRef(x=x if isinstance(x, int) else x.x) + + factory = providers.Factory(creator=make) + app_container = Container() + app_container.providers_registry.add_providers(factory) + + result = app_container.resolve(SelfRef) + assert isinstance(result, SelfRef) + assert result.x == 1 + + def test_factory_repr() -> None: provider = providers.Factory(creator=str, scope=Scope.APP) assert repr(provider) == "Factory(creator=, scope=, cached=False)"