@@ -180,7 +180,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
180180 feat_cache [idx ] = "Rep"
181181 feat_idx [0 ] += 1
182182 else :
183- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
183+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
184184 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
185185 # cache last frame of last two chunk
186186 cache_x = torch .cat (
@@ -258,7 +258,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
258258
259259 if feat_cache is not None :
260260 idx = feat_idx [0 ]
261- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
261+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
262262 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
263263 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
264264
@@ -277,7 +277,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
277277
278278 if feat_cache is not None :
279279 idx = feat_idx [0 ]
280- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
280+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
281281 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
282282 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
283283
@@ -446,7 +446,7 @@ def __init__(
446446 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
447447 if feat_cache is not None :
448448 idx = feat_idx [0 ]
449- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
449+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
450450 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
451451 # cache last frame of last two chunk
452452 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
@@ -471,7 +471,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
471471 x = self .nonlinearity (x )
472472 if feat_cache is not None :
473473 idx = feat_idx [0 ]
474- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
474+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
475475 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
476476 # cache last frame of last two chunk
477477 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
@@ -636,7 +636,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
636636 ## conv1
637637 if feat_cache is not None :
638638 idx = feat_idx [0 ]
639- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
639+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
640640 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
641641 # cache last frame of last two chunk
642642 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
@@ -658,7 +658,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
658658 x = self .nonlinearity (x )
659659 if feat_cache is not None :
660660 idx = feat_idx [0 ]
661- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
661+ cache_x = x [:, :, - min ( CACHE_T , x . shape [ 2 ]) :, :, :].clone ()
662662 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
663663 # cache last frame of last two chunk
664664 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
0 commit comments