Skip to content

fix: cache warmup RuntimeError on mps#46239

Merged
ArthurZucker merged 1 commit into
mainfrom
fix/cache_warmup_runtime_error_mps
Jun 1, 2026
Merged

fix: cache warmup RuntimeError on mps#46239
ArthurZucker merged 1 commit into
mainfrom
fix/cache_warmup_runtime_error_mps

Conversation

@McPatate

@McPatate McPatate commented May 27, 2026

Copy link
Copy Markdown
Member

Skip warmup on MPS: there is a limit of the maximum size a single buffer can have on MPS, which from testing seems to be about 2/3 of the total device memory (tested on apple silicon). This causes the warmup function to return a RuntimeError: Invalid buffer size: XX.XX GiB.

NOTE: not tested on intel macs, but I assume the same issue arises since it's also an mps backend.

EDIT: from this old thread, it appears the mtlbuffer limit on intel macs is capped to a hardcoded value, so this PR is even more so needed on that platform.

Running:

import torch
from transformers import AutoModelForCausalLM


AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-14B", dtype=torch.bfloat16, device_map="mps")

yields:

>>> transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-14B", dtype=torch.bfloat16, device_map="mps")
W0528 14:58:14.629000 91467 torch/distributed/elastic/multiprocessing/redirects.py:35] NOTE: Redirects are currently not supported in MacOs.
Traceback (most recent call last):
  File "<python-input-4>", line 1, in <module>
    transformers.AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-14B", dtype=torch.bfloat16, device_map="mps")
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/transformers/src/transformers/models/auto/auto_factory.py", line 405, in from_pretrained
    return model_class.from_pretrained(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/transformers/src/transformers/modeling_utils.py", line 4382, in from_pretrained
    loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
                                       ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/transformers/src/transformers/modeling_utils.py", line 4463, in _load_pretrained_model
    caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/transformers/src/transformers/modeling_utils.py", line 5142, in caching_allocator_warmup
    _ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
RuntimeError: Invalid buffer size: 27.51 GiB

You can run on your machine to get info of the hard limits of your system:

import ctypes

metal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
objc = ctypes.CDLL("/usr/lib/libobjc.dylib")
metal.MTLCreateSystemDefaultDevice.restype = ctypes.c_void_p
objc.sel_registerName.restype = ctypes.c_void_p
objc.sel_registerName.argtypes = [ctypes.c_char_p]
objc.objc_msgSend.restype = ctypes.c_ulonglong
objc.objc_msgSend.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
dev = metal.MTLCreateSystemDefaultDevice()


def u(sel):
    return objc.objc_msgSend(dev, objc.sel_registerName(sel))


maxbuf = u(b"maxBufferLength")
wss = u(b"recommendedMaxWorkingSetSize")
GB = 1024**3
print(f"maxBufferLength              = {maxbuf / GB:6.2f} GiB")
print(f"recommendedMaxWorkingSetSize = {wss / GB:6.2f} GiB")
print("warmup tried to allocate     =  27.51 GiB (single buffer)")
print(f"exceeds maxBufferLength?     = {maxbuf / GB < 27.51}")

Mine shows:

maxBufferLength              =  21.06 GiB
recommendedMaxWorkingSetSize =  28.08 GiB
warmup tried to allocate     =  27.51 GiB (single buffer)
exceeds maxBufferLength?     = True

which is coherent with the RuntimeError I get in the caching_allocator_warmup fn. Note that I have the space on my machine to load such a model, I'm at 36gb, and even if were to go a little above, I know I can fit a large amount of data in swap anyways (~36gb before my os OOMs the process).

@McPatate McPatate requested a review from ArthurZucker May 27, 2026 12:17
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +5142 to +5146
# Skip warmup on MPS: there is a limit of the maximum size a single buffer can have on MPS,
# which from testing seems to be about 2/3 of the total device memory (tested on apple silicon).
# This causes the warmup function to return a `RuntimeError: Invalid buffer size: XX.XX GiB`.
# NOTE: not tested on intel macs
continue

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide a repro, it does not fail for me AFAIK! Loading a big mixtral to max capa!
Also no skip maybe reduce allocate + bench speed loss please

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't need pre-allocation with safetensors on mps once safetensors/safetensors#767 is merged. We allocate the mtlbuffers, fill them with pread and then hand them 0-copy to torch with dlpack. So as we don't go through torch's allocation stack, it's going to become unnecessary, at least for mps.

As we discussed by message, you in fact cannot allocate a buffer of size over 58gb out of your 96 available. It'd be interesting to see what total_byte_count's value is when you load your Mixtral model.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more details in the PR desc

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I can see your load not crash is because you don't set device_map="mps", which skips the cache warmup function altogether.

@hclsys

This comment was marked as low quality.

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty confirmed on MPS its not slower anyways!

@ArthurZucker ArthurZucker merged commit 5390fc3 into main Jun 1, 2026
116 of 162 checks passed
@ArthurZucker ArthurZucker deleted the fix/cache_warmup_runtime_error_mps branch June 1, 2026 09:59
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
khushali9 pushed a commit to khushali9/transformers that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants