Skip to content

Commit 33a1317

Browse files
authored
[core] fix autoencoderkl qwenimage for xla (#13480)
fix autoencoderkl qwenimage for xla
1 parent 947bc23 commit 33a1317

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
180180
feat_cache[idx] = "Rep"
181181
feat_idx[0] += 1
182182
else:
183-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
183+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
184184
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
185185
# cache last frame of last two chunk
186186
cache_x = torch.cat(
@@ -258,7 +258,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
258258

259259
if feat_cache is not None:
260260
idx = feat_idx[0]
261-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
261+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
262262
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
263263
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
264264

@@ -277,7 +277,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
277277

278278
if feat_cache is not None:
279279
idx = feat_idx[0]
280-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
280+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
281281
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
282282
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
283283

@@ -446,7 +446,7 @@ def __init__(
446446
def forward(self, x, feat_cache=None, feat_idx=[0]):
447447
if feat_cache is not None:
448448
idx = feat_idx[0]
449-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
449+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
450450
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
451451
# cache last frame of last two chunk
452452
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -471,7 +471,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
471471
x = self.nonlinearity(x)
472472
if feat_cache is not None:
473473
idx = feat_idx[0]
474-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
474+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
475475
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
476476
# cache last frame of last two chunk
477477
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -636,7 +636,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
636636
## conv1
637637
if feat_cache is not None:
638638
idx = feat_idx[0]
639-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
639+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
640640
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
641641
# cache last frame of last two chunk
642642
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
@@ -658,7 +658,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
658658
x = self.nonlinearity(x)
659659
if feat_cache is not None:
660660
idx = feat_idx[0]
661-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
661+
cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone()
662662
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
663663
# cache last frame of last two chunk
664664
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

0 commit comments

Comments
 (0)