1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
16+ import pytest
17+ import torch
1718
1819from diffusers import AutoencoderDC
20+ from diffusers .utils .torch_utils import randn_tensor
1921
20- from ...testing_utils import IS_GITHUB_ACTIONS , enable_full_determinism , floats_tensor , torch_device
21- from ..test_modeling_common import ModelTesterMixin
22- from .testing_utils import AutoencoderTesterMixin
22+ from ...testing_utils import IS_GITHUB_ACTIONS , enable_full_determinism , torch_device
23+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
24+ from .testing_utils import NewAutoencoderTesterMixin
2325
2426
2527enable_full_determinism ()
2628
2729
28- class AutoencoderDCTests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
29- model_class = AutoencoderDC
30- main_input_name = "sample"
31- base_precision = 1e-2
30+ class AutoencoderDCTesterConfig (BaseModelTesterConfig ):
31+ @property
32+ def model_class (self ):
33+ return AutoencoderDC
34+
35+ @property
36+ def output_shape (self ):
37+ return (3 , 32 , 32 )
38+
39+ @property
40+ def generator (self ):
41+ return torch .Generator ("cpu" ).manual_seed (0 )
3242
33- def get_autoencoder_dc_config (self ):
43+ def get_init_dict (self ):
3444 return {
3545 "in_channels" : 3 ,
3646 "latent_channels" : 4 ,
@@ -56,33 +66,29 @@ def get_autoencoder_dc_config(self):
5666 "scaling_factor" : 0.41407 ,
5767 }
5868
59- @property
60- def dummy_input (self ):
69+ def get_dummy_inputs (self ):
6170 batch_size = 4
6271 num_channels = 3
6372 sizes = (32 , 32 )
73+ image = randn_tensor ((batch_size , num_channels , * sizes ), generator = self .generator , device = torch_device )
74+ return {"sample" : image }
6475
65- image = floats_tensor ((batch_size , num_channels ) + sizes ).to (torch_device )
6676
67- return {"sample" : image }
77+ class TestAutoencoderDC (AutoencoderDCTesterConfig , ModelTesterMixin ):
78+ base_precision = 1e-2
6879
69- @property
70- def input_shape (self ):
71- return (3 , 32 , 32 )
7280
73- @property
74- def output_shape (self ):
75- return (3 , 32 , 32 )
81+ class TestAutoencoderDCTraining (AutoencoderDCTesterConfig , TrainingTesterMixin ):
82+ """Training tests for AutoencoderDC."""
7683
77- def prepare_init_args_and_inputs_for_common (self ):
78- init_dict = self .get_autoencoder_dc_config ()
79- inputs_dict = self .dummy_input
80- return init_dict , inputs_dict
8184
82- @unittest .skipIf (IS_GITHUB_ACTIONS , reason = "Skipping test inside GitHub Actions environment" )
83- def test_layerwise_casting_inference (self ):
84- super ().test_layerwise_casting_inference ()
85+ class TestAutoencoderDCMemory (AutoencoderDCTesterConfig , MemoryTesterMixin ):
86+ """Memory optimization tests for AutoencoderDC."""
8587
86- @unittest . skipIf (IS_GITHUB_ACTIONS , reason = "Skipping test inside GitHub Actions environment" )
88+ @pytest . mark . skipif (IS_GITHUB_ACTIONS , reason = "Skipping test inside GitHub Actions environment" )
8789 def test_layerwise_casting_memory (self ):
8890 super ().test_layerwise_casting_memory ()
91+
92+
93+ class TestAutoencoderDCSlicingTiling (AutoencoderDCTesterConfig , NewAutoencoderTesterMixin ):
94+ """Slicing and tiling tests for AutoencoderDC."""
0 commit comments