-
Notifications
You must be signed in to change notification settings - Fork 6.9k
fix(qwen): fix CFG failing when passing neg prompt embeds with none mask #13379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dbe4f95
d7814c3
f600a36
5b2a1d2
c774885
908c304
eb86a2e
42587ae
7fa33ee
17938f5
30e10dd
4f736f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -311,6 +311,22 @@ def check_inputs( | |||||||||||||||||||||
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if prompt_embeds is not None and prompt_embeds_mask is None: | ||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||
| "`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all" | ||||||||||||||||||||||
| " prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as" | ||||||||||||||||||||||
| " `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was" | ||||||||||||||||||||||
| " used to generate `prompt_embeds`." | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: | ||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||
| "`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all" | ||||||||||||||||||||||
| " negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as" | ||||||||||||||||||||||
| " `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was" | ||||||||||||||||||||||
| " used to generate `negative_prompt_embeds`." | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+323
to
+328
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Analogous suggestion to #13379 (comment) for |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if max_sequence_length is not None and max_sequence_length > 1024: | ||||||||||||||||||||||
| raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -584,9 +600,7 @@ def __call__( | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| device = self._execution_device | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| has_neg_prompt = negative_prompt is not None or ( | ||||||||||||||||||||||
| negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we handle diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py Lines 261 to 267 in 71a6fd9
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Interestingly, while working on this, I found that some other variants (like |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if true_cfg_scale > 1 and not has_neg_prompt: | ||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -74,7 +74,7 @@ | |||
| """ | ||||
|
|
||||
|
|
||||
| # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift | ||||
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift | ||||
| def calculate_shift( | ||||
| image_seq_len, | ||||
| base_seq_len: int = 256, | ||||
|
|
@@ -221,7 +221,7 @@ def __init__( | |||
| self.prompt_template_encode_start_idx = 34 | ||||
| self.default_sample_size = 128 | ||||
|
|
||||
| # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden | ||||
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden | ||||
| def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | ||||
| bool_mask = mask.bool() | ||||
| valid_lengths = bool_mask.sum(dim=1) | ||||
|
|
@@ -230,7 +230,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor | |||
|
|
||||
| return split_result | ||||
|
|
||||
| # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds | ||||
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds | ||||
| def _get_qwen_prompt_embeds( | ||||
| self, | ||||
| prompt: str | list[str] = None, | ||||
|
|
@@ -247,7 +247,7 @@ def _get_qwen_prompt_embeds( | |||
| txt = [template.format(e) for e in prompt] | ||||
| txt_tokens = self.tokenizer( | ||||
| txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" | ||||
| ).to(self.device) | ||||
| ).to(device) | ||||
| encoder_hidden_states = self.text_encoder( | ||||
| input_ids=txt_tokens.input_ids, | ||||
| attention_mask=txt_tokens.attention_mask, | ||||
|
|
@@ -269,7 +269,7 @@ def _get_qwen_prompt_embeds( | |||
|
|
||||
| return prompt_embeds, encoder_attention_mask | ||||
|
|
||||
| # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt | ||||
| # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt | ||||
| def encode_prompt( | ||||
| self, | ||||
| prompt: str | list[str], | ||||
|
|
@@ -280,6 +280,7 @@ def encode_prompt( | |||
| max_sequence_length: int = 1024, | ||||
| ): | ||||
| r""" | ||||
|
|
||||
| Args: | ||||
| prompt (`str` or `list[str]`, *optional*): | ||||
| prompt to be encoded | ||||
|
|
@@ -299,14 +300,18 @@ def encode_prompt( | |||
| if prompt_embeds is None: | ||||
| prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) | ||||
|
|
||||
| prompt_embeds = prompt_embeds[:, :max_sequence_length] | ||||
| _, seq_len, _ = prompt_embeds.shape | ||||
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | ||||
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | ||||
| prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) | ||||
| prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) | ||||
|
|
||||
| if prompt_embeds_mask is not None and prompt_embeds_mask.all(): | ||||
| prompt_embeds_mask = None | ||||
| if prompt_embeds_mask is not None: | ||||
| prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] | ||||
| prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) | ||||
| prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) | ||||
|
|
||||
| if prompt_embeds_mask.all(): | ||||
| prompt_embeds_mask = None | ||||
|
Comment on lines
+308
to
+314
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like an unrelated change? Since it already has diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py Line 272 in 71a6fd9
if we run
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll revert the manual edits on the copied methods, run make fix-copies to sync them up properly, and update the PR. |
||||
|
|
||||
| return prompt_embeds, prompt_embeds_mask | ||||
|
|
||||
|
|
@@ -354,12 +359,19 @@ def check_inputs( | |||
| ) | ||||
|
|
||||
| if prompt_embeds is not None and prompt_embeds_mask is None: | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reasoning behind this? What happens when users pass the embeds and not the masks? Maybe we should warn them?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the hard ValueError here because the base QwenImagePipeline actually allows passing embeds without masks, and the underlying transformer natively handles a None mask by treating all tokens as valid. Throwing a strict exception here was completely blocking users from doing exactly what this PR is fixing (passing negative_prompt_embeds without a mask to trigger True CFG). That said, I totally agree we shouldn't just let it pass silently, especially since the text encoder's output often relies on masks for sequence packing. I'll swap these hard exceptions out for a logger.warning to let users know they should probably pass the mask if they have it. I'll get that updated! |
||||
| raise ValueError( | ||||
| "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." | ||||
| logger.warning( | ||||
| "`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all" | ||||
| " prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as" | ||||
| " `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was" | ||||
| " used to generate `prompt_embeds`." | ||||
| ) | ||||
|
|
||||
| if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: | ||||
| raise ValueError( | ||||
| "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." | ||||
| logger.warning( | ||||
| "`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all" | ||||
| " negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as" | ||||
| " `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was" | ||||
| " used to generate `negative_prompt_embeds`." | ||||
| ) | ||||
|
|
||||
| if max_sequence_length is not None and max_sequence_length > 1024: | ||||
|
|
@@ -739,9 +751,7 @@ def __call__( | |||
|
|
||||
| device = self._execution_device | ||||
|
|
||||
| has_neg_prompt = negative_prompt is not None or ( | ||||
| negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None | ||||
| ) | ||||
| has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None | ||||
| do_true_cfg = true_cfg_scale > 1 and has_neg_prompt | ||||
| prompt_embeds, prompt_embeds_mask = self.encode_prompt( | ||||
| prompt=prompt, | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the exact wording, but I think the warning here should explain when not passing the corresponding mask is valid, and when passing it is necessary.