2222import safetensors .torch
2323import torch
2424
25- from ..utils import get_logger , is_accelerate_available
25+ from ..utils import get_logger , is_accelerate_available , is_torchao_available
2626from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2727from .hooks import HookRegistry , ModelHook
2828
3535logger = get_logger (__name__ ) # pylint: disable=invalid-name
3636
3737
38+ def _is_torchao_tensor (tensor : torch .Tensor ) -> bool :
39+ if not is_torchao_available ():
40+ return False
41+ from torchao .utils import TorchAOBaseTensor
42+
43+ return isinstance (tensor , TorchAOBaseTensor )
44+
45+
46+ def _get_torchao_inner_tensor_names (tensor : torch .Tensor ) -> list [str ]:
47+ """Get names of all internal tensor data attributes from a TorchAO tensor."""
48+ cls = type (tensor )
49+ names = list (getattr (cls , "tensor_data_names" , []))
50+ for attr_name in getattr (cls , "optional_tensor_data_names" , []):
51+ if getattr (tensor , attr_name , None ) is not None :
52+ names .append (attr_name )
53+ return names
54+
55+
56+ def _swap_torchao_tensor (param : torch .Tensor , source : torch .Tensor ) -> None :
57+ """Move a TorchAO parameter to the device of `source` via `swap_tensors`.
58+
59+ `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
60+ the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
61+ original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
62+ that any dict keyed by `id(param)` remains valid.
63+
64+ Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
65+ """
66+ torch .utils .swap_tensors (param , source )
67+
68+
69+ def _restore_torchao_tensor (param : torch .Tensor , source : torch .Tensor ) -> None :
70+ """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
71+
72+ Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
73+ modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
74+ `cpu_param_dict`).
75+ """
76+ for attr_name in _get_torchao_inner_tensor_names (source ):
77+ setattr (param , attr_name , getattr (source , attr_name ))
78+
79+
80+ def _record_stream_torchao_tensor (param : torch .Tensor , stream ) -> None :
81+ """Record stream for all internal tensors of a TorchAO parameter."""
82+ for attr_name in _get_torchao_inner_tensor_names (param ):
83+ getattr (param , attr_name ).record_stream (stream )
84+
85+
3886# fmt: off
3987_GROUP_OFFLOADING = "group_offloading"
4088_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -124,24 +172,29 @@ def __init__(
124172 else torch .cuda
125173 )
126174
175+ @staticmethod
176+ def _to_cpu (tensor , low_cpu_mem_usage ):
177+ # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
178+ # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
179+ t = tensor .cpu () if _is_torchao_tensor (tensor ) else tensor .data .cpu ()
180+ return t if low_cpu_mem_usage else t .pin_memory ()
181+
127182 def _init_cpu_param_dict (self ):
128183 cpu_param_dict = {}
129184 if self .stream is None :
130185 return cpu_param_dict
131186
132187 for module in self .modules :
133188 for param in module .parameters ():
134- cpu_param_dict [param ] = param . data . cpu () if self .low_cpu_mem_usage else param . data . cpu (). pin_memory ( )
189+ cpu_param_dict [param ] = self . _to_cpu ( param , self .low_cpu_mem_usage )
135190 for buffer in module .buffers ():
136- cpu_param_dict [buffer ] = (
137- buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
138- )
191+ cpu_param_dict [buffer ] = self ._to_cpu (buffer , self .low_cpu_mem_usage )
139192
140193 for param in self .parameters :
141- cpu_param_dict [param ] = param . data . cpu () if self .low_cpu_mem_usage else param . data . cpu (). pin_memory ( )
194+ cpu_param_dict [param ] = self . _to_cpu ( param , self .low_cpu_mem_usage )
142195
143196 for buffer in self .buffers :
144- cpu_param_dict [buffer ] = buffer . data . cpu () if self .low_cpu_mem_usage else buffer . data . cpu (). pin_memory ( )
197+ cpu_param_dict [buffer ] = self . _to_cpu ( buffer , self .low_cpu_mem_usage )
145198
146199 return cpu_param_dict
147200
@@ -157,9 +210,16 @@ def _pinned_memory_tensors(self):
157210 pinned_dict = None
158211
159212 def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream ):
160- tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
213+ moved = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
214+ if _is_torchao_tensor (tensor ):
215+ _swap_torchao_tensor (tensor , moved )
216+ else :
217+ tensor .data = moved
161218 if self .record_stream :
162- tensor .data .record_stream (default_stream )
219+ if _is_torchao_tensor (tensor ):
220+ _record_stream_torchao_tensor (tensor , default_stream )
221+ else :
222+ tensor .data .record_stream (default_stream )
163223
164224 def _process_tensors_from_modules (self , pinned_memory = None , default_stream = None ):
165225 for group_module in self .modules :
@@ -178,7 +238,19 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None)
178238 source = pinned_memory [buffer ] if pinned_memory else buffer .data
179239 self ._transfer_tensor_to_device (buffer , source , default_stream )
180240
241+ def _check_disk_offload_torchao (self ):
242+ all_tensors = list (self .tensor_to_key .keys ())
243+ has_torchao = any (_is_torchao_tensor (t ) for t in all_tensors )
244+ if has_torchao :
245+ raise ValueError (
246+ "Disk offloading is not supported for TorchAO quantized tensors because safetensors "
247+ "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
248+ "setting `offload_to_disk_path`."
249+ )
250+
181251 def _onload_from_disk (self ):
252+ self ._check_disk_offload_torchao ()
253+
182254 if self .stream is not None :
183255 # Wait for previous Host->Device transfer to complete
184256 self .stream .synchronize ()
@@ -221,6 +293,8 @@ def _onload_from_memory(self):
221293 self ._process_tensors_from_modules (None )
222294
223295 def _offload_to_disk (self ):
296+ self ._check_disk_offload_torchao ()
297+
224298 # TODO: we can potentially optimize this code path by checking if the _all_ the desired
225299 # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
226300 # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
@@ -245,18 +319,35 @@ def _offload_to_memory(self):
245319
246320 for group_module in self .modules :
247321 for param in group_module .parameters ():
248- param .data = self .cpu_param_dict [param ]
322+ if _is_torchao_tensor (param ):
323+ _restore_torchao_tensor (param , self .cpu_param_dict [param ])
324+ else :
325+ param .data = self .cpu_param_dict [param ]
249326 for param in self .parameters :
250- param .data = self .cpu_param_dict [param ]
327+ if _is_torchao_tensor (param ):
328+ _restore_torchao_tensor (param , self .cpu_param_dict [param ])
329+ else :
330+ param .data = self .cpu_param_dict [param ]
251331 for buffer in self .buffers :
252- buffer .data = self .cpu_param_dict [buffer ]
332+ if _is_torchao_tensor (buffer ):
333+ _restore_torchao_tensor (buffer , self .cpu_param_dict [buffer ])
334+ else :
335+ buffer .data = self .cpu_param_dict [buffer ]
253336 else :
254337 for group_module in self .modules :
255338 group_module .to (self .offload_device , non_blocking = False )
256339 for param in self .parameters :
257- param .data = param .data .to (self .offload_device , non_blocking = False )
340+ if _is_torchao_tensor (param ):
341+ moved = param .to (self .offload_device , non_blocking = False )
342+ _swap_torchao_tensor (param , moved )
343+ else :
344+ param .data = param .data .to (self .offload_device , non_blocking = False )
258345 for buffer in self .buffers :
259- buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
346+ if _is_torchao_tensor (buffer ):
347+ moved = buffer .to (self .offload_device , non_blocking = False )
348+ _swap_torchao_tensor (buffer , moved )
349+ else :
350+ buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
260351
261352 @torch .compiler .disable ()
262353 def onload_ (self ):
0 commit comments