Skip to content

Commit e4d219b

Browse files
authored
[tests] fix training tests (#13442)
* fix textual inversion * fix rest
1 parent e9c092d commit e4d219b

10 files changed

+111
-56
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -895,19 +895,16 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
895895
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
896896

897897
# random initialization of new tokens
898-
embeds = (
899-
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
900-
)
898+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
899+
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
901900
std_token_embedding = embeds.weight.data.std()
902901

903902
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
904903

905904
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
906905
# if initializer_concept are not provided, token embeddings are initialized randomly
907906
if args.initializer_concept is None:
908-
hidden_size = (
909-
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
910-
)
907+
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
911908
embeds.weight.data[train_ids] = (
912909
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
913910
* std_token_embedding
@@ -940,7 +937,8 @@ def save_embeddings(self, file_path: str):
940937
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
941938
for idx, text_encoder in enumerate(self.text_encoders):
942939
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
943-
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
940+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
941+
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
944942
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
945943
new_token_embeddings = embeds.weight.data[train_ids]
946944

@@ -962,7 +960,8 @@ def device(self):
962960
@torch.no_grad()
963961
def retract_embeddings(self):
964962
for idx, text_encoder in enumerate(self.text_encoders):
965-
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
963+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
964+
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
966965
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
967966
embeds.weight.data[index_no_updates] = (
968967
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
@@ -2112,7 +2111,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21122111
if args.train_text_encoder:
21132112
text_encoder_one.train()
21142113
# set top parameter requires_grad = True for gradient checkpointing works
2115-
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
2114+
_te_one = unwrap_model(text_encoder_one)
2115+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
21162116
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
21172117
text_encoder_one.train()
21182118
if args.enable_t5_ti:

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -763,19 +763,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
763763
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
764764

765765
# random initialization of new tokens
766-
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
766+
std_token_embedding = (
767+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
768+
).embeddings.token_embedding.weight.data.std()
767769

768770
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
769771

770-
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
771-
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
772+
(
773+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
774+
).embeddings.token_embedding.weight.data[self.train_ids] = (
775+
torch.randn(
776+
len(self.train_ids),
777+
(
778+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
779+
).config.hidden_size,
780+
)
772781
.to(device=self.device)
773782
.to(dtype=self.dtype)
774783
* std_token_embedding
775784
)
776785
self.embeddings_settings[f"original_embeddings_{idx}"] = (
777-
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
778-
)
786+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
787+
).embeddings.token_embedding.weight.data.clone()
779788
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
780789

781790
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -794,10 +803,14 @@ def save_embeddings(self, file_path: str):
794803
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
795804
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
796805
for idx, text_encoder in enumerate(self.text_encoders):
797-
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
798-
self.tokenizers[0]
799-
), "Tokenizers should be the same."
800-
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
806+
assert (
807+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
808+
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
809+
"Tokenizers should be the same."
810+
)
811+
new_token_embeddings = (
812+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
813+
).embeddings.token_embedding.weight.data[self.train_ids]
801814

802815
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
803816
# text_encoder 1) to keep compatible with the ecosystem.
@@ -819,7 +832,9 @@ def device(self):
819832
def retract_embeddings(self):
820833
for idx, text_encoder in enumerate(self.text_encoders):
821834
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
822-
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
835+
(
836+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
837+
).embeddings.token_embedding.weight.data[index_no_updates] = (
823838
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
824839
.to(device=text_encoder.device)
825840
.to(dtype=text_encoder.dtype)
@@ -830,11 +845,15 @@ def retract_embeddings(self):
830845
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
831846

832847
index_updates = ~index_no_updates
833-
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
848+
new_embeddings = (
849+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
850+
).embeddings.token_embedding.weight.data[index_updates]
834851
off_ratio = std_token_embedding / new_embeddings.std()
835852

836853
new_embeddings = new_embeddings * (off_ratio**0.1)
837-
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
854+
(
855+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
856+
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
838857

839858

840859
class DreamBoothDataset(Dataset):
@@ -1704,7 +1723,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17041723
text_encoder_one.train()
17051724
# set top parameter requires_grad = True for gradient checkpointing works
17061725
if args.train_text_encoder:
1707-
text_encoder_one.text_model.embeddings.requires_grad_(True)
1726+
_te_one = text_encoder_one
1727+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
17081728

17091729
unet.train()
17101730
for step, batch in enumerate(train_dataloader):

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -929,19 +929,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
929929
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
930930

931931
# random initialization of new tokens
932-
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
932+
std_token_embedding = (
933+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
934+
).embeddings.token_embedding.weight.data.std()
933935

934936
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
935937

936-
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
937-
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
938+
(
939+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
940+
).embeddings.token_embedding.weight.data[self.train_ids] = (
941+
torch.randn(
942+
len(self.train_ids),
943+
(
944+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
945+
).config.hidden_size,
946+
)
938947
.to(device=self.device)
939948
.to(dtype=self.dtype)
940949
* std_token_embedding
941950
)
942951
self.embeddings_settings[f"original_embeddings_{idx}"] = (
943-
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
944-
)
952+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
953+
).embeddings.token_embedding.weight.data.clone()
945954
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
946955

947956
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -959,10 +968,14 @@ def save_embeddings(self, file_path: str):
959968
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
960969
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
961970
for idx, text_encoder in enumerate(self.text_encoders):
962-
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
963-
self.tokenizers[0]
964-
), "Tokenizers should be the same."
965-
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
971+
assert (
972+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
973+
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
974+
"Tokenizers should be the same."
975+
)
976+
new_token_embeddings = (
977+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
978+
).embeddings.token_embedding.weight.data[self.train_ids]
966979

967980
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
968981
# text_encoder 1) to keep compatible with the ecosystem.
@@ -984,7 +997,9 @@ def device(self):
984997
def retract_embeddings(self):
985998
for idx, text_encoder in enumerate(self.text_encoders):
986999
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
987-
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
1000+
(
1001+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
1002+
).embeddings.token_embedding.weight.data[index_no_updates] = (
9881003
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
9891004
.to(device=text_encoder.device)
9901005
.to(dtype=text_encoder.dtype)
@@ -995,11 +1010,15 @@ def retract_embeddings(self):
9951010
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
9961011

9971012
index_updates = ~index_no_updates
998-
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
1013+
new_embeddings = (
1014+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
1015+
).embeddings.token_embedding.weight.data[index_updates]
9991016
off_ratio = std_token_embedding / new_embeddings.std()
10001017

10011018
new_embeddings = new_embeddings * (off_ratio**0.1)
1002-
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
1019+
(
1020+
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
1021+
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
10031022

10041023

10051024
class DreamBoothDataset(Dataset):
@@ -2083,8 +2102,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20832102
text_encoder_two.train()
20842103
# set top parameter requires_grad = True for gradient checkpointing works
20852104
if args.train_text_encoder:
2086-
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
2087-
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
2105+
_te_one = accelerator.unwrap_model(text_encoder_one)
2106+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
2107+
_te_two = accelerator.unwrap_model(text_encoder_two)
2108+
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
20882109

20892110
for step, batch in enumerate(train_dataloader):
20902111
if pivoted:

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,10 +874,11 @@ def main(args):
874874
token_embeds[x] = token_embeds[y]
875875

876876
# Freeze all parameters except for the token embeddings in text encoder
877+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
877878
params_to_freeze = itertools.chain(
878-
text_encoder.text_model.encoder.parameters(),
879-
text_encoder.text_model.final_layer_norm.parameters(),
880-
text_encoder.text_model.embeddings.position_embedding.parameters(),
879+
text_module.encoder.parameters(),
880+
text_module.final_layer_norm.parameters(),
881+
text_module.embeddings.position_embedding.parameters(),
881882
)
882883
freeze_params(params_to_freeze)
883884
########################################################

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16911691
if args.train_text_encoder:
16921692
text_encoder_one.train()
16931693
# set top parameter requires_grad = True for gradient checkpointing works
1694-
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1694+
_te_one = unwrap_model(text_encoder_one)
1695+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
16951696

16961697
for step, batch in enumerate(train_dataloader):
16971698
models_to_accumulate = [transformer]

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18961896
if args.train_text_encoder:
18971897
text_encoder_one.train()
18981898
# set top parameter requires_grad = True for gradient checkpointing works
1899-
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1899+
_te_one = unwrap_model(text_encoder_one)
1900+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
19001901

19011902
for step, batch in enumerate(train_dataloader):
19021903
models_to_accumulate = [transformer]

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,8 +1719,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17191719
text_encoder_two.train()
17201720

17211721
# set top parameter requires_grad = True for gradient checkpointing works
1722-
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1723-
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
1722+
_te_one = accelerator.unwrap_model(text_encoder_one)
1723+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
1724+
_te_two = accelerator.unwrap_model(text_encoder_two)
1725+
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
17241726

17251727
for step, batch in enumerate(train_dataloader):
17261728
models_to_accumulate = [transformer]

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,8 +1661,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16611661
text_encoder_two.train()
16621662

16631663
# set top parameter requires_grad = True for gradient checkpointing works
1664-
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1665-
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
1664+
_te_one = accelerator.unwrap_model(text_encoder_one)
1665+
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
1666+
_te_two = accelerator.unwrap_model(text_encoder_two)
1667+
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
16661668

16671669
for step, batch in enumerate(train_dataloader):
16681670
with accelerator.accumulate(unet):

examples/textual_inversion/textual_inversion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,10 @@ def main():
702702
vae.requires_grad_(False)
703703
unet.requires_grad_(False)
704704
# Freeze all parameters except for the token embeddings in text encoder
705-
text_encoder.text_model.encoder.requires_grad_(False)
706-
text_encoder.text_model.final_layer_norm.requires_grad_(False)
707-
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
705+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
706+
text_module.encoder.requires_grad_(False)
707+
text_module.final_layer_norm.requires_grad_(False)
708+
text_module.embeddings.position_embedding.requires_grad_(False)
708709

709710
if args.gradient_checkpointing:
710711
# Keep unet in train mode if we are using gradient checkpointing to save memory.

0 commit comments

Comments
 (0)