@@ -233,6 +233,11 @@ def rope_params(self, index, dim, theta=10000):
233233 freqs = torch .polar (torch .ones_like (freqs ), freqs )
234234 return freqs
235235
236+ @lru_cache_unless_export (maxsize = None )
237+ def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
238+ """Return pos_freqs and neg_freqs on the given device."""
239+ return self .pos_freqs .to (device ), self .neg_freqs .to (device )
240+
236241 def forward (
237242 self ,
238243 video_fhw : tuple [int , int , int , list [tuple [int , int , int ]]],
@@ -300,8 +305,9 @@ def forward(
300305 max_vid_index = max (height , width , max_vid_index )
301306
302307 max_txt_seq_len_int = int (max_txt_seq_len )
303- # Create device-specific copy for text freqs without modifying self.pos_freqs
304- txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
308+ # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
309+ pos_freqs_device , _ = self ._get_device_freqs (device )
310+ txt_freqs = pos_freqs_device [max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
305311 vid_freqs = torch .cat (vid_freqs , dim = 0 )
306312
307313 return vid_freqs , txt_freqs
@@ -311,8 +317,9 @@ def _compute_video_freqs(
311317 self , frame : int , height : int , width : int , idx : int = 0 , device : torch .device = None
312318 ) -> torch .Tensor :
313319 seq_lens = frame * height * width
314- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
315- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
320+ pos_freqs , neg_freqs = (
321+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
322+ )
316323
317324 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
318325 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -367,6 +374,11 @@ def rope_params(self, index, dim, theta=10000):
367374 freqs = torch .polar (torch .ones_like (freqs ), freqs )
368375 return freqs
369376
377+ @lru_cache_unless_export (maxsize = None )
378+ def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
379+ """Return pos_freqs and neg_freqs on the given device."""
380+ return self .pos_freqs .to (device ), self .neg_freqs .to (device )
381+
370382 def forward (
371383 self ,
372384 video_fhw : tuple [int , int , int , list [tuple [int , int , int ]]],
@@ -421,17 +433,19 @@ def forward(
421433
422434 max_vid_index = max (max_vid_index , layer_num )
423435 max_txt_seq_len_int = int (max_txt_seq_len )
424- # Create device-specific copy for text freqs without modifying self.pos_freqs
425- txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
436+ # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
437+ pos_freqs_device , _ = self ._get_device_freqs (device )
438+ txt_freqs = pos_freqs_device [max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
426439 vid_freqs = torch .cat (vid_freqs , dim = 0 )
427440
428441 return vid_freqs , txt_freqs
429442
430443 @lru_cache_unless_export (maxsize = None )
431444 def _compute_video_freqs (self , frame , height , width , idx = 0 , device : torch .device = None ):
432445 seq_lens = frame * height * width
433- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
434- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
446+ pos_freqs , neg_freqs = (
447+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
448+ )
435449
436450 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
437451 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -452,8 +466,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466 @lru_cache_unless_export (maxsize = None )
453467 def _compute_condition_freqs (self , frame , height , width , device : torch .device = None ):
454468 seq_lens = frame * height * width
455- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
456- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
469+ pos_freqs , neg_freqs = (
470+ self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
471+ )
457472
458473 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
459474 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
0 commit comments