Skip to content

Commit 4548e68

Browse files
akshan-maingithub-actions[bot]dg845sayakpaul
authored
Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage (#13406)
* Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage * Apply style fixes * use lru_cache_unless_export --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent b80d3f6 commit 4548e68

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def rope_params(self, index, dim, theta=10000):
233233
freqs = torch.polar(torch.ones_like(freqs), freqs)
234234
return freqs
235235

236+
@lru_cache_unless_export(maxsize=None)
237+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
238+
"""Return pos_freqs and neg_freqs on the given device."""
239+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
240+
236241
def forward(
237242
self,
238243
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -300,8 +305,9 @@ def forward(
300305
max_vid_index = max(height, width, max_vid_index)
301306

302307
max_txt_seq_len_int = int(max_txt_seq_len)
303-
# Create device-specific copy for text freqs without modifying self.pos_freqs
304-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
308+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
309+
pos_freqs_device, _ = self._get_device_freqs(device)
310+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
305311
vid_freqs = torch.cat(vid_freqs, dim=0)
306312

307313
return vid_freqs, txt_freqs
@@ -311,8 +317,9 @@ def _compute_video_freqs(
311317
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
312318
) -> torch.Tensor:
313319
seq_lens = frame * height * width
314-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
315-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
320+
pos_freqs, neg_freqs = (
321+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
322+
)
316323

317324
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
318325
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -367,6 +374,11 @@ def rope_params(self, index, dim, theta=10000):
367374
freqs = torch.polar(torch.ones_like(freqs), freqs)
368375
return freqs
369376

377+
@lru_cache_unless_export(maxsize=None)
378+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
379+
"""Return pos_freqs and neg_freqs on the given device."""
380+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
381+
370382
def forward(
371383
self,
372384
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -421,17 +433,19 @@ def forward(
421433

422434
max_vid_index = max(max_vid_index, layer_num)
423435
max_txt_seq_len_int = int(max_txt_seq_len)
424-
# Create device-specific copy for text freqs without modifying self.pos_freqs
425-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
436+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
437+
pos_freqs_device, _ = self._get_device_freqs(device)
438+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
426439
vid_freqs = torch.cat(vid_freqs, dim=0)
427440

428441
return vid_freqs, txt_freqs
429442

430443
@lru_cache_unless_export(maxsize=None)
431444
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
432445
seq_lens = frame * height * width
433-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
434-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
446+
pos_freqs, neg_freqs = (
447+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
448+
)
435449

436450
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
437451
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -452,8 +466,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466
@lru_cache_unless_export(maxsize=None)
453467
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
454468
seq_lens = frame * height * width
455-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
456-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
469+
pos_freqs, neg_freqs = (
470+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
471+
)
457472

458473
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
459474
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

0 commit comments

Comments
 (0)