Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def get_qwen_prompt_embeds_edit(
).to(device)

outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -173,15 +173,15 @@ def get_qwen_prompt_embeds_edit_plus(
return_tensors="pt",
).to(device)
outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down
20 changes: 17 additions & 3 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,22 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
logger.warning(
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
" used to generate `prompt_embeds`."
)
Comment on lines +315 to +320
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.warning(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
logger.warning(
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
" used to generate `prompt_embeds`."
)

Not sure about the exact wording, but I think the warning here should explain when not passing the corresponding mask is valid, and when passing it is necessary.


if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
logger.warning(
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
" used to generate `negative_prompt_embeds`."
)
Comment on lines +323 to +328
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.warning(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
logger.warning(
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model"
" will treat all negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you"
" should provide the padding mask as `negative_prompt_embeds_mask`. Make sure to generate"
" `negative_prompt_embeds_mask` from the same text encoder that was used to generate"
" `negative_prompt_embeds`."
)

Analogous suggestion to #13379 (comment) for negative_prompt_embeds/negative_prompt_embeds_mask.


if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -584,9 +600,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we handle negative_prompt_embeds_mask? 👀 Is this block affected?

if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask.all():
prompt_embeds_mask = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If negative_prompt_embeds_mask isn't provided (is None), that block is safely skipped, and encode_prompt just returns None for the mask. The transformer model later receives this None mask and handles it natively by treating all tokens as valid.

Interestingly, while working on this, I found that some other variants (like edit and inpaint) were actually missing this exact if prompt_embeds_mask is not None: check and were crashing on .repeat(). I've added the same check to those pipelines in this PR as well so they all handle None masks gracefully now.


if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
30 changes: 23 additions & 7 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"""


# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 128

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
Expand All @@ -248,7 +248,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor

return split_result

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
def _get_qwen_prompt_embeds(
self,
prompt: str | list[str] = None,
Expand Down Expand Up @@ -287,7 +287,7 @@ def _get_qwen_prompt_embeds(

return prompt_embeds, encoder_attention_mask

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: str | list[str],
Expand Down Expand Up @@ -318,11 +318,13 @@ def encode_prompt(
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

Expand Down Expand Up @@ -374,6 +376,22 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
logger.warning(
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
" used to generate `prompt_embeds`."
)

if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
logger.warning(
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
" used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -700,9 +718,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"""


# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(
self.prompt_template_encode_start_idx = 34
self.default_sample_size = 128

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
Expand All @@ -230,7 +230,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor

return split_result

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
def _get_qwen_prompt_embeds(
self,
prompt: str | list[str] = None,
Expand All @@ -247,7 +247,7 @@ def _get_qwen_prompt_embeds(
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
Expand All @@ -269,7 +269,7 @@ def _get_qwen_prompt_embeds(

return prompt_embeds, encoder_attention_mask

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: str | list[str],
Expand All @@ -280,6 +280,7 @@ def encode_prompt(
max_sequence_length: int = 1024,
):
r"""

Args:
prompt (`str` or `list[str]`, *optional*):
prompt to be encoded
Expand All @@ -299,14 +300,18 @@ def encode_prompt(
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None
Comment on lines +308 to +314
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change? Since it already has

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt

if we run make fix-copies, the changes would be propagated automatically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert the manual edits on the copied methods, run make fix-copies to sync them up properly, and update the PR.


return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -354,12 +359,19 @@ def check_inputs(
)

if prompt_embeds is not None and prompt_embeds_mask is None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning behind this?

What happens when users pass the embeds and not the masks? Maybe we should warn them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the hard ValueError here because the base QwenImagePipeline actually allows passing embeds without masks, and the underlying transformer natively handles a None mask by treating all tokens as valid. Throwing a strict exception here was completely blocking users from doing exactly what this PR is fixing (passing negative_prompt_embeds without a mask to trigger True CFG).

That said, I totally agree we shouldn't just let it pass silently, especially since the text encoder's output often relies on masks for sequence packing. I'll swap these hard exceptions out for a logger.warning to let users know they should probably pass the mask if they have it. I'll get that updated!

raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
logger.warning(
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
" used to generate `prompt_embeds`."
)

if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
logger.warning(
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
" used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
Expand Down Expand Up @@ -739,9 +751,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
Expand Down
39 changes: 23 additions & 16 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ def _get_qwen_prompt_embeds(
).to(device)

outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -306,11 +306,13 @@ def encode_prompt(
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -358,12 +360,19 @@ def check_inputs(
)

if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
logger.warning(
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
" used to generate `prompt_embeds`."
)

if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
logger.warning(
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
" used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
Expand Down Expand Up @@ -705,9 +714,7 @@ def __call__(
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
image = image.unsqueeze(2)

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Loading
Loading