|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import inspect |
16 | | -import unittest |
17 | 16 | from importlib import import_module |
18 | 17 |
|
| 18 | +import pytest |
19 | 19 |
|
20 | | -class DependencyTester(unittest.TestCase): |
| 20 | + |
| 21 | +class TestDependencies: |
21 | 22 | def test_diffusers_import(self): |
22 | | - try: |
23 | | - import diffusers # noqa: F401 |
24 | | - except ImportError: |
25 | | - assert False |
| 23 | + import diffusers # noqa: F401 |
26 | 24 |
|
27 | 25 | def test_backend_registration(self): |
28 | 26 | import diffusers |
@@ -52,3 +50,36 @@ def test_pipeline_imports(self): |
52 | 50 | if hasattr(diffusers.pipelines, cls_name): |
53 | 51 | pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3]) |
54 | 52 | _ = import_module(pipeline_folder_module, str(cls_name)) |
| 53 | + |
| 54 | + def test_pipeline_module_imports(self): |
| 55 | + """Import every pipeline submodule whose dependencies are satisfied, |
| 56 | + to catch unguarded optional-dep imports (e.g., torchvision). |
| 57 | +
|
| 58 | + Uses inspect.getmembers to discover classes that the lazy loader can |
| 59 | + actually resolve (same self-filtering as test_pipeline_imports), then |
| 60 | + imports the full module path instead of truncating to the folder level. |
| 61 | + """ |
| 62 | + import diffusers |
| 63 | + import diffusers.pipelines |
| 64 | + |
| 65 | + failures = [] |
| 66 | + all_classes = inspect.getmembers(diffusers, inspect.isclass) |
| 67 | + |
| 68 | + for cls_name, cls_module in all_classes: |
| 69 | + if not hasattr(diffusers.pipelines, cls_name): |
| 70 | + continue |
| 71 | + if "dummy_" in cls_module.__module__: |
| 72 | + continue |
| 73 | + |
| 74 | + full_module_path = cls_module.__module__ |
| 75 | + try: |
| 76 | + import_module(full_module_path) |
| 77 | + except ImportError as e: |
| 78 | + failures.append(f"{full_module_path}: {e}") |
| 79 | + except Exception: |
| 80 | + # Non-import errors (e.g., missing config) are fine; we only |
| 81 | + # care about unguarded import statements. |
| 82 | + pass |
| 83 | + |
| 84 | + if failures: |
| 85 | + pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures)) |
0 commit comments