@@ -929,19 +929,28 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
929929 self .train_ids = tokenizer .convert_tokens_to_ids (self .inserting_toks )
930930
931931 # random initialization of new tokens
932- std_token_embedding = text_encoder .text_model .embeddings .token_embedding .weight .data .std ()
932+ std_token_embedding = (
933+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
934+ ).embeddings .token_embedding .weight .data .std ()
933935
934936 print (f"{ idx } text encoder's std_token_embedding: { std_token_embedding } " )
935937
936- text_encoder .text_model .embeddings .token_embedding .weight .data [self .train_ids ] = (
937- torch .randn (len (self .train_ids ), text_encoder .text_model .config .hidden_size )
938+ (
939+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
940+ ).embeddings .token_embedding .weight .data [self .train_ids ] = (
941+ torch .randn (
942+ len (self .train_ids ),
943+ (
944+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
945+ ).config .hidden_size ,
946+ )
938947 .to (device = self .device )
939948 .to (dtype = self .dtype )
940949 * std_token_embedding
941950 )
942951 self .embeddings_settings [f"original_embeddings_{ idx } " ] = (
943- text_encoder .text_model . embeddings . token_embedding . weight . data . clone ()
944- )
952+ text_encoder .text_model if hasattr ( text_encoder , "text_model" ) else text_encoder
953+ ). embeddings . token_embedding . weight . data . clone ()
945954 self .embeddings_settings [f"std_token_embedding_{ idx } " ] = std_token_embedding
946955
947956 inu = torch .ones ((len (tokenizer ),), dtype = torch .bool )
@@ -959,10 +968,14 @@ def save_embeddings(self, file_path: str):
959968 # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
960969 idx_to_text_encoder_name = {0 : "clip_l" , 1 : "clip_g" }
961970 for idx , text_encoder in enumerate (self .text_encoders ):
962- assert text_encoder .text_model .embeddings .token_embedding .weight .data .shape [0 ] == len (
963- self .tokenizers [0 ]
964- ), "Tokenizers should be the same."
965- new_token_embeddings = text_encoder .text_model .embeddings .token_embedding .weight .data [self .train_ids ]
971+ assert (
972+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
973+ ).embeddings .token_embedding .weight .data .shape [0 ] == len (self .tokenizers [0 ]), (
974+ "Tokenizers should be the same."
975+ )
976+ new_token_embeddings = (
977+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
978+ ).embeddings .token_embedding .weight .data [self .train_ids ]
966979
967980 # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
968981 # text_encoder 1) to keep compatible with the ecosystem.
@@ -984,7 +997,9 @@ def device(self):
984997 def retract_embeddings (self ):
985998 for idx , text_encoder in enumerate (self .text_encoders ):
986999 index_no_updates = self .embeddings_settings [f"index_no_updates_{ idx } " ]
987- text_encoder .text_model .embeddings .token_embedding .weight .data [index_no_updates ] = (
1000+ (
1001+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
1002+ ).embeddings .token_embedding .weight .data [index_no_updates ] = (
9881003 self .embeddings_settings [f"original_embeddings_{ idx } " ][index_no_updates ]
9891004 .to (device = text_encoder .device )
9901005 .to (dtype = text_encoder .dtype )
@@ -995,11 +1010,15 @@ def retract_embeddings(self):
9951010 std_token_embedding = self .embeddings_settings [f"std_token_embedding_{ idx } " ]
9961011
9971012 index_updates = ~ index_no_updates
998- new_embeddings = text_encoder .text_model .embeddings .token_embedding .weight .data [index_updates ]
1013+ new_embeddings = (
1014+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
1015+ ).embeddings .token_embedding .weight .data [index_updates ]
9991016 off_ratio = std_token_embedding / new_embeddings .std ()
10001017
10011018 new_embeddings = new_embeddings * (off_ratio ** 0.1 )
1002- text_encoder .text_model .embeddings .token_embedding .weight .data [index_updates ] = new_embeddings
1019+ (
1020+ text_encoder .text_model if hasattr (text_encoder , "text_model" ) else text_encoder
1021+ ).embeddings .token_embedding .weight .data [index_updates ] = new_embeddings
10031022
10041023
10051024class DreamBoothDataset (Dataset ):
@@ -2083,8 +2102,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20832102 text_encoder_two .train ()
20842103 # set top parameter requires_grad = True for gradient checkpointing works
20852104 if args .train_text_encoder :
2086- accelerator .unwrap_model (text_encoder_one ).text_model .embeddings .requires_grad_ (True )
2087- accelerator .unwrap_model (text_encoder_two ).text_model .embeddings .requires_grad_ (True )
2105+ _te_one = accelerator .unwrap_model (text_encoder_one )
2106+ (_te_one .text_model if hasattr (_te_one , "text_model" ) else _te_one ).embeddings .requires_grad_ (True )
2107+ _te_two = accelerator .unwrap_model (text_encoder_two )
2108+ (_te_two .text_model if hasattr (_te_two , "text_model" ) else _te_two ).embeddings .requires_grad_ (True )
20882109
20892110 for step , batch in enumerate (train_dataloader ):
20902111 if pivoted :
0 commit comments