Skip to content

Commit ee3c352

Browse files
authored
[CI] Hunyuan Transformer Tests Refactor (#13342)
* update * update * update * update * update * update * update
1 parent 357b681 commit ee3c352

File tree

5 files changed

+379
-346
lines changed

5 files changed

+379
-346
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ class HunyuanVideoTransformer3DModel(
888888
_no_split_modules = [
889889
"HunyuanVideoTransformerBlock",
890890
"HunyuanVideoSingleTransformerBlock",
891+
"HunyuanVideoTokenReplaceTransformerBlock",
892+
"HunyuanVideoTokenReplaceSingleTransformerBlock",
891893
"HunyuanVideoPatchEmbed",
892894
"HunyuanVideoTokenRefiner",
893895
]

tests/models/transformers/test_models_transformer_hunyuan_1_5.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,71 +12,53 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import torch
1816

1917
from diffusers import HunyuanVideo15Transformer3DModel
18+
from diffusers.utils.torch_utils import randn_tensor
2019

2120
from ...testing_utils import enable_full_determinism, torch_device
22-
from ..test_modeling_common import ModelTesterMixin
21+
from ..testing_utils import (
22+
BaseModelTesterConfig,
23+
ModelTesterMixin,
24+
TrainingTesterMixin,
25+
)
2326

2427

2528
enable_full_determinism()
2629

2730

28-
class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
29-
model_class = HunyuanVideo15Transformer3DModel
30-
main_input_name = "hidden_states"
31-
uses_custom_attn_processor = True
32-
model_split_percents = [0.99, 0.99, 0.99]
33-
31+
class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
3432
text_embed_dim = 16
3533
text_embed_2_dim = 8
3634
image_embed_dim = 12
3735

3836
@property
39-
def dummy_input(self):
40-
batch_size = 1
41-
num_channels = 4
42-
num_frames = 1
43-
height = 8
44-
width = 8
45-
sequence_length = 6
46-
sequence_length_2 = 4
47-
image_sequence_length = 3
37+
def model_class(self):
38+
return HunyuanVideo15Transformer3DModel
4839

49-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
50-
timestep = torch.tensor([1.0]).to(torch_device)
51-
encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
52-
encoder_hidden_states_2 = torch.randn(
53-
(batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
54-
)
55-
encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
56-
encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
57-
# All zeros for inducing T2V path in the model.
58-
image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
40+
@property
41+
def main_input_name(self) -> str:
42+
return "hidden_states"
5943

60-
return {
61-
"hidden_states": hidden_states,
62-
"timestep": timestep,
63-
"encoder_hidden_states": encoder_hidden_states,
64-
"encoder_attention_mask": encoder_attention_mask,
65-
"encoder_hidden_states_2": encoder_hidden_states_2,
66-
"encoder_attention_mask_2": encoder_attention_mask_2,
67-
"image_embeds": image_embeds,
68-
}
44+
@property
45+
def model_split_percents(self) -> list:
46+
return [0.99, 0.99, 0.99]
6947

7048
@property
71-
def input_shape(self):
49+
def output_shape(self) -> tuple:
7250
return (4, 1, 8, 8)
7351

7452
@property
75-
def output_shape(self):
53+
def input_shape(self) -> tuple:
7654
return (4, 1, 8, 8)
7755

78-
def prepare_init_args_and_inputs_for_common(self):
79-
init_dict = {
56+
@property
57+
def generator(self):
58+
return torch.Generator("cpu").manual_seed(0)
59+
60+
def get_init_dict(self) -> dict:
61+
return {
8062
"in_channels": 4,
8163
"out_channels": 4,
8264
"num_attention_heads": 2,
@@ -93,9 +75,40 @@ def prepare_init_args_and_inputs_for_common(self):
9375
"target_size": 16,
9476
"task_type": "t2v",
9577
}
96-
inputs_dict = self.dummy_input
97-
return init_dict, inputs_dict
9878

79+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
80+
num_channels = 4
81+
num_frames = 1
82+
height = 8
83+
width = 8
84+
sequence_length = 6
85+
sequence_length_2 = 4
86+
image_sequence_length = 3
87+
88+
return {
89+
"hidden_states": randn_tensor(
90+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
91+
),
92+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
93+
"encoder_hidden_states": randn_tensor(
94+
(batch_size, sequence_length, self.text_embed_dim), generator=self.generator, device=torch_device
95+
),
96+
"encoder_hidden_states_2": randn_tensor(
97+
(batch_size, sequence_length_2, self.text_embed_2_dim), generator=self.generator, device=torch_device
98+
),
99+
"encoder_attention_mask": torch.ones((batch_size, sequence_length), device=torch_device),
100+
"encoder_attention_mask_2": torch.ones((batch_size, sequence_length_2), device=torch_device),
101+
"image_embeds": torch.zeros(
102+
(batch_size, image_sequence_length, self.image_embed_dim), device=torch_device
103+
),
104+
}
105+
106+
107+
class TestHunyuanVideo15Transformer(HunyuanVideo15TransformerTesterConfig, ModelTesterMixin):
108+
pass
109+
110+
111+
class TestHunyuanVideo15TransformerTraining(HunyuanVideo15TransformerTesterConfig, TrainingTesterMixin):
99112
def test_gradient_checkpointing_is_applied(self):
100113
expected_set = {"HunyuanVideo15Transformer3DModel"}
101114
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_hunyuan_dit.py

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,97 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import HunyuanDiT2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
2526
)
26-
from ..test_modeling_common import ModelTesterMixin
2727

2828

2929
enable_full_determinism()
3030

3131

32-
class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = HunyuanDiT2DModel
34-
main_input_name = "hidden_states"
32+
class HunyuanDiTTesterConfig(BaseModelTesterConfig):
33+
@property
34+
def model_class(self):
35+
return HunyuanDiT2DModel
36+
37+
@property
38+
def pretrained_model_name_or_path(self):
39+
return "hf-internal-testing/tiny-hunyuan-dit-pipe"
40+
41+
@property
42+
def pretrained_model_kwargs(self):
43+
return {"subfolder": "transformer"}
44+
45+
@property
46+
def main_input_name(self) -> str:
47+
return "hidden_states"
48+
49+
@property
50+
def output_shape(self) -> tuple:
51+
return (8, 8, 8)
52+
53+
@property
54+
def input_shape(self) -> tuple:
55+
return (4, 8, 8)
3556

3657
@property
37-
def dummy_input(self):
38-
batch_size = 2
58+
def generator(self):
59+
return torch.Generator("cpu").manual_seed(0)
60+
61+
def get_init_dict(self) -> dict:
62+
return {
63+
"sample_size": 8,
64+
"patch_size": 2,
65+
"in_channels": 4,
66+
"num_layers": 1,
67+
"attention_head_dim": 8,
68+
"num_attention_heads": 2,
69+
"cross_attention_dim": 8,
70+
"cross_attention_dim_t5": 8,
71+
"pooled_projection_dim": 4,
72+
"hidden_size": 16,
73+
"text_len": 4,
74+
"text_len_t5": 4,
75+
"activation_fn": "gelu-approximate",
76+
}
77+
78+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
3979
num_channels = 4
4080
height = width = 8
4181
embedding_dim = 8
4282
sequence_length = 4
4383
sequence_length_t5 = 4
4484

45-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
46-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
85+
hidden_states = randn_tensor(
86+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
87+
)
88+
encoder_hidden_states = randn_tensor(
89+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
90+
)
4791
text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device)
48-
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device)
92+
encoder_hidden_states_t5 = randn_tensor(
93+
(batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device
94+
)
4995
text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device)
50-
timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device)
96+
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device)
5197

5298
original_size = [1024, 1024]
5399
target_size = [16, 16]
54100
crops_coords_top_left = [0, 0]
55101
add_time_ids = list(original_size + target_size + crops_coords_top_left)
56-
add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device)
102+
add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device)
57103
style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device)
58104
image_rotary_emb = [
59-
torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype),
60-
torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype),
105+
torch.ones(size=(1, 8), dtype=torch.float32),
106+
torch.zeros(size=(1, 8), dtype=torch.float32),
61107
]
62108

63109
return {
@@ -72,42 +118,14 @@ def dummy_input(self):
72118
"image_rotary_emb": image_rotary_emb,
73119
}
74120

75-
@property
76-
def input_shape(self):
77-
return (4, 8, 8)
78-
79-
@property
80-
def output_shape(self):
81-
return (8, 8, 8)
82-
83-
def prepare_init_args_and_inputs_for_common(self):
84-
init_dict = {
85-
"sample_size": 8,
86-
"patch_size": 2,
87-
"in_channels": 4,
88-
"num_layers": 1,
89-
"attention_head_dim": 8,
90-
"num_attention_heads": 2,
91-
"cross_attention_dim": 8,
92-
"cross_attention_dim_t5": 8,
93-
"pooled_projection_dim": 4,
94-
"hidden_size": 16,
95-
"text_len": 4,
96-
"text_len_t5": 4,
97-
"activation_fn": "gelu-approximate",
98-
}
99-
inputs_dict = self.dummy_input
100-
return init_dict, inputs_dict
101121

122+
class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
102123
def test_output(self):
103-
super().test_output(
104-
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
105-
)
124+
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
125+
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
106126

107-
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
108-
def test_set_xformers_attn_processor_for_determinism(self):
109-
pass
110127

111-
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
112-
def test_set_attn_processor_for_determinism(self):
113-
pass
128+
class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
129+
def test_gradient_checkpointing_is_applied(self):
130+
expected_set = {"HunyuanDiT2DModel"}
131+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)