1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import HunyuanDiT2DModel
19+ from diffusers .utils .torch_utils import randn_tensor
2120
22- from ...testing_utils import (
23- enable_full_determinism ,
24- torch_device ,
21+ from ...testing_utils import enable_full_determinism , torch_device
22+ from ..testing_utils import (
23+ BaseModelTesterConfig ,
24+ ModelTesterMixin ,
25+ TrainingTesterMixin ,
2526)
26- from ..test_modeling_common import ModelTesterMixin
2727
2828
2929enable_full_determinism ()
3030
3131
32- class HunyuanDiTTests (ModelTesterMixin , unittest .TestCase ):
33- model_class = HunyuanDiT2DModel
34- main_input_name = "hidden_states"
32+ class HunyuanDiTTesterConfig (BaseModelTesterConfig ):
33+ @property
34+ def model_class (self ):
35+ return HunyuanDiT2DModel
36+
37+ @property
38+ def pretrained_model_name_or_path (self ):
39+ return "hf-internal-testing/tiny-hunyuan-dit-pipe"
40+
41+ @property
42+ def pretrained_model_kwargs (self ):
43+ return {"subfolder" : "transformer" }
44+
45+ @property
46+ def main_input_name (self ) -> str :
47+ return "hidden_states"
48+
49+ @property
50+ def output_shape (self ) -> tuple :
51+ return (8 , 8 , 8 )
52+
53+ @property
54+ def input_shape (self ) -> tuple :
55+ return (4 , 8 , 8 )
3556
3657 @property
37- def dummy_input (self ):
38- batch_size = 2
58+ def generator (self ):
59+ return torch .Generator ("cpu" ).manual_seed (0 )
60+
61+ def get_init_dict (self ) -> dict :
62+ return {
63+ "sample_size" : 8 ,
64+ "patch_size" : 2 ,
65+ "in_channels" : 4 ,
66+ "num_layers" : 1 ,
67+ "attention_head_dim" : 8 ,
68+ "num_attention_heads" : 2 ,
69+ "cross_attention_dim" : 8 ,
70+ "cross_attention_dim_t5" : 8 ,
71+ "pooled_projection_dim" : 4 ,
72+ "hidden_size" : 16 ,
73+ "text_len" : 4 ,
74+ "text_len_t5" : 4 ,
75+ "activation_fn" : "gelu-approximate" ,
76+ }
77+
78+ def get_dummy_inputs (self , batch_size : int = 2 ) -> dict [str , torch .Tensor ]:
3979 num_channels = 4
4080 height = width = 8
4181 embedding_dim = 8
4282 sequence_length = 4
4383 sequence_length_t5 = 4
4484
45- hidden_states = torch .randn ((batch_size , num_channels , height , width )).to (torch_device )
46- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
85+ hidden_states = randn_tensor (
86+ (batch_size , num_channels , height , width ), generator = self .generator , device = torch_device
87+ )
88+ encoder_hidden_states = randn_tensor (
89+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
90+ )
4791 text_embedding_mask = torch .ones (size = (batch_size , sequence_length )).to (torch_device )
48- encoder_hidden_states_t5 = torch .randn ((batch_size , sequence_length_t5 , embedding_dim )).to (torch_device )
92+ encoder_hidden_states_t5 = randn_tensor (
93+ (batch_size , sequence_length_t5 , embedding_dim ), generator = self .generator , device = torch_device
94+ )
4995 text_embedding_mask_t5 = torch .ones (size = (batch_size , sequence_length_t5 )).to (torch_device )
50- timestep = torch .randint (0 , 1000 , size = (batch_size ,), dtype = encoder_hidden_states . dtype ).to (torch_device )
96+ timestep = torch .randint (0 , 1000 , size = (batch_size ,), generator = self . generator ). float ( ).to (torch_device )
5197
5298 original_size = [1024 , 1024 ]
5399 target_size = [16 , 16 ]
54100 crops_coords_top_left = [0 , 0 ]
55101 add_time_ids = list (original_size + target_size + crops_coords_top_left )
56- add_time_ids = torch .tensor ([add_time_ids , add_time_ids ] , dtype = encoder_hidden_states . dtype ).to (torch_device )
102+ add_time_ids = torch .tensor ([add_time_ids ] * batch_size , dtype = torch . float32 ).to (torch_device )
57103 style = torch .zeros (size = (batch_size ,), dtype = int ).to (torch_device )
58104 image_rotary_emb = [
59- torch .ones (size = (1 , 8 ), dtype = encoder_hidden_states . dtype ),
60- torch .zeros (size = (1 , 8 ), dtype = encoder_hidden_states . dtype ),
105+ torch .ones (size = (1 , 8 ), dtype = torch . float32 ),
106+ torch .zeros (size = (1 , 8 ), dtype = torch . float32 ),
61107 ]
62108
63109 return {
@@ -72,42 +118,14 @@ def dummy_input(self):
72118 "image_rotary_emb" : image_rotary_emb ,
73119 }
74120
75- @property
76- def input_shape (self ):
77- return (4 , 8 , 8 )
78-
79- @property
80- def output_shape (self ):
81- return (8 , 8 , 8 )
82-
83- def prepare_init_args_and_inputs_for_common (self ):
84- init_dict = {
85- "sample_size" : 8 ,
86- "patch_size" : 2 ,
87- "in_channels" : 4 ,
88- "num_layers" : 1 ,
89- "attention_head_dim" : 8 ,
90- "num_attention_heads" : 2 ,
91- "cross_attention_dim" : 8 ,
92- "cross_attention_dim_t5" : 8 ,
93- "pooled_projection_dim" : 4 ,
94- "hidden_size" : 16 ,
95- "text_len" : 4 ,
96- "text_len_t5" : 4 ,
97- "activation_fn" : "gelu-approximate" ,
98- }
99- inputs_dict = self .dummy_input
100- return init_dict , inputs_dict
101121
122+ class TestHunyuanDiT (HunyuanDiTTesterConfig , ModelTesterMixin ):
102123 def test_output (self ):
103- super ().test_output (
104- expected_output_shape = (self .dummy_input [self .main_input_name ].shape [0 ],) + self .output_shape
105- )
124+ batch_size = self .get_dummy_inputs ()[self .main_input_name ].shape [0 ]
125+ super ().test_output (expected_output_shape = (batch_size ,) + self .output_shape )
106126
107- @unittest .skip ("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0" )
108- def test_set_xformers_attn_processor_for_determinism (self ):
109- pass
110127
111- @unittest .skip ("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0" )
112- def test_set_attn_processor_for_determinism (self ):
113- pass
128+ class TestHunyuanDiTTraining (HunyuanDiTTesterConfig , TrainingTesterMixin ):
129+ def test_gradient_checkpointing_is_applied (self ):
130+ expected_set = {"HunyuanDiT2DModel" }
131+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments