Skip to content

Commit ee1e84b

Browse files
committed
Revert torch_device to cpu for QwenImageEdit tests that run out of memory on MPS
1 parent 7fa33ee commit ee1e84b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/pipelines/qwenimage/test_qwenimage_edit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,16 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
245245
def test_true_cfg_without_negative_prompt_embeds_mask(self):
246246
components = self.get_dummy_components()
247247
pipe = self.pipeline_class(**components)
248-
pipe.to(torch_device)
248+
pipe.to("cpu")
249249
pipe.set_progress_bar_config(disable=None)
250250

251-
inputs = self.get_dummy_inputs(torch_device)
251+
inputs = self.get_dummy_inputs("cpu")
252252
prompt = inputs.pop("prompt")
253253

254254
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
255255
prompt=prompt,
256256
image=inputs.get("image") if "image" in inputs else None,
257-
device=torch_device,
257+
device="cpu",
258258
num_images_per_prompt=1,
259259
max_sequence_length=inputs.get("max_sequence_length", 16),
260260
)

tests/pipelines/qwenimage/test_qwenimage_edit_plus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,16 @@ def test_inference_batch_single_identical():
255255
def test_true_cfg_without_negative_prompt_embeds_mask(self):
256256
components = self.get_dummy_components()
257257
pipe = self.pipeline_class(**components)
258-
pipe.to(torch_device)
258+
pipe.to("cpu")
259259
pipe.set_progress_bar_config(disable=None)
260260

261-
inputs = self.get_dummy_inputs(torch_device)
261+
inputs = self.get_dummy_inputs("cpu")
262262
prompt = inputs.pop("prompt")
263263

264264
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
265265
prompt=prompt,
266266
image=inputs.get("image") if "image" in inputs else None,
267-
device=torch_device,
267+
device="cpu",
268268
num_images_per_prompt=1,
269269
max_sequence_length=inputs.get("max_sequence_length", 16),
270270
)

0 commit comments

Comments
 (0)