Skip to content

Commit 896fec3

Browse files
authored
[tests] tighten dependency testing. (#13332)
* tighten dependency testing. * invoke dependency testing temporarily. * f
1 parent 4548e68 commit 896fec3

File tree

4 files changed

+46
-10
lines changed

4 files changed

+46
-10
lines changed

.github/workflows/pr_dependency_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
- main
77
paths:
88
- "src/diffusers/**.py"
9+
- "tests/**.py"
910
push:
1011
branches:
1112
- main

.github/workflows/pr_torch_dependency_test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
- main
77
paths:
88
- "src/diffusers/**.py"
9+
- "tests/**.py"
910
push:
1011
branches:
1112
- main
@@ -26,7 +27,7 @@ jobs:
2627
- name: Install dependencies
2728
run: |
2829
pip install -e .
29-
pip install torch torchvision torchaudio pytest
30+
pip install torch pytest
3031
- name: Check for soft dependencies
3132
run: |
3233
pytest tests/others/test_dependencies.py

src/diffusers/pipelines/consisid/consisid_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import numpy as np
66
import torch
77
from PIL import Image, ImageOps
8-
from torchvision.transforms import InterpolationMode
9-
from torchvision.transforms.functional import normalize, resize
108

11-
from ...utils import get_logger, load_image
9+
from ...utils import get_logger, is_torchvision_available, load_image
10+
11+
12+
if is_torchvision_available():
13+
from torchvision.transforms import InterpolationMode
14+
from torchvision.transforms.functional import normalize, resize
1215

1316

1417
logger = get_logger(__name__)

tests/others/test_dependencies.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import unittest
1716
from importlib import import_module
1817

18+
import pytest
1919

20-
class DependencyTester(unittest.TestCase):
20+
21+
class TestDependencies:
2122
def test_diffusers_import(self):
22-
try:
23-
import diffusers # noqa: F401
24-
except ImportError:
25-
assert False
23+
import diffusers # noqa: F401
2624

2725
def test_backend_registration(self):
2826
import diffusers
@@ -52,3 +50,36 @@ def test_pipeline_imports(self):
5250
if hasattr(diffusers.pipelines, cls_name):
5351
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
5452
_ = import_module(pipeline_folder_module, str(cls_name))
53+
54+
def test_pipeline_module_imports(self):
55+
"""Import every pipeline submodule whose dependencies are satisfied,
56+
to catch unguarded optional-dep imports (e.g., torchvision).
57+
58+
Uses inspect.getmembers to discover classes that the lazy loader can
59+
actually resolve (same self-filtering as test_pipeline_imports), then
60+
imports the full module path instead of truncating to the folder level.
61+
"""
62+
import diffusers
63+
import diffusers.pipelines
64+
65+
failures = []
66+
all_classes = inspect.getmembers(diffusers, inspect.isclass)
67+
68+
for cls_name, cls_module in all_classes:
69+
if not hasattr(diffusers.pipelines, cls_name):
70+
continue
71+
if "dummy_" in cls_module.__module__:
72+
continue
73+
74+
full_module_path = cls_module.__module__
75+
try:
76+
import_module(full_module_path)
77+
except ImportError as e:
78+
failures.append(f"{full_module_path}: {e}")
79+
except Exception:
80+
# Non-import errors (e.g., missing config) are fine; we only
81+
# care about unguarded import statements.
82+
pass
83+
84+
if failures:
85+
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))

0 commit comments

Comments
 (0)