Skip to content

Commit b8ec64c

Browse files
authored
[core] fix group offloading when using torchao (#13276)
* fix group offloading when using torchao * switch to swap_tensors. * up * address feedback. * error out for the offload to disk option.
1 parent c39fba2 commit b8ec64c

File tree

1 file changed

+105
-14
lines changed

1 file changed

+105
-14
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import safetensors.torch
2323
import torch
2424

25-
from ..utils import get_logger, is_accelerate_available
25+
from ..utils import get_logger, is_accelerate_available, is_torchao_available
2626
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2727
from .hooks import HookRegistry, ModelHook
2828

@@ -35,6 +35,54 @@
3535
logger = get_logger(__name__) # pylint: disable=invalid-name
3636

3737

38+
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
39+
if not is_torchao_available():
40+
return False
41+
from torchao.utils import TorchAOBaseTensor
42+
43+
return isinstance(tensor, TorchAOBaseTensor)
44+
45+
46+
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
47+
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
48+
cls = type(tensor)
49+
names = list(getattr(cls, "tensor_data_names", []))
50+
for attr_name in getattr(cls, "optional_tensor_data_names", []):
51+
if getattr(tensor, attr_name, None) is not None:
52+
names.append(attr_name)
53+
return names
54+
55+
56+
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
57+
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
58+
59+
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
60+
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
61+
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
62+
that any dict keyed by `id(param)` remains valid.
63+
64+
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
65+
"""
66+
torch.utils.swap_tensors(param, source)
67+
68+
69+
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
70+
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
71+
72+
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
73+
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
74+
`cpu_param_dict`).
75+
"""
76+
for attr_name in _get_torchao_inner_tensor_names(source):
77+
setattr(param, attr_name, getattr(source, attr_name))
78+
79+
80+
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
81+
"""Record stream for all internal tensors of a TorchAO parameter."""
82+
for attr_name in _get_torchao_inner_tensor_names(param):
83+
getattr(param, attr_name).record_stream(stream)
84+
85+
3886
# fmt: off
3987
_GROUP_OFFLOADING = "group_offloading"
4088
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -124,24 +172,29 @@ def __init__(
124172
else torch.cuda
125173
)
126174

175+
@staticmethod
176+
def _to_cpu(tensor, low_cpu_mem_usage):
177+
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
178+
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
179+
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
180+
return t if low_cpu_mem_usage else t.pin_memory()
181+
127182
def _init_cpu_param_dict(self):
128183
cpu_param_dict = {}
129184
if self.stream is None:
130185
return cpu_param_dict
131186

132187
for module in self.modules:
133188
for param in module.parameters():
134-
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
189+
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
135190
for buffer in module.buffers():
136-
cpu_param_dict[buffer] = (
137-
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
138-
)
191+
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
139192

140193
for param in self.parameters:
141-
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
194+
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
142195

143196
for buffer in self.buffers:
144-
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
197+
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
145198

146199
return cpu_param_dict
147200

@@ -157,9 +210,16 @@ def _pinned_memory_tensors(self):
157210
pinned_dict = None
158211

159212
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
160-
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
213+
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
214+
if _is_torchao_tensor(tensor):
215+
_swap_torchao_tensor(tensor, moved)
216+
else:
217+
tensor.data = moved
161218
if self.record_stream:
162-
tensor.data.record_stream(default_stream)
219+
if _is_torchao_tensor(tensor):
220+
_record_stream_torchao_tensor(tensor, default_stream)
221+
else:
222+
tensor.data.record_stream(default_stream)
163223

164224
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
165225
for group_module in self.modules:
@@ -178,7 +238,19 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None)
178238
source = pinned_memory[buffer] if pinned_memory else buffer.data
179239
self._transfer_tensor_to_device(buffer, source, default_stream)
180240

241+
def _check_disk_offload_torchao(self):
242+
all_tensors = list(self.tensor_to_key.keys())
243+
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
244+
if has_torchao:
245+
raise ValueError(
246+
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
247+
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
248+
"setting `offload_to_disk_path`."
249+
)
250+
181251
def _onload_from_disk(self):
252+
self._check_disk_offload_torchao()
253+
182254
if self.stream is not None:
183255
# Wait for previous Host->Device transfer to complete
184256
self.stream.synchronize()
@@ -221,6 +293,8 @@ def _onload_from_memory(self):
221293
self._process_tensors_from_modules(None)
222294

223295
def _offload_to_disk(self):
296+
self._check_disk_offload_torchao()
297+
224298
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
225299
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
226300
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
@@ -245,18 +319,35 @@ def _offload_to_memory(self):
245319

246320
for group_module in self.modules:
247321
for param in group_module.parameters():
248-
param.data = self.cpu_param_dict[param]
322+
if _is_torchao_tensor(param):
323+
_restore_torchao_tensor(param, self.cpu_param_dict[param])
324+
else:
325+
param.data = self.cpu_param_dict[param]
249326
for param in self.parameters:
250-
param.data = self.cpu_param_dict[param]
327+
if _is_torchao_tensor(param):
328+
_restore_torchao_tensor(param, self.cpu_param_dict[param])
329+
else:
330+
param.data = self.cpu_param_dict[param]
251331
for buffer in self.buffers:
252-
buffer.data = self.cpu_param_dict[buffer]
332+
if _is_torchao_tensor(buffer):
333+
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
334+
else:
335+
buffer.data = self.cpu_param_dict[buffer]
253336
else:
254337
for group_module in self.modules:
255338
group_module.to(self.offload_device, non_blocking=False)
256339
for param in self.parameters:
257-
param.data = param.data.to(self.offload_device, non_blocking=False)
340+
if _is_torchao_tensor(param):
341+
moved = param.to(self.offload_device, non_blocking=False)
342+
_swap_torchao_tensor(param, moved)
343+
else:
344+
param.data = param.data.to(self.offload_device, non_blocking=False)
258345
for buffer in self.buffers:
259-
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
346+
if _is_torchao_tensor(buffer):
347+
moved = buffer.to(self.offload_device, non_blocking=False)
348+
_swap_torchao_tensor(buffer, moved)
349+
else:
350+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
260351

261352
@torch.compiler.disable()
262353
def onload_(self):

0 commit comments

Comments
 (0)