From 354052021de053ab54ffd77fabbfb7f3da4874cd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 31 May 2026 23:26:28 -0400 Subject: [PATCH 01/33] feat(app): parallel multi-GPU session execution Run one generation session per configured GPU concurrently, with a tiled progress preview. Multi-user isolation is unchanged. Backed by five seams: - Per-thread device context (TorchDevice.set/get/clear_session_device); choose_torch_device() consults it first, so all device-selecting call sites resolve to the calling worker's GPU with no per-node changes. - Per-device model caches: build_model_manager builds one ModelCache per generation device; ModelLoadService.ram_cache resolves by current thread device; ram_caches fans out clear/drop/shutdown. - Atomic concurrent dequeue: a dequeue lock makes select+claim atomic so concurrent workers never claim the same item (works on FIFO; round-robin from #9086 slots in later). - Worker pool: one _SessionWorker per device, each pinning torch.cuda.set_device and its session device, with its own runner and cancel event; cancellation routes via an {item_id -> worker} lookup. Single-device installs keep the exact legacy single-worker behavior. Profiling disabled when >1 worker. - New config `generation_devices`; unset = legacy single-worker mode. Frontend: the canvas staging area already tiles per queue item; the main ImageViewer now tracks progress per session and renders a tile grid (ProgressImageTiles) when more than one session is active. Also adds a lock to ObjectSerializerForwardCache for concurrent access. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/src/generated/settings.json | 11 + invokeai/app/api/routers/model_manager.py | 11 +- .../app/services/config/config_default.py | 14 ++ .../services/model_load/model_load_base.py | 16 +- .../services/model_load/model_load_default.py | 42 +++- .../model_manager/model_manager_default.py | 48 ++-- .../object_serializer_forward_cache.py | 23 +- .../session_processor_default.py | 208 +++++++++++++----- .../session_queue/session_queue_sqlite.py | 54 +++-- .../load/model_cache/model_cache.py | 5 + invokeai/backend/util/devices.py | 27 +++ .../ImageViewer/CurrentImagePreview.tsx | 19 +- .../ImageViewer/ProgressImageTiles.tsx | 39 ++++ .../components/ImageViewer/context.tsx | 42 +++- .../test_model_load_device_routing.py | 81 +++++++ .../test_session_queue_dequeue_concurrency.py | 70 ++++++ tests/backend/util/test_devices.py | 45 ++++ 17 files changed, 634 insertions(+), 121 deletions(-) create mode 100644 invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx create mode 100644 tests/app/services/model_load/test_model_load_device_routing.py create mode 100644 tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 88a42f8fbcf..eb26d39960f 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -490,6 +490,17 @@ "type": "", "validation": {} }, + { + "category": "DEVICE", + "default": null, + "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", + "env_var": "INVOKEAI_GENERATION_DEVICES", + "literal_values": [], + "name": "generation_devices", + "required": false, + "type": "typing.Optional[list[str]]", + "validation": {} + }, { "category": "DEVICE", "default": "auto", diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index bdd2e406444..53c4c68981f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -443,7 +443,11 @@ async def update_model_record( # nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until # the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds. if _load_settings_changed(previous_config, config): - dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key) + # Drop the model from every per-device cache so the next load on any GPU rebuilds it. + dropped = sum( + cache.drop_model(key) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values() + ) if dropped: logger.info( f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change." @@ -1304,9 +1308,10 @@ async def get_stats() -> Optional[CacheStats]: ) async def empty_model_cache(current_admin: AdminUserOrDefault) -> None: """Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped.""" - # Request 1000GB of room in order to force the cache to drop all models. + # Request 1000GB of room in order to force each per-device cache to drop all models. ApiDependencies.invoker.services.logger.info("Emptying model cache.") - ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values(): + cache.make_room(1000 * 2**30) class HFTokenStatus(str, Enum): diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 57004efca39..a70f5f7e97c 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -205,6 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") + generation_devices: Optional[list[str]] = Field(default=None, description="List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -257,6 +258,19 @@ class InvokeAIAppConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) + @field_validator("generation_devices") + @classmethod + def validate_generation_devices(cls, v: Optional[list[str]]) -> Optional[list[str]]: + if v is None: + return v + pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + for device in v: + if not pattern.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v + def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None: """Updates the config, overwriting existing values. diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 87a405b4ea4..8fc9823328d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -26,7 +26,21 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo @property @abstractmethod def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" + """Return the RAM cache for the current thread's execution device. + + In multi-GPU mode, each session-processor worker is pinned to a device and gets its own + cache; this resolves to the calling thread's cache. Outside a worker (e.g. API threads), + it resolves to the default device's cache. + """ + + @property + @abstractmethod + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string. + + Use this for maintenance operations that must apply to every device (clear cache, drop a + model from all devices, shutdown). + """ @abstractmethod def load_model_from_path( diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 2e2d2ae219d..45d0c354278 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -33,13 +33,25 @@ def __init__( app_config: InvokeAIAppConfig, ram_cache: ModelCache, registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, + ram_caches: Optional[dict[str, ModelCache]] = None, ): - """Initialize the model load service.""" + """Initialize the model load service. + + Args: + ram_cache: The default RAM cache, used when no per-device cache matches the calling + thread (e.g. single-device installs, or API threads). + ram_caches: Optional map of normalized device string -> ModelCache for multi-GPU mode. + One cache per generation device. The default `ram_cache` is always included. + """ logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) self._logger = logger self._app_config = app_config - self._ram_cache = ram_cache + self._default_ram_cache = ram_cache + # Map normalized device string -> cache. Always includes the default cache so that callers + # without a pinned device (API threads) resolve to a valid cache. + self._ram_caches: dict[str, ModelCache] = dict(ram_caches) if ram_caches else {} + self._ram_caches.setdefault(str(TorchDevice.normalize(ram_cache.execution_device)), ram_cache) self._registry = registry def start(self, invoker: Invoker) -> None: @@ -47,8 +59,18 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" - return self._ram_cache + """Return the RAM cache for the calling thread's execution device. + + `choose_torch_device()` is thread-local-aware: a session-processor worker pinned to a GPU + gets that GPU's cache; everything else falls back to the default cache. + """ + key = str(TorchDevice.choose_torch_device()) + return self._ram_caches.get(key, self._default_ram_cache) + + @property + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string.""" + return dict(self._ram_caches) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -67,7 +89,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo loaded_model: LoadedModel = implementation( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=self.ram_cache, ).load_model(model_config, submodel_type) if hasattr(self, "_invoker"): @@ -78,9 +100,11 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: + # Resolve the calling thread's cache once so the whole load uses a single device's cache. + ram_cache = self.ram_cache cache_key = str(model_path) try: - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) except IndexError: pass @@ -110,7 +134,7 @@ def diffusers_load_directory(directory: Path) -> AnyModel: load_class = GenericDiffusersLoader( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=ram_cache, convert_cache=self.convert_cache, ).get_hf_load_class(directory) return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) @@ -124,5 +148,5 @@ def diffusers_load_directory(directory: Path) -> AnyModel: ) assert loader is not None raw_model = loader(model_path) - self._ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 6141a635f4d..eaeb5d4e612 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -60,9 +60,10 @@ def start(self, invoker: Invoker) -> None: service.start(invoker) def stop(self, invoker: Invoker) -> None: - # Shutdown the model cache to cancel any pending timers - if hasattr(self._load, "ram_cache"): - self._load.ram_cache.shutdown() + # Shutdown every per-device model cache to cancel any pending keep-alive timers. + if hasattr(self._load, "ram_caches"): + for cache in self._load.ram_caches.values(): + cache.shutdown() for service in [self._store, self._install, self._load]: if hasattr(service, "stop"): @@ -85,22 +86,39 @@ def build_model_manager( logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) - ram_cache = ModelCache( - execution_device_working_mem_gb=app_config.device_working_mem_gb, - enable_partial_loading=app_config.enable_partial_loading, - keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, - max_ram_cache_size_gb=app_config.max_cache_ram_gb, - max_vram_cache_size_gb=app_config.max_cache_vram_gb, - execution_device=execution_device or TorchDevice.choose_torch_device(), - storage_device="cpu", - log_memory_usage=app_config.log_memory_usage, - logger=logger, - keep_alive_minutes=app_config.model_cache_keep_alive_min, - ) + def build_cache(device: torch.device) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=app_config.device_working_mem_gb, + enable_partial_loading=app_config.enable_partial_loading, + keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, + max_ram_cache_size_gb=app_config.max_cache_ram_gb, + max_vram_cache_size_gb=app_config.max_cache_vram_gb, + execution_device=device, + storage_device="cpu", + log_memory_usage=app_config.log_memory_usage, + logger=logger, + keep_alive_minutes=app_config.model_cache_keep_alive_min, + ) + + # The default cache for callers without a pinned device (API threads, single-device installs). + default_device = execution_device or TorchDevice.choose_torch_device() + ram_cache = build_cache(default_device) + + # In multi-GPU mode, build one independent cache per generation device. Each session-processor + # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own + # cache. The default cache is always included by ModelLoadService. + ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} + if app_config.generation_devices: + for device_str in app_config.generation_devices: + key = str(TorchDevice.normalize(device_str)) + if key not in ram_caches: + ram_caches[key] = build_cache(torch.device(key)) + loader = ModelLoadService( app_config=app_config, ram_cache=ram_cache, registry=ModelLoaderRegistry, + ram_caches=ram_caches, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index b361259a4b1..ae00173e422 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,4 +1,5 @@ from queue import Queue +from threading import Lock from typing import TYPE_CHECKING, Optional, TypeVar from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase @@ -21,6 +22,9 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache: dict[str, T] = {} self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size + # Guards the in-memory cache so concurrent session-processor workers (multi-GPU) can't race + # the check-then-evict in `_set_cache` (which could otherwise raise KeyError on eviction). + self._cache_lock = Lock() def start(self, invoker: "Invoker") -> None: self._invoker = invoker @@ -50,16 +54,19 @@ def save(self, obj: T) -> str: def delete(self, name: str) -> None: self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] + with self._cache_lock: + if name in self._cache: + del self._cache[name] self._on_deleted(name) def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] + with self._cache_lock: + return None if name not in self._cache else self._cache[name] def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) + with self._cache_lock: + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7159c19e746..c6d566255b2 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -5,6 +5,8 @@ from threading import Event as ThreadEvent from typing import Optional +import torch + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, @@ -31,6 +33,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.devices import TorchDevice class DefaultSessionRunner(SessionRunnerBase): @@ -305,6 +308,26 @@ def _on_node_error( ) +class _SessionWorker: + """A single generation worker: one thread, optionally pinned to one device. + + In single-device (legacy) mode there is exactly one worker with `device=None`. In multi-GPU + mode there is one worker per configured device, each with its own session runner and cancel + event so concurrent sessions can be canceled independently. + """ + + def __init__(self, device: Optional[torch.device], runner: SessionRunnerBase) -> None: + self.device = device + self.runner = runner + self.cancel_event = ThreadEvent() + self.queue_item: Optional[SessionQueueItem] = None + self.thread: Optional[Thread] = None + + @property + def label(self) -> str: + return str(self.device) if self.device is not None else "default device" + + class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, @@ -319,57 +342,118 @@ def __init__( self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval + self._workers: list[_SessionWorker] = [] + + def _resolve_devices(self) -> list[Optional[torch.device]]: + """Determine the per-worker devices from config. + + Returns a single `None` (legacy single-worker, device chosen by the global config) unless + `generation_devices` is configured, in which case it returns one normalized device per + listed device (deduplicated, order preserved). + """ + generation_devices = self._invoker.services.configuration.generation_devices + if not generation_devices: + return [None] + devices: list[Optional[torch.device]] = [] + seen: set[str] = set() + for device_str in generation_devices: + device = TorchDevice.normalize(device_str) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices + + def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: + """Create an independent runner for an additional worker. + + Each worker needs its own runner because the runner stores its session's cancel event. + We carry over the template's callbacks so all workers behave identically. + """ + if isinstance(template, DefaultSessionRunner): + return DefaultSessionRunner( + on_before_run_session_callbacks=list(template._on_before_run_session_callbacks), + on_before_run_node_callbacks=list(template._on_before_run_node_callbacks), + on_after_run_node_callbacks=list(template._on_after_run_node_callbacks), + on_node_error_callbacks=list(template._on_node_error_callbacks), + on_after_run_session_callbacks=list(template._on_after_run_session_callbacks), + ) + # Unknown runner implementation — only safe to reuse in single-worker mode. + return template def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None - self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() self._poll_now_event = ThreadEvent() - self._cancel_event = ThreadEvent() register_events(QueueClearedEvent, self._on_queue_cleared) register_events(BatchEnqueuedEvent, self._on_batch_enqueued) register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_semaphore = BoundedSemaphore(self._thread_limit) + devices = self._resolve_devices() # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. + # the profiler will create a new profile for each session. Profiling uses a process-global cProfile, which + # cannot cleanly attribute work when multiple sessions run concurrently, so it is disabled in multi-GPU mode. + profiler_enabled = self._invoker.services.configuration.profile_graphs + if profiler_enabled and len(devices) > 1: + self._invoker.services.logger.warning( + "Graph profiling is disabled because multiple generation devices are configured." + ) + profiler_enabled = False self._profiler = ( Profiler( logger=self._invoker.services.logger, output_dir=self._invoker.services.configuration.profiles_path, prefix=self._invoker.services.configuration.profile_prefix, ) - if self._invoker.services.configuration.profile_graphs + if profiler_enabled else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) - self._thread = Thread( - name="session_processor", - target=self._process, - daemon=True, - kwargs={ - "stop_event": self._stop_event, - "poll_now_event": self._poll_now_event, - "resume_event": self._resume_event, - "cancel_event": self._cancel_event, - }, - ) - self._thread.start() + self._thread_semaphore = BoundedSemaphore(len(devices)) + + # Start in the running (resumed) state. + self._stop_event.clear() + self._resume_event.set() + + self._workers = [] + for index, device in enumerate(devices): + runner = self.session_runner if index == 0 else self._clone_session_runner(self.session_runner) + worker = _SessionWorker(device=device, runner=runner) + runner.start(services=invoker.services, cancel_event=worker.cancel_event, profiler=self._profiler) + self._workers.append(worker) + + if len(self._workers) > 1: + self._invoker.services.logger.info( + f"Starting session processor with {len(self._workers)} parallel workers on devices: " + f"{', '.join(w.label for w in self._workers)}" + ) + + for index, worker in enumerate(self._workers): + worker.thread = Thread( + name=f"session_processor_{index}", + target=self._process, + daemon=True, + kwargs={ + "worker": worker, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + }, + ) + worker.thread.start() def stop(self, *args, **kwargs) -> None: self._stop_event.set() # Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at - # the next step boundary instead of running to completion. Without this, the generation + # the next step boundary instead of running to completion. Without this, a generation # thread may still be executing CUDA operations when Python teardown begins, which can # cause a C++ std::terminate() crash ("terminate called without an active exception"). - self._cancel_event.set() - # Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). + for worker in self._workers: + worker.cancel_event.set() + # Wake any worker sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). self._poll_now_event.set() self._resume_event.set() @@ -377,28 +461,31 @@ def _poll_now(self) -> None: self._poll_now_event.set() async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: - if self._queue_item and self._queue_item.queue_id == event[1].queue_id: - self._cancel_event.set() + # Cancel every worker currently running an item from the cleared queue. + canceled = False + for worker in self._workers: + if worker.queue_item and worker.queue_item.queue_id == event[1].queue_id: + worker.cancel_event.set() + canceled = True + if canceled: self._poll_now() async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: self._poll_now() async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: - # Make sure the cancel event is for the currently processing queue item - if self._queue_item and self._queue_item.item_id != event[1].item_id: - return - if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: - # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is - # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel - # event, which the session runner checks between invocations. If set, the session runner loop is broken. - # - # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such - # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item - # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. - if event[1].status == "canceled": - self._cancel_event.set() - self._poll_now() + # Find the worker (if any) currently running the item whose status changed. + for worker in self._workers: + if worker.queue_item and worker.queue_item.item_id == event[1].item_id: + if event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the status is set to "canceled" and this event is + # emitted. We respond by setting that worker's cancel event, which its session runner checks + # between invocations (and which denoise_latents' step callback checks mid-node, raising + # CanceledException to stop immediately). + if event[1].status == "canceled": + worker.cancel_event.set() + self._poll_now() + return def resume(self) -> SessionProcessorStatus: if not self._resume_event.is_set(): @@ -413,22 +500,28 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=any(worker.queue_item is not None for worker in self._workers), ) def _process( self, + worker: _SessionWorker, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, - cancel_event: ThreadEvent, ): try: - # Any unhandled exception in this block is a fatal processor error and will stop the processor. + # Any unhandled exception in this block is a fatal processor error and will stop this worker. self._thread_semaphore.acquire() - stop_event.clear() - resume_event.set() - cancel_event.clear() + + # Pin this worker thread to its device so all device-selecting code (TorchDevice.choose_torch_device, + # which nodes and the model loader consult) resolves to this GPU. CUDA's current device is per-thread. + if worker.device is not None: + TorchDevice.set_session_device(worker.device) + if worker.device.type == "cuda": + torch.cuda.set_device(worker.device) + + worker.cancel_event.clear() while not stop_event.is_set(): poll_now_event.clear() @@ -437,10 +530,14 @@ def _process( # If we are paused, wait for resume event resume_event.wait() - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + if stop_event.is_set(): + break + + # Get the next session to process. dequeue() atomically claims the item, so concurrent + # workers never receive the same item. + worker.queue_item = self._invoker.services.session_queue.dequeue() - if self._queue_item is None: + if worker.queue_item is None: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) @@ -453,19 +550,20 @@ def _process( gc.collect() self._invoker.services.logger.info( - f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}" + f"Executing queue item {worker.queue_item.item_id}, session {worker.queue_item.session_id} " + f"on {worker.label}" ) - cancel_event.clear() + worker.cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) + worker.runner.run(queue_item=worker.queue_item) except Exception as e: error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() self._on_non_fatal_processor_error( - queue_item=self._queue_item, + queue_item=worker.queue_item, error_type=error_type, error_message=error_message, error_traceback=error_traceback, @@ -474,7 +572,7 @@ def _process( poll_now_event.wait(self._polling_interval) continue except Exception as e: - # Fatal error in processor, log and pass - we're done here + # Fatal error in this worker, log and pass - we're done here error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() @@ -482,9 +580,9 @@ def _process( self._invoker.services.logger.error(error_traceback) pass finally: - stop_event.clear() - poll_now_event.clear() - self._queue_item = None + worker.queue_item = None + if worker.device is not None: + TorchDevice.clear_session_device() self._thread_semaphore.release() def _on_non_fatal_processor_error( diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a05ed468857..f1bcd8c7c5c 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -1,6 +1,7 @@ import asyncio import json import sqlite3 +import threading from typing import Optional, Union, cast from pydantic_core import to_jsonable_python @@ -42,6 +43,12 @@ class SqliteSessionQueue(SessionQueueBase): __invoker: Invoker + # Serializes the select-candidate-then-claim sequence in `dequeue()`. The DB connection's + # RLock serializes individual statements, but the gap between selecting the next pending item + # and marking it 'in_progress' is a race: with multiple session-processor workers (multi-GPU), + # two workers could select the same item. Holding this lock across the whole claim prevents it. + _dequeue_lock = threading.Lock() + def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() @@ -210,27 +217,32 @@ async def enqueue_batch( return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT - sq.*, - u.display_name as user_display_name, - u.email as user_email - FROM session_queue sq - LEFT JOIN users u ON sq.user_id = u.user_id - WHERE sq.status = 'pending' - ORDER BY - sq.priority DESC, - sq.item_id ASC - LIMIT 1 - """ - ) - result = cast(Union[sqlite3.Row, None], cursor.fetchone()) - if result is None: - return None - queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") + # Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU) + # cannot select and claim the same pending item. `_set_queue_item_status` already no-ops + # if the item was concurrently moved to a terminal state (e.g. canceled), so we only need + # to guard against two dequeues racing for the same pending row. + with self._dequeue_lock: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT + sq.*, + u.display_name as user_display_name, + u.email as user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + WHERE sq.status = 'pending' + ORDER BY + sq.priority DESC, + sq.item_id ASC + LIMIT 1 + """ + ) + result = cast(Union[sqlite3.Row, None], cursor.fetchone()) + if result is None: + return None + queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) + queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index e3a0928e52b..1196a0f3885 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -229,6 +229,11 @@ def unsubscribe() -> None: return unsubscribe + @property + def execution_device(self) -> torch.device: + """Return the default execution device this cache loads models onto.""" + return self._execution_device + @property @synchronized def stats(self) -> Optional[CacheStats]: diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 359ce45dc4f..d912f86a8a3 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,3 +1,4 @@ +import threading from typing import Dict, Literal, Optional, Union import torch @@ -46,9 +47,35 @@ class TorchDevice: CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") + # Per-thread execution device. When set (by a session-processor worker thread bound to a + # specific GPU), `choose_torch_device()` returns it instead of consulting the global config. + # This is the lynchpin that makes the ~79 `choose_torch_device()` call sites (nodes, model + # patcher, etc.) resolve to the calling worker's GPU without per-call-site changes. + _session_device = threading.local() + + @classmethod + def set_session_device(cls, device: Union[str, torch.device]) -> None: + """Pin the calling thread's execution device. Used by multi-GPU session workers.""" + cls._session_device.device = cls.normalize(device) + + @classmethod + def get_session_device(cls) -> Optional[torch.device]: + """Return the calling thread's pinned execution device, or None if unset.""" + return getattr(cls._session_device, "device", None) + + @classmethod + def clear_session_device(cls) -> None: + """Remove the calling thread's pinned execution device, reverting to global config.""" + if hasattr(cls._session_device, "device"): + del cls._session_device.device + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" + # A worker thread pinned to a specific GPU takes precedence over the global config. + session_device = cls.get_session_device() + if session_device is not None: + return session_device app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index a39cf9be514..b22bd1b3aee 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -22,6 +22,7 @@ import type { ImageDTO } from 'services/api/types'; import { useImageViewerContext } from './context'; import { NoContentForViewer } from './NoContentForViewer'; import { ProgressImage } from './ProgressImage2'; +import { ProgressImageTiles } from './ProgressImageTiles'; import { ProgressIndicator } from './ProgressIndicator2'; export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => { @@ -30,9 +31,10 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails); const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer); const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation(); - const { onLoadImage, $progressEvent, $progressImage } = useImageViewerContext(); + const { onLoadImage, $progressEvent, $progressImage, $activeProgressData } = useImageViewerContext(); const progressEvent = useStore($progressEvent); const progressImage = useStore($progressImage); + const activeProgressData = useStore($activeProgressData); const [imageToRender, setImageToRender] = useState(null); useEffect(() => { @@ -134,6 +136,9 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu }); const withProgress = shouldShowProgressInViewer && progressImage !== null; + // When more than one session is generating concurrently (multi-GPU), tile their previews instead of + // showing only the most recent one. + const withTiledProgress = withProgress && activeProgressData.length > 1; return ( } {withProgress && ( - - {progressEvent && ( - + {withTiledProgress ? ( + + ) : ( + <> + + {progressEvent && ( + + )} + )} )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx new file mode 100644 index 00000000000..6f66c02e929 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx @@ -0,0 +1,39 @@ +import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; +import { memo, useMemo } from 'react'; + +import type { ViewerProgressDatum } from './context'; +import { ProgressImage } from './ProgressImage2'; +import { ProgressIndicator } from './ProgressIndicator2'; + +/** + * Renders one tile per concurrently-running session (multi-GPU). Each tile shows that session's live + * preview image plus a small progress indicator. Used by the viewer when more than one session is + * active; a single active session uses the full-size preview instead. + */ +export const ProgressImageTiles = memo(({ data }: { data: ViewerProgressDatum[] }) => { + // Lay the tiles out in a roughly-square grid that grows with the number of active sessions. + const columns = useMemo(() => Math.ceil(Math.sqrt(data.length)), [data.length]); + + return ( + + {data.map((datum) => ( + + + + + + + ))} + + ); +}); +ProgressImageTiles.displayName = 'ProgressImageTiles'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx index 1cb22d61463..145ab63ba6e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors'; import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common'; import { LRUCache } from 'lru-cache'; -import { type Atom, atom, computed } from 'nanostores'; +import { type Atom, atom, computed, map, type MapStore } from 'nanostores'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react'; import type { S } from 'services/api/types'; @@ -12,10 +12,24 @@ import { $socket } from 'services/events/stores'; import { assert } from 'tsafe'; import type { JsonObject } from 'type-fest'; +/** Live progress for a single in-flight session (queue item). Used to tile the viewer when several + * sessions run concurrently (multi-GPU). Only items that have produced a preview image are tracked. */ +export type ViewerProgressDatum = { + itemId: number; + progressEvent: S['InvocationProgressEvent']; + progressImage: ProgressImageType; +}; + +type ViewerProgressDataMap = Record; + type ImageViewerContextValue = { $progressEvent: Atom; $progressImage: Atom; $hasProgressImage: Atom; + /** Per-session progress, keyed by queue item id. Drives the tiled multi-session preview. */ + $progressData: MapStore; + /** Active sessions (those with a preview image), sorted by item id for a stable tile order. */ + $activeProgressData: Atom; onLoadImage: () => void; }; @@ -29,6 +43,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { const $progressEvent = useState(() => atom(null))[0]; const $progressImage = useState(() => atom(null))[0]; const $hasProgressImage = useState(() => computed($progressImage, (progressImage) => progressImage !== null))[0]; + // Per-session progress, keyed by queue item id, for the tiled multi-session preview (multi-GPU). + const $progressData = useState(() => map({}))[0]; + const $activeProgressData = useState(() => + computed($progressData, (progressData) => + Object.values(progressData) + .filter((datum): datum is ViewerProgressDatum => datum !== undefined) + .sort((a, b) => a.itemId - b.itemId) + ) + )[0]; // We can have race conditions where we receive a progress event for a queue item that has already finished. Easiest // way to handle this is to keep track of finished queue items in a cache and ignore progress events for those. const [finishedQueueItemIds] = useState(() => new LRUCache({ max: 200 })); @@ -49,6 +72,12 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent.set(data); if (data.image) { $progressImage.set(data.image); + // Track per-session so the viewer can tile concurrent sessions (multi-GPU). + $progressData.setKey(data.item_id, { + itemId: data.item_id, + progressEvent: data, + progressImage: data.image, + }); } }; @@ -57,7 +86,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('invocation_progress', onInvocationProgress); }; - }, [$progressEvent, $progressImage, finishedQueueItemIds, socket]); + }, [$progressData, $progressEvent, $progressImage, finishedQueueItemIds, socket]); useEffect(() => { if (!socket) { @@ -74,6 +103,9 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { } if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') { finishedQueueItemIds.set(data.item_id, true); + // Remove this session's tile from the multi-session preview as soon as it reaches a terminal + // state. The single-image "resolve" illusion below is handled separately via onLoadImage. + $progressData.setKey(data.item_id, undefined); // Completed queue items have the progress event cleared by the onLoadImage callback. This allows the viewer to // create the illusion of the progress image "resolving" into the final image. If we cleared the progress image // now, there would be a flicker where the progress image disappears before the final image appears, and the @@ -103,7 +135,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('queue_item_status_changed', onQueueItemStatusChanged); }; - }, [$progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); + }, [$progressData, $progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); const onLoadImage = useCallback(() => { $progressEvent.set(null); @@ -111,8 +143,8 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { }, [$progressEvent, $progressImage]); const value = useMemo( - () => ({ $progressEvent, $progressImage, $hasProgressImage, onLoadImage }), - [$hasProgressImage, $progressEvent, $progressImage, onLoadImage] + () => ({ $progressEvent, $progressImage, $hasProgressImage, $progressData, $activeProgressData, onLoadImage }), + [$hasProgressImage, $progressEvent, $progressImage, $progressData, $activeProgressData, onLoadImage] ); return {props.children}; diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py new file mode 100644 index 00000000000..c9bb107d809 --- /dev/null +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -0,0 +1,81 @@ +"""Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" + +import threading + +import torch + +from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.model_load.model_load_default import ModelLoadService +from invokeai.backend.util.devices import TorchDevice + + +class _FakeCache: + """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" + + def __init__(self, device: str): + self.execution_device = torch.device(device) + + +def _build_service() -> tuple[ModelLoadService, _FakeCache, _FakeCache]: + cache0 = _FakeCache("cuda:0") + cache1 = _FakeCache("cuda:1") + service = ModelLoadService( + app_config=InvokeAIAppConfig(), + ram_cache=cache0, # type: ignore[arg-type] + ram_caches={"cuda:0": cache0, "cuda:1": cache1}, # type: ignore[arg-type] + ) + return service, cache0, cache1 + + +def test_ram_cache_routes_to_pinned_device(): + """A thread pinned to cuda:1 resolves to that device's cache; the default thread to cuda:0.""" + service, cache0, cache1 = _build_service() + + # The default thread has no session device; point config.device at cuda:0 so it resolves there. + get_config().device = "cuda:0" + assert service.ram_cache is cache0 + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:1") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache1 + # Main thread is unaffected by the worker's pinning. + assert service.ram_cache is cache0 + + +def test_ram_caches_exposes_all_devices(): + service, cache0, cache1 = _build_service() + caches = service.ram_caches + assert set(caches.keys()) == {"cuda:0", "cuda:1"} + assert caches["cuda:0"] is cache0 + assert caches["cuda:1"] is cache1 + + +def test_unknown_device_falls_back_to_default(): + """A thread pinned to a device with no cache falls back to the default cache.""" + service, cache0, _ = _build_service() + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:7") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache0 diff --git a/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py new file mode 100644 index 00000000000..8d55db941a5 --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py @@ -0,0 +1,70 @@ +"""Tests that concurrent dequeue() calls (multi-GPU session workers) never claim the same item twice.""" + +import threading +import uuid + +import pytest + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item(session_queue: SqliteSessionQueue, user_id: str = "system") -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("default", session_json, session.id, batch_id, None, 0, None, None, None, None, user_id), + ) + return cursor.lastrowid + + +def test_concurrent_dequeue_never_claims_same_item_twice(session_queue: SqliteSessionQueue) -> None: + item_count = 50 + worker_count = 8 + for _ in range(item_count): + _insert_queue_item(session_queue) + + claimed_ids: list[int] = [] + claimed_lock = threading.Lock() + start_barrier = threading.Barrier(worker_count) + + def worker() -> None: + # Release all workers at once to maximize contention on the dequeue path. + start_barrier.wait() + while True: + item = session_queue.dequeue() + if item is None: + break + with claimed_lock: + claimed_ids.append(item.item_id) + + threads = [threading.Thread(target=worker) for _ in range(worker_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every item is claimed exactly once: no duplicates, none lost. + assert len(claimed_ids) == item_count + assert len(set(claimed_ids)) == item_count diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 3f134e3c3da..39dee5cb618 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -2,6 +2,7 @@ Test abstract device class. """ +import threading from unittest.mock import patch import pytest @@ -24,6 +25,50 @@ def test_device_choice(device_name): assert torch_device == torch.device(device_name) +# ===== per-thread session device (multi-GPU worker pinning) ================ + + +def test_session_device_overrides_config(): + """A per-thread session device takes precedence over the global config.device.""" + config = get_config() + config.device = "cpu" + try: + TorchDevice.set_session_device("cuda:1") + assert TorchDevice.choose_torch_device() == torch.device("cuda:1") + finally: + TorchDevice.clear_session_device() + # Once cleared, we fall back to the global config. + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +def test_session_device_is_thread_local(): + """Each thread sees only its own pinned device; the main thread is unaffected.""" + config = get_config() + config.device = "cpu" + results: dict[str, torch.device] = {} + barrier = threading.Barrier(2) + + def worker(name: str, device: str): + TorchDevice.set_session_device(device) + # Wait so both threads have set their device before either reads it, proving isolation. + barrier.wait() + results[name] = TorchDevice.choose_torch_device() + TorchDevice.clear_session_device() + + t0 = threading.Thread(target=worker, args=("a", "cuda:0")) + t1 = threading.Thread(target=worker, args=("b", "cuda:1")) + t0.start() + t1.start() + t0.join() + t1.join() + + assert results["a"] == torch.device("cuda:0") + assert results["b"] == torch.device("cuda:1") + # The main thread never set a session device, so it still uses the global config. + assert TorchDevice.get_session_device() is None + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with ( From 6bb89d6ba6ff7f06c1b5057efce84e757acfae8e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 14:49:43 -0400 Subject: [PATCH 02/33] fix(tests): restore global device after multi-GPU cache routing test test_model_load_device_routing mutated the process-wide get_config() singleton (device = "cuda:0") to exercise the per-thread cache routing, but never restored it. The leaked CUDA device was then picked up by a later test (test_model_load::test_loading) via choose_torch_device(), which crashed with "Torch not compiled with CUDA enabled" on the CUDA-less CI runner. Add an autouse fixture to save/restore device and clear any pinned session device. Co-Authored-By: Claude Opus 4.8 --- .../model_load/test_model_load_device_routing.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py index c9bb107d809..85b3868b92f 100644 --- a/tests/app/services/model_load/test_model_load_device_routing.py +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -1,7 +1,9 @@ """Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" import threading +from collections.abc import Iterator +import pytest import torch from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config @@ -9,6 +11,19 @@ from invokeai.backend.util.devices import TorchDevice +@pytest.fixture(autouse=True) +def restore_global_device() -> Iterator[None]: + """`get_config()` is a process-wide singleton; restore `device` so we don't leak a CUDA device + into later CPU-only tests (e.g. the model-loading suite on the CUDA-less CI runner).""" + config = get_config() + original_device = config.device + try: + yield + finally: + config.device = original_device + TorchDevice.clear_session_device() + + class _FakeCache: """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" From a3be44423a6dcb772c8ba359bf53830ca6695b30 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 14:51:14 -0400 Subject: [PATCH 03/33] chore(ui): regenerate openapi schema and frontend types for generation_devices Regenerate openapi.json (make frontend-openapi) and the frontend schema.ts types (make frontend-typegen) so they include the new generation_devices config field, fixing the openapi-checks and typegen-checks CI jobs. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/openapi.json | 18 +++++++++++++++++- .../frontend/web/src/services/api/schema.ts | 5 +++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 1287ee58865..b828412fb0c 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -14375,7 +14375,8 @@ } }, "type": "object", - "title": "CacheStats" + "title": "CacheStats", + "description": "Collect statistics on cache performance." }, "CalculateImageTilesEvenSplitInvocation": { "category": "tiles", @@ -41151,6 +41152,21 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Generation Devices", + "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index a80183476bd..ace2e30178e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,6 +16506,11 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) + */ + generation_devices?: string[] | null; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. From be54889811d95b6cf6aefd350f11ba76689bc64a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 15:07:44 -0400 Subject: [PATCH 04/33] fix(ui): regenerate openapi.json with uv to match CI generator `make frontend-openapi` used a bare `python` from a different environment that emitted the CacheStats @dataclass docstring as a schema description. CI generates the schema via `uv run`, which does not, so openapi-checks failed on the diff. Regenerate with the uv-locked environment to drop the stray description while keeping the generation_devices field. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/openapi.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index b828412fb0c..852dc866bce 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -14375,8 +14375,7 @@ } }, "type": "object", - "title": "CacheStats", - "description": "Collect statistics on cache performance." + "title": "CacheStats" }, "CalculateImageTilesEvenSplitInvocation": { "category": "tiles", From a119b50bc1f27eafcf3087c22b35f28b9ee2e398 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 23:00:29 -0400 Subject: [PATCH 05/33] fix(model-manager): serialize model construction against VRAM moves to prevent meta-device corruption Parallel multi-GPU session workers could intermittently crash with "unrecognized device meta" (denoise) or "Cannot copy out of meta tensor; no data!" (l2i), because model loading relies on process-global, non-thread-safe monkey-patches. accelerate.init_empty_weights() (used directly by the loaders and implicitly by diffusers' default low_cpu_mem_usage=True in from_pretrained) swaps torch.nn.Module.register_parameter globally for the duration of a load, routing every newly-registered parameter to the meta device. The model cache's VRAM load/unload runs nn.Module.load_state_dict(assign=True), whose assign path does setattr -> __setattr__ -> register_parameter. When one worker's VRAM move overlapped another worker's from_pretrained, the move's real weights got hijacked onto meta and blew up on the next .to(device). Introduce MODEL_LOAD_LOCK, a write-preferring readers-writer lock: - write lock = model construction (_load_and_cache, load_model_from_path), exclusive. - read lock = VRAM load/unload (ModelCache.lock(), repair_required_tensors_on_device). VRAM transfers across GPUs still overlap each other; they only block while a construction holds the write lock. The lock is always acquired before any per-cache lock to keep a consistent order and avoid an AB-BA deadlock with the writer's make_room/put. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../services/model_load/model_load_default.py | 17 +++-- .../backend/model_manager/load/load_base.py | 20 ++++-- .../model_manager/load/load_default.py | 64 ++++++++++++----- .../load/model_cache/model_cache.py | 70 ++++++++++++++++++- 4 files changed, 143 insertions(+), 28 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 45d0c354278..33c7ef6108c 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -18,7 +18,7 @@ ModelLoaderRegistry, ModelLoaderRegistryBase, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType from invokeai.backend.util.devices import TorchDevice @@ -147,6 +147,15 @@ def diffusers_load_directory(directory: Path) -> AnyModel: else lambda path: safetensors_load_file(path, device="cpu") ) assert loader is not None - raw_model = loader(model_path) - ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + # Serialize construction (see MODEL_LOAD_LOCK): the diffusers loader path uses the same + # process-global, non-thread-safe monkey-patches as the main loader, so it takes the write + # lock to exclude concurrent VRAM moves. Re-check the cache after acquiring the lock in case + # a worker sharing this cache built it while we waited. + with MODEL_LOAD_LOCK.write_lock(): + try: + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + except IndexError: + pass + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4609a2e92ab..984362f185d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -57,7 +57,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache): self._cache = cache def __enter__(self) -> AnyModel: - self._cache.lock(self._cache_record, None) + # Hold the MODEL_LOAD_LOCK read lock across the VRAM load (lock() runs + # load_state_dict(assign=True), which calls register_parameter) so it can't overlap a + # concurrent model construction that has the global register_parameter -> meta patch active. + # Acquired before the cache's own lock to keep a consistent lock order (see MODEL_LOAD_LOCK). + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, None) try: self.repair_required_tensors_on_device() return self.model @@ -77,7 +82,9 @@ def model_on_device( :param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the model. """ - self._cache.lock(self._cache_record, working_mem_bytes) + # See __enter__ for why the VRAM load is wrapped in the read lock. + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, working_mem_bytes) try: self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) @@ -94,7 +101,12 @@ def repair_required_tensors_on_device(self) -> int: cached_model = self._cache_record.cached_model if not isinstance(cached_model, CachedModelWithPartialLoad): return 0 - return cached_model.repair_required_tensors_on_compute_device() + # Repair runs load_state_dict(assign=True) -> register_parameter, so it must hold the read + # lock to avoid being hijacked onto the `meta` device by a concurrent construction. This is + # also called directly (outside __enter__/model_on_device) by some text-encoder invocations, + # so the guard lives here rather than only at the call sites. + with MODEL_LOAD_LOCK.read_lock(): + return cached_model.repair_required_tensors_on_compute_device() class LoadedModel(LoadedModelWithoutConfig): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 040b55cb6ec..02929ff6132 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -13,7 +13,11 @@ from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + MODEL_LOAD_LOCK, + ModelCache, + get_model_cache_key, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.taxonomy import ( @@ -52,7 +56,9 @@ ) -# TO DO: The loader is not thread safe! +# The construction path is not thread-safe on its own; it monkey-patches process-global torch state +# (see MODEL_LOAD_LOCK). Concurrent callers must hold the MODEL_LOAD_LOCK write lock (see +# _load_and_cache). class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -85,8 +91,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if not model_path.exists(): raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") - with skip_torch_weight_init(): - cache_record = self._load_and_cache(model_config, submodel_type) + cache_record = self._load_and_cache(model_config, submodel_type) return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) @property @@ -124,25 +129,46 @@ def _get_execution_device( def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord: stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")]) + cache_key = get_model_cache_key(config.key, submodel_type) try: - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + return self._ram_cache.get(key=cache_key, stats_name=stats_name) except IndexError: pass - config.path = str(self._get_model_path(config)) - self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) - loaded_model = self._load_model(config, submodel_type) - - # Determine execution device from model config, considering submodel type - execution_device = self._get_execution_device(config, submodel_type) - - self._ram_cache.put( - get_model_cache_key(config.key, submodel_type), - model=loaded_model, - execution_device=execution_device, - ) - - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + # Cache miss: construct the model from disk. This path holds the MODEL_LOAD_LOCK *write* + # lock because it relies on process-global, non-thread-safe monkey-patches + # (skip_torch_weight_init and, inside the loaders, accelerate.init_empty_weights / diffusers + # low_cpu_mem_usage). The write lock excludes both other constructions AND concurrent VRAM + # load/unload on other workers (which take the read lock); without that, a concurrent move's + # load_state_dict(assign=True) -> register_parameter gets hijacked onto the `meta` device. + # See MODEL_LOAD_LOCK for the full explanation. + # + # Lock-ordering: the write lock is acquired before any ModelCache._lock taken below + # (get/make_room/put), matching the readers' order, so there is no AB-BA deadlock. + with MODEL_LOAD_LOCK.write_lock(): + # Double-checked locking: another worker sharing this cache may have loaded the same + # entry while we waited for the mutex. (Workers on other devices use a different cache, + # so they will still miss here and construct their own copy — which is intended.) + try: + return self._ram_cache.get(key=cache_key, stats_name=stats_name) + except IndexError: + pass + + config.path = str(self._get_model_path(config)) + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + with skip_torch_weight_init(): + loaded_model = self._load_model(config, submodel_type) + + # Determine execution device from model config, considering submodel type + execution_device = self._get_execution_device(config, submodel_type) + + self._ram_cache.put( + cache_key, + model=loaded_model, + execution_device=execution_device, + ) + + return self._ram_cache.get(key=cache_key, stats_name=stats_name) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 1196a0f3885..2ca8dd44ba2 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -2,10 +2,11 @@ import logging import threading import time +from contextlib import contextmanager from dataclasses import dataclass from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, Generator, List, Optional, Protocol import psutil import torch @@ -35,6 +36,73 @@ MB = 2**20 +class _ModelLoadReadWriteLock: + """A write-preferring readers-writer lock that serializes model construction against VRAM moves. + + The model load machinery depends on PROCESS-GLOBAL monkey-patches that are not thread-safe: + model CONSTRUCTION (diffusers `from_pretrained` / `accelerate.init_empty_weights`) temporarily + replaces `torch.nn.Module.register_parameter` so that every newly-registered parameter is routed + to the `meta` device. While that patch is installed, ANY `register_parameter` call in ANY thread + is hijacked onto `meta`. VRAM load/unload uses `nn.Module.load_state_dict(assign=True)`, which + assigns `Parameter`s via `__setattr__` -> `register_parameter` — so if it runs concurrently with + a construction on another worker thread, its real weights get stranded on `meta`. That surfaces + later as "Cannot copy out of meta tensor; no data!" or "unrecognized device meta". + + - Construction takes the WRITE lock (exclusive — no reader and no other writer may run). + - VRAM load/unload takes the READ lock (shared, so concurrent moves on different GPUs still + overlap each other; they only block while a construction holds the write lock). + + Write-preferring: once a construction is waiting, new readers queue behind it, so a steady stream + of VRAM moves from busy workers can't starve a pending load. + + Lock-ordering contract: callers MUST acquire this lock *before* any `ModelCache._lock`, never + after. Readers do so by taking the read lock around the outer `ModelCache.lock()` call (see + `LoadedModelWithoutConfig`), and writers around the whole construction (see + `ModelLoader._load_and_cache`). Acquiring it in the other order — cache lock first, then this + lock — would risk an AB-BA deadlock with a writer that takes a cache lock during `put()`. + """ + + def __init__(self) -> None: + self._cond = threading.Condition(threading.Lock()) + self._readers = 0 + self._writers_waiting = 0 + self._writer_active = False + + @contextmanager + def read_lock(self) -> Generator[None, None, None]: + with self._cond: + # Defer to any active or waiting writer (write-preferring). + while self._writer_active or self._writers_waiting > 0: + self._cond.wait() + self._readers += 1 + try: + yield + finally: + with self._cond: + self._readers -= 1 + if self._readers == 0: + self._cond.notify_all() + + @contextmanager + def write_lock(self) -> Generator[None, None, None]: + with self._cond: + self._writers_waiting += 1 + while self._writer_active or self._readers > 0: + self._cond.wait() + self._writers_waiting -= 1 + self._writer_active = True + try: + yield + finally: + with self._cond: + self._writer_active = False + self._cond.notify_all() + + +# Process-global lock guarding the non-thread-safe model load machinery. See _ModelLoadReadWriteLock. +MODEL_LOAD_LOCK = _ModelLoadReadWriteLock() + + # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: """Get the cache key for a model based on the optional submodel type.""" From 70114464698543acbe78c6b6191d5c50faabe119 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 16:34:44 -0400 Subject: [PATCH 06/33] fix(backend): fix outpainting crash caused by model download collisions --- .../model_install/model_install_default.py | 67 +++++++++++++------ 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 49d3cfdf7f9..5f70fc53838 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -114,6 +114,11 @@ def __init__( self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[int, ModelInstallJob] = {} + # Per-source locks serializing download_and_cache_model() so parallel (multi-GPU) sessions + # that need the same remote model (e.g. the LaMa infill model) don't race to download into + # the same cache directory. _download_cache_locks_guard protects the dict itself. + self._download_cache_locks: Dict[str, threading.Lock] = {} + self._download_cache_locks_guard = threading.Lock() self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -711,27 +716,47 @@ def download_and_cache_model( if len(contents) > 0: return contents[0] - model_path.mkdir(parents=True, exist_ok=True) - model_source = self._guess_source(str(source)) - remote_files, _ = self._remote_files_from_source(model_source) - # Handle multiple subfolders for HFModelSource - subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] - job = self._multifile_download( - dest=model_path, - remote_files=remote_files, - subfolder=model_source.subfolder - if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 - else None, - subfolders=subfolders if len(subfolders) > 1 else None, - ) - files_string = "file" if len(remote_files) == 1 else "files" - self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") - self._download_queue.wait_for_job(job) - if job.complete: - assert job.download_path is not None - return job.download_path - else: - raise Exception(job.error) + # Serialize concurrent downloads of the same source. Parallel multi-GPU sessions can each + # request the same remote model (e.g. the LaMa infill model) at once; without this lock they + # both download into the same cache directory and collide on the final rename, which fails on + # Windows with "WinError 32: the file is being used by another process". The other waiters + # find the completed download on the post-lock re-check below and skip downloading. + with self._download_cache_lock(str(source)): + if model_path.exists(): + contents = list(model_path.iterdir()) + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + # Handle multiple subfolders for HFModelSource + subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] + job = self._multifile_download( + dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder + if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 + else None, + subfolders=subfolders if len(subfolders) > 1 else None, + ) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + + def _download_cache_lock(self, source: str) -> threading.Lock: + """Return the lock that serializes downloads for a given source, creating it on first use.""" + with self._download_cache_locks_guard: + lock = self._download_cache_locks.get(source) + if lock is None: + lock = threading.Lock() + self._download_cache_locks[source] = lock + return lock def _remote_files_from_source( self, source: ModelSource From a1fe3757f051d21893f53b3b3cddd0aca703f819 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 16:42:18 -0400 Subject: [PATCH 07/33] fix(backend): make DiskImageFileStorage thread-safe for parallel sessions Image.open() is lazy: it reads the header but defers pixel decoding (and holds the file handle open) until the first .load()/.copy()/.convert(). The opened object was cached and the same object handed to every caller, so in multi-GPU parallel mode two session-processor worker threads could call .copy() on it concurrently and race on the shared file handle and decoder state. This surfaced as "broken data stream when reading image file" and "AssertionError: self.png is not None" during inpainting with batch >1. Force the decode (image.load()) before the object enters the cache so the cached object is safe for concurrent reads, and guard the cache structures (__cache / __cache_ids) with a lock since they are now mutated from multiple threads. Co-Authored-By: Claude Opus 4.8 --- .../services/image_files/image_files_disk.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 12b737a7cf1..ec84439547a 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -1,4 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import threading from pathlib import Path from queue import Queue from typing import Optional, Union @@ -23,6 +24,9 @@ def __init__(self, output_folder: Union[str, Path]): self.__cache: dict[Path, PILImageType] = {} self.__cache_ids = Queue[Path]() self.__max_cache_size = 10 # TODO: get this from config + # Guards the cache structures (__cache / __cache_ids), which are read and mutated from + # multiple session-processor worker threads in multi-GPU parallel mode. + self.__cache_lock = threading.Lock() self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__thumbnails_folder = self.__output_folder / "thumbnails" @@ -41,6 +45,13 @@ def get(self, image_name: str, image_subfolder: str = "") -> PILImageType: return cache_item image = Image.open(image_path) + # Image.open() is lazy: it reads the header but defers pixel decoding (and holds the + # file handle open) until the first .load()/.copy()/.convert(). The opened object is + # cached and the SAME object is handed to every caller, so in multi-GPU parallel mode + # two worker threads can call .copy() on it concurrently and race on the shared file + # handle and decoder state, producing "broken data stream" / "self.png is not None" + # errors. Forcing the decode here makes the cached object safe for concurrent reads. + image.load() self.__set_cache(image_path, image) return image except FileNotFoundError as e: @@ -105,16 +116,18 @@ def delete(self, image_name: str, image_subfolder: str = "") -> None: if image_path.exists(): image_path.unlink() - if image_path in self.__cache: - del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, True, image_subfolder=image_subfolder) if thumbnail_path.exists(): thumbnail_path.unlink() - if thumbnail_path in self.__cache: - del self.__cache[thumbnail_path] + + with self.__cache_lock: + if image_path in self.__cache: + del self.__cache[image_path] + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] except Exception as e: raise ImageFileDeleteException from e @@ -185,13 +198,15 @@ def __validate_storage_folders(self) -> None: folder.mkdir(parents=True, exist_ok=True) def __get_cache(self, image_name: Path) -> Optional[PILImageType]: - return None if image_name not in self.__cache else self.__cache[image_name] + with self.__cache_lock: + return None if image_name not in self.__cache else self.__cache[image_name] def __set_cache(self, image_name: Path, image: PILImageType): - if image_name not in self.__cache: - self.__cache[image_name] = image - self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache - if len(self.__cache) > self.__max_cache_size: - cache_id = self.__cache_ids.get() - if cache_id in self.__cache: - del self.__cache[cache_id] + with self.__cache_lock: + if image_name not in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + if cache_id in self.__cache: + del self.__cache[cache_id] From 3db88d51ebcf2bbd07bc01ec434c574c0effc15c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:09:23 -0400 Subject: [PATCH 08/33] feat(ui): stack per-session progress bars during parallel generation The generation progress bars (under the Invoke button and the Viewer tab) both read a single global $lastProgressEvent atom, which every session overwrites. With parallel multi-GPU sessions this made the bar jump back and forth between sessions. Track progress per queue item id and render one bar per in-flight session, stacked vertically, each removed as its session reaches a terminal state. - stores.ts: add $progressEvents (map keyed by item_id), $activeProgressEvents (sorted), and set/clear helpers. - setEventListeners.tsx: populate per-item progress on invocation_progress; clear per item on terminal status; clear all on connect/disconnect/queue cleared. - ProgressBar.tsx: render a vertical stack of bars (one per active session) with a single-bar fallback for the idle / model-loading window; add containerProps so dockview tabs can position the stack. - Dockview tab call sites: move positioning into containerProps. Co-Authored-By: Claude Opus 4.8 --- .../system/components/ProgressBar.tsx | 90 ++++++++++--------- .../ui/layouts/DockviewTabCanvasViewer.tsx | 6 +- .../ui/layouts/DockviewTabCanvasWorkspace.tsx | 6 +- .../ui/layouts/DockviewTabProgress.tsx | 6 +- .../src/services/events/setEventListeners.tsx | 15 +++- .../web/src/services/events/stores.ts | 29 +++++- 6 files changed, 103 insertions(+), 49 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 5a4abdd4d28..6a305416bbf 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -1,62 +1,64 @@ -import type { ProgressProps } from '@invoke-ai/ui-library'; -import { Progress } from '@invoke-ai/ui-library'; +import type { FlexProps, ProgressProps } from '@invoke-ai/ui-library'; +import { Flex, Progress } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; -import { $isConnected, $lastProgressEvent, $loadingModelsCount } from 'services/events/stores'; +import { $activeProgressEvents, $isConnected, $loadingModelsCount } from 'services/events/stores'; -const ProgressBar = (props: ProgressProps) => { +type ProgressBarProps = ProgressProps & { + /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ + containerProps?: FlexProps; +}; + +type BarDescriptor = { + key: number | string; + value: number; + isIndeterminate: boolean; +}; + +const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); - const lastProgressEvent = useStore($lastProgressEvent); + const activeProgressEvents = useStore($activeProgressEvents); const loadingModelsCount = useStore($loadingModelsCount); - const value = useMemo(() => { - if (!lastProgressEvent) { - return 0; - } - return (lastProgressEvent.percentage ?? 0) * 100; - }, [lastProgressEvent]); - - const isIndeterminate = useMemo(() => { - if (!isConnected) { - return false; - } - if (loadingModelsCount > 0) { - return true; + const bars = useMemo(() => { + // One bar per in-flight session (multi-GPU). Each session's progress is tracked independently, so + // the bars no longer jump back and forth when several sessions render simultaneously. + if (activeProgressEvents.length > 0) { + return activeProgressEvents.map((event) => ({ + key: event.item_id, + value: (event.percentage ?? 0) * 100, + isIndeterminate: isConnected && (loadingModelsCount > 0 || event.percentage === null || event.percentage === 0), + })); } - if (!queueStatus?.queue.in_progress) { - return false; + // Fallback single bar: idle, or generation has started but no progress event has arrived yet (e.g. + // while models are loading). Mirrors the previous single-bar indeterminate behavior. + let isIndeterminate = false; + if (isConnected && (loadingModelsCount > 0 || Boolean(queueStatus?.queue.in_progress))) { + isIndeterminate = true; } - - if (!lastProgressEvent) { - return true; - } - - if (lastProgressEvent.percentage === null) { - return true; - } - - if (lastProgressEvent.percentage === 0) { - return true; - } - - return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); + return [{ key: 'idle', value: 0, isIndeterminate }]; + }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); return ( - + + {bars.map((bar) => ( + + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index 80f851ab7af..a53e0c3c4cb 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -34,7 +34,11 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( - + )}
); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index 440847d7451..f96381511fc 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -37,7 +37,11 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {t(props.params.i18nKey)} {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 1d997caaf78..180babf8191 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -32,7 +32,11 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 1e73abb2027..fa8e2895ba3 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -53,7 +53,13 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent, $loadingModelsCount } from './stores'; +import { + $lastProgressEvent, + $loadingModelsCount, + clearAllProgressEvents, + clearProgressEvent, + setProgressEvent, +} from './stores'; const log = logger('events'); @@ -84,6 +90,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.emit('subscribe_queue', { queue_id: 'default' }); socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' }); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); }); @@ -91,6 +98,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; @@ -108,6 +116,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('disconnect', () => { log.debug('Disconnected'); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); setIsConnected(false); }); @@ -148,6 +157,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, _message); $lastProgressEvent.set(data); + setProgressEvent(data); if (origin === 'workflows') { const nes = $nodeExecutionStates.get()[invocation_source_id]; @@ -491,11 +501,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } // If the queue item is completed, failed, or cancelled, we want to clear the last progress event $lastProgressEvent.set(null); + // Also remove this session's per-item progress so its stacked progress bar disappears. + clearProgressEvent(item_id); } }); socket.on('queue_cleared', (data) => { log.debug({ data }, 'Queue cleared'); + clearAllProgressEvents(); dispatch( queueApi.util.invalidateTags([ 'SessionQueueStatus', diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 180f4a3a636..95c88bc28cc 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -1,5 +1,5 @@ import { round } from 'es-toolkit/compat'; -import { atom, computed } from 'nanostores'; +import { atom, computed, map } from 'nanostores'; import type { S } from 'services/api/types'; import type { AppSocket } from 'services/events/types'; @@ -8,6 +8,33 @@ export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); export const $loadingModelsCount = atom(0); +/** + * Live progress events keyed by queue item id. Unlike `$lastProgressEvent` (a single global value that + * is overwritten by whichever session reported last), this tracks each in-flight session separately so + * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress + * events arrive and removed when the session reaches a terminal state. + */ +export const $progressEvents = map>({}); + +/** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ +export const $activeProgressEvents = computed($progressEvents, (events) => + Object.values(events) + .filter((event): event is S['InvocationProgressEvent'] => event !== undefined) + .sort((a, b) => a.item_id - b.item_id) +); + +export const setProgressEvent = (event: S['InvocationProgressEvent']) => { + $progressEvents.setKey(event.item_id, event); +}; + +export const clearProgressEvent = (itemId: number) => { + $progressEvents.setKey(itemId, undefined); +}; + +export const clearAllProgressEvents = () => { + $progressEvents.set({}); +}; + export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) { return null; From 351758c60643365e85d302b4b48d778341a83224 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:13:35 -0400 Subject: [PATCH 09/33] fix(ui): make $progressEvents module-local to satisfy knip $progressEvents is only referenced within stores.ts (via the $activeProgressEvents computed and the set/clear helpers), so exporting it tripped knip's unused-exports check. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/src/services/events/stores.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 95c88bc28cc..7c7630e2019 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -14,7 +14,7 @@ export const $loadingModelsCount = atom(0); * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress * events arrive and removed when the session reaches a terminal state. */ -export const $progressEvents = map>({}); +const $progressEvents = map>({}); /** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ export const $activeProgressEvents = computed($progressEvents, (events) => From 4f6613f0002f38401b4cba4f5539d8864b6af44d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:23:44 -0400 Subject: [PATCH 10/33] fix(ui): cap stacked tab progress bars to fit below the tab label With 4 GPUs the stacked per-session progress bars grew past the bottom strip of the dockview tab and overlapped the "Viewer" label. Add a fitHeightPx prop: in fit mode the stack is capped to the available strip (10px below the ~40px tab's centered label) and the bars flex to share it, shrinking below their natural height only once they no longer fit. With 1-2 sessions the bars keep their familiar thin height; with 3+ they scale down to stay within the strip. The sidebar bar is unaffected and continues to stack at natural height (it has the vertical room). Co-Authored-By: Claude Opus 4.8 --- .../system/components/ProgressBar.tsx | 29 +++++++++++++++++-- .../ui/layouts/DockviewTabCanvasViewer.tsx | 2 +- .../ui/layouts/DockviewTabCanvasWorkspace.tsx | 2 +- .../ui/layouts/DockviewTabProgress.tsx | 2 +- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 6a305416bbf..a38b2603625 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -6,9 +6,20 @@ import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; import { $activeProgressEvents, $isConnected, $loadingModelsCount } from 'services/events/stores'; +// In "fit" mode (e.g. the strip below a dockview tab label) the stack is constrained to a fixed height. +// Bars stay at FIT_BAR_HEIGHT_PX while they fit, then shrink to share the available space so they never +// overlap the label, no matter how many sessions are running. +const FIT_BAR_HEIGHT_PX = 4; +const FIT_BAR_GAP_PX = 1; + type ProgressBarProps = ProgressProps & { /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ containerProps?: FlexProps; + /** + * When set, the stacked bars are constrained to this total height (in px) and shrink to share it, so + * they never grow past the available space (e.g. the strip below a dockview tab label). + */ + fitHeightPx?: number; }; type BarDescriptor = { @@ -17,7 +28,7 @@ type BarDescriptor = { isIndeterminate: boolean; }; -const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { +const ProgressBar = ({ containerProps, fitHeightPx, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); @@ -44,8 +55,21 @@ const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { return [{ key: 'idle', value: 0, isIndeterminate }]; }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); + // In fit mode, cap the whole stack to the available strip and let the bars flex to share it. When the + // bars fit at their natural height the stack is shorter than the cap; once they don't, they shrink. + const isFit = fitHeightPx !== undefined; + const fitContainerProps = useMemo(() => { + if (!isFit) { + return undefined; + } + const naturalHeight = bars.length * FIT_BAR_HEIGHT_PX + Math.max(0, bars.length - 1) * FIT_BAR_GAP_PX; + return { h: `${Math.min(naturalHeight, fitHeightPx)}px`, gap: `${FIT_BAR_GAP_PX}px` }; + }, [bars.length, fitHeightPx, isFit]); + + const fitBarProps: ProgressProps | undefined = isFit ? { flex: '1 1 0', minH: 0, h: 'auto' } : undefined; + return ( - + {bars.map((bar) => ( { w="full" colorScheme="invokeBlue" {...props} + {...fitBarProps} /> ))} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index a53e0c3c4cb..62246faa0f8 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -35,8 +35,8 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( )} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index f96381511fc..285afa3a1b6 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -38,8 +38,8 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( )} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 180babf8191..c89f682e66a 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -33,8 +33,8 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( )} From 2a65c4aa6e381d104bfee69bcd89ecf185d41f51 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:34:26 -0400 Subject: [PATCH 11/33] feat(config): support "auto" generation_devices to use all GPUs by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generation_devices now accepts "auto" (the new default), which expands to every visible CUDA device — so multi-GPU parallel generation works out of the box without manually listing devices. On GPU-less systems "auto" resolves to the single cpu/mps device, preserving serial behavior. - config_default.py: type is now Union[Literal["auto"], list[str]], default "auto"; validator accepts "auto" or a list of device strings. - devices.py: add TorchDevice.get_generation_devices(), the single resolver that expands "auto", normalizes, and deduplicates. - session_processor / model_manager: both consumers use the resolver instead of iterating the raw config value (which would have iterated the characters of the "auto" string). - Regenerated docs/src/generated/settings.json. - Tests for the resolver (auto-with/without-CUDA, dedup, empty). An explicit single-device list (e.g. [cuda:0]) or an empty list opts out of parallelism. Co-Authored-By: Claude Opus 4.8 --- docs/src/generated/settings.json | 6 +-- .../app/services/config/config_default.py | 10 ++--- .../model_manager/model_manager_default.py | 9 ++-- .../session_processor_default.py | 19 ++++----- invokeai/backend/util/devices.py | 28 +++++++++++++ tests/backend/util/test_devices.py | 41 +++++++++++++++++++ 6 files changed, 88 insertions(+), 25 deletions(-) diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index eb26d39960f..35cea553b96 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -492,13 +492,13 @@ }, { "category": "DEVICE", - "default": null, - "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", + "default": "auto", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", "env_var": "INVOKEAI_GENERATION_DEVICES", "literal_values": [], "name": "generation_devices", "required": false, - "type": "typing.Optional[list[str]]", + "type": "typing.Union[typing.Literal['auto'], list[str]]", "validation": {} }, { diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index a70f5f7e97c..4d9755654a3 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -11,7 +11,7 @@ import shutil from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -205,7 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") - generation_devices: Optional[list[str]] = Field(default=None, description="List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)") + generation_devices: Union[Literal["auto"], list[str]] = Field(default="auto", description="Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -260,14 +260,14 @@ class InvokeAIAppConfig(BaseSettings): @field_validator("generation_devices") @classmethod - def validate_generation_devices(cls, v: Optional[list[str]]) -> Optional[list[str]]: - if v is None: + def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, list[str]]: + if v == "auto": return v pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") for device in v: if not pattern.match(device): raise ValueError( - f"Invalid generation device '{device}'. Valid values are 'cpu', 'mps', 'cuda', or 'cuda:N'." + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." ) return v diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index eaeb5d4e612..b7680524a34 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -108,11 +108,10 @@ def build_cache(device: torch.device) -> ModelCache: # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own # cache. The default cache is always included by ModelLoadService. ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} - if app_config.generation_devices: - for device_str in app_config.generation_devices: - key = str(TorchDevice.normalize(device_str)) - if key not in ram_caches: - ram_caches[key] = build_cache(torch.device(key)) + for device in TorchDevice.get_generation_devices(app_config.generation_devices): + key = str(device) + if key not in ram_caches: + ram_caches[key] = build_cache(device) loader = ModelLoadService( app_config=app_config, diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c6d566255b2..c6edb5069f8 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -347,21 +347,16 @@ def __init__( def _resolve_devices(self) -> list[Optional[torch.device]]: """Determine the per-worker devices from config. - Returns a single `None` (legacy single-worker, device chosen by the global config) unless - `generation_devices` is configured, in which case it returns one normalized device per - listed device (deduplicated, order preserved). + Resolves `generation_devices` (which defaults to `"auto"` — every available GPU) into one + normalized device per worker. Returns a single `None` (legacy single-worker, device chosen by + the global config) only if the resolution is empty (e.g. `generation_devices` set to an empty + list). """ generation_devices = self._invoker.services.configuration.generation_devices - if not generation_devices: + devices = TorchDevice.get_generation_devices(generation_devices) + if not devices: return [None] - devices: list[Optional[torch.device]] = [] - seen: set[str] = set() - for device_str in generation_devices: - device = TorchDevice.normalize(device_str) - if str(device) not in seen: - seen.add(str(device)) - devices.append(device) - return devices + return list(devices) def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: """Create an independent runner for an additional worker. diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d912f86a8a3..0511601b557 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -120,6 +120,34 @@ def get_torch_device_name(cls) -> str: device = cls.choose_torch_device() return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + @classmethod + def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) -> list[torch.device]: + """Resolve the configured `generation_devices` into a concrete, deduplicated device list. + + - ``"auto"`` (the default) expands to every visible CUDA device, or the single best available + device (mps/cpu) when CUDA is unavailable. + - An explicit list is normalized and deduplicated, with order preserved. + - ``None`` or an empty list yields an empty list; the caller decides the single-device fallback. + """ + if generation_devices == "auto": + if torch.cuda.is_available(): + device_strs: list[str] = [f"cuda:{index}" for index in range(torch.cuda.device_count())] + else: + device_strs = [str(cls.choose_torch_device())] + elif not generation_devices: + return [] + else: + device_strs = list(generation_devices) + + devices: list[torch.device] = [] + seen: set[str] = set() + for device_str in device_strs: + device = cls.normalize(device_str) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices + @classmethod def normalize(cls, device: Union[str, torch.device]) -> torch.device: """Add the device index to CUDA devices.""" diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 39dee5cb618..aa8433c632e 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -69,6 +69,47 @@ def worker(name: str, device: str): assert TorchDevice.choose_torch_device() == torch.device("cpu") +# ===== generation_devices resolution (config -> concrete device list) ======= + + +def test_get_generation_devices_auto_expands_to_all_cuda(): + """`auto` enumerates every visible CUDA device.""" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=3), + ): + assert TorchDevice.get_generation_devices("auto") == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + +def test_get_generation_devices_auto_without_cuda(): + """`auto` falls back to the single best device when CUDA is unavailable.""" + config = get_config() + config.device = "cpu" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=False), + patch("invokeai.backend.util.devices.torch.backends.mps.is_available", return_value=False), + ): + assert TorchDevice.get_generation_devices("auto") == [torch.device("cpu")] + + +def test_get_generation_devices_explicit_list_is_deduplicated(): + """An explicit list is normalized and deduplicated, preserving order.""" + assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + +@pytest.mark.parametrize("value", [None, []]) +def test_get_generation_devices_empty(value): + """`None` or an empty list resolves to an empty list (caller handles the single-device fallback).""" + assert TorchDevice.get_generation_devices(value) == [] + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with ( From 420978083c9889067aa67902c6f2e9ba86609ca0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:26:18 -0400 Subject: [PATCH 12/33] chore(frontend): typegen+openapi --- invokeai/frontend/web/openapi.json | 15 --------------- invokeai/frontend/web/src/services/api/schema.ts | 5 ----- 2 files changed, 20 deletions(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 852dc866bce..1287ee58865 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -41151,21 +41151,6 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, - "generation_devices": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Generation Devices", - "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)" - }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index ace2e30178e..a80183476bd 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,11 +16506,6 @@ export type components = { * @default auto */ device?: string; - /** - * Generation Devices - * @description List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) - */ - generation_devices?: string[] | null; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. From 914c577679f217c8f70abe3a4c241bc5d61eca58 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:43:11 -0400 Subject: [PATCH 13/33] docs(multi-gpu): add configuration information --- .../docs/configuration/invokeai-yaml.mdx | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/src/content/docs/configuration/invokeai-yaml.mdx b/docs/src/content/docs/configuration/invokeai-yaml.mdx index 987c8eb98a2..6ac56053928 100644 --- a/docs/src/content/docs/configuration/invokeai-yaml.mdx +++ b/docs/src/content/docs/configuration/invokeai-yaml.mdx @@ -114,6 +114,39 @@ Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These These options set the paths of various directories and files used by InvokeAI. Any user-defined paths should be absolute paths. +#### Multi-GPU Generation + +On a machine with more than one GPU, InvokeAI can run several generation sessions at the same time — one per GPU — instead of processing the queue one job at a time. Jobs are distributed fairly across users, so a single user's large batch cannot monopolize every GPU while others wait. + +This is controlled by the `generation_devices` setting: + +```yaml +generation_devices: auto # default value +``` + +| Value | Behavior | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | +| `auto` | Use every available CUDA GPU, running one generation session per GPU concurrently. This is the default. | +| `[cuda:0,cuda:1]` | Use the specific devices listed, one session per device. Useful for reserving a GPU for other work. | +| `[cuda:0]` | Use a single specific device. Generation runs serially, as it did before multi-GPU support. | +| `[]` | Use the first detected device. Generation runs serially, as it did before multi-GPU support. | + +Each entry in the list must be one of `cpu`, `cuda`, `mps`, or `cuda:N`, where `N` is a zero-based device number (`cuda:0` is the first GPU, `cuda:1` the second, and so on). + +```yaml +# Use the first and third GPUs, leaving the second free for other tasks +generation_devices: [cuda:0, cuda:2] +``` + +Notes: + +- On a system without a CUDA GPU, `auto` resolves to the single best available device (`mps` on Apple Silicon, otherwise `cpu`), so generation runs serially. +- Each active GPU gets its own model cache, and model weights are duplicated in system RAM for every device. Running many GPUs in parallel therefore increases RAM usage — ensure you have ample system memory before enabling a large device list. +- Duplicate entries are ignored; `[cuda:0, cuda:0]` is treated as `[cuda:0]`. +- You can restrict which physical GPUs InvokeAI sees with the `CUDA_VISIBLE_DEVICES` environment variable. When set, `auto` only enumerates the visible subset, and `cuda:N` indices refer to positions within that subset. + +During parallel generation, the progress display shows one progress bar per active session, stacked vertically, each disappearing as its session completes. + #### Image Subfolder Strategy By default, generated images are stored in a single flat directory under `outputs/images/`. The `image_subfolder_strategy` setting lets you organize newly-created images into subfolders automatically. You can edit this setting in `invokeai.yaml` or, as an admin user, in the Settings panel. From a928a756f90902e65e00904062256eb761104d3a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:55:25 -0400 Subject: [PATCH 14/33] chore(frontend): typegen + openapi again --- invokeai/frontend/web/openapi.json | 17 +++++++++++++++++ .../frontend/web/src/services/api/schema.ts | 6 ++++++ 2 files changed, 23 insertions(+) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 1287ee58865..dbbe40d5ad4 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -41151,6 +41151,23 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "type": "string", + "const": "auto" + }, + { + "items": { + "type": "string" + }, + "type": "array" + } + ], + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "default": "auto" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index a80183476bd..91393f53a9c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,6 +16506,12 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number) + * @default auto + */ + generation_devices?: "auto" | string[]; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. From 48458c13cf3c973eb6c3b8cf2d53ca43c21d8996 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Jun 2026 14:29:24 -0400 Subject: [PATCH 15/33] feat(settings): add Generation Devices selector to Settings dialog Add a badges UI in the Generation section of the Settings dialog for choosing which devices `generation_devices` should use, modeled on the Log Namespaces toggle UI. Backend: - New `GET /api/v1/app/generation_device_options` endpoint listing the selectable devices (cuda:N with GPU names, or the sole mps/cpu fallback). - Add `generation_devices` to the runtime-config update allowlist with validation rejecting invalid device strings and explicit nulls. Frontend: - New SettingsGenerationDevices component with active/inactive badges. "Auto (all GPUs)" is exclusive; removing the last explicit device reverts to auto. Admin/multiuser gated; notes restart requirement. - Wire into the Generation section; regenerate schema; add en strings. Co-Authored-By: Claude Opus 4.8 --- invokeai/app/api/routers/app_info.py | 60 ++++- invokeai/frontend/web/public/locales/en.json | 4 + .../SettingsGenerationDevices.tsx | 217 ++++++++++++++++++ .../SettingsModal/SettingsModal.tsx | 2 + .../web/src/services/api/endpoints/appInfo.ts | 11 + .../frontend/web/src/services/api/schema.ts | 61 +++++ tests/app/routers/test_app_info.py | 58 +++++ 7 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 832e58f5e24..a8e0c68d781 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,15 +1,16 @@ import locale +import re from enum import Enum from importlib.metadata import distributions from pathlib import Path as FilePath from threading import Lock -from typing import Any +from typing import Any, Literal, Union import torch import yaml from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies @@ -118,6 +119,16 @@ def _remove_nullable_default_from_schema(schema: dict[str, Any]) -> None: schema.update(non_null_schemas[0]) +_GENERATION_DEVICE_PATTERN = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + + +class GenerationDeviceOption(BaseModel): + """A device that may be selected for generation.""" + + device: str = Field(description="The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'") + name: str = Field(description="Human-readable device name") + + class UpdateAppGenerationSettingsRequest(BaseModel): """Writable generation-related app settings.""" @@ -131,14 +142,59 @@ class UpdateAppGenerationSettingsRequest(BaseModel): ge=0, description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", ) + generation_devices: Union[Literal["auto"], list[str]] | None = Field( + default=None, + description="Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI.", + json_schema_extra=_remove_nullable_default_from_schema, + ) + + @field_validator("generation_devices") + @classmethod + def validate_generation_devices( + cls, v: Union[Literal["auto"], list[str], None] + ) -> Union[Literal["auto"], list[str], None]: + if v is None or v == "auto": + return v + for device in v: + if not _GENERATION_DEVICE_PATTERN.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v @model_validator(mode="after") def validate_explicit_nulls(self) -> "UpdateAppGenerationSettingsRequest": if "image_subfolder_strategy" in self.model_fields_set and self.image_subfolder_strategy is None: raise ValueError("image_subfolder_strategy may not be null") + if "generation_devices" in self.model_fields_set and self.generation_devices is None: + raise ValueError("generation_devices may not be null") return self +@app_router.get( + "/generation_device_options", + operation_id="get_generation_device_options", + status_code=200, + response_model=list[GenerationDeviceOption], +) +async def get_generation_device_options() -> list[GenerationDeviceOption]: + """List the devices available for generation, for use with the `generation_devices` setting.""" + options: list[GenerationDeviceOption] = [] + if torch.cuda.is_available(): + for index in range(torch.cuda.device_count()): + device = f"cuda:{index}" + try: + name = torch.cuda.get_device_name(index) + except Exception: + name = device + options.append(GenerationDeviceOption(device=device, name=name)) + elif torch.backends.mps.is_available(): + options.append(GenerationDeviceOption(device="mps", name="Apple MPS")) + else: + options.append(GenerationDeviceOption(device="cpu", name="CPU")) + return options + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 4f6c49ee9a0..45e52babe1a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1843,6 +1843,10 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "generationDevices": "Generation Devices", + "generationDevicesAuto": "Auto (all GPUs)", + "generationDevicesHelp": "Select which devices to use for parallel generation, one session per device. \"Auto\" uses every available GPU. Changes take effect after restarting InvokeAI.", + "generationDevicesSaveFailed": "Failed to save Generation Devices", "imageSubfolderStrategy": "Image Subfolder Strategy", "imageSubfolderStrategyDate": "Date", "imageSubfolderStrategyFlat": "Flat", diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx new file mode 100644 index 00000000000..0aa525c1673 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx @@ -0,0 +1,217 @@ +import { + Flex, + FormControl, + FormHelperText, + FormLabel, + Tag, + TagCloseButton, + Text, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; +import { toast } from 'features/toast/toast'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + useGetGenerationDeviceOptionsQuery, + useGetRuntimeConfigQuery, + useUpdateRuntimeConfigMutation, +} from 'services/api/endpoints/appInfo'; + +const AUTO = 'auto'; + +type GenerationDevicesValue = 'auto' | string[]; + +type DeviceBadge = { + /** The device identifier, or 'auto' for the special "use all GPUs" badge. */ + device: string; + /** The label shown on the badge. */ + label: string; + /** A human-readable description shown on hover (e.g. the GPU model name). */ + tooltip?: string; +}; + +export const SettingsGenerationDevices = memo(() => { + const { t } = useTranslation(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: runtimeConfig } = useGetRuntimeConfigQuery(); + const { data: deviceOptions } = useGetGenerationDeviceOptionsQuery(); + const [updateRuntimeConfig, { isLoading }] = useUpdateRuntimeConfigMutation(); + + const generationDevices: GenerationDevicesValue = runtimeConfig?.config.generation_devices ?? AUTO; + const isAuto = generationDevices === AUTO; + const selectedDevices = useMemo(() => (isAuto ? [] : [...generationDevices]), [generationDevices, isAuto]); + + const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false; + const isDisabled = !runtimeConfig || !canEditRuntimeConfig || isLoading; + + const save = useCallback( + async (value: GenerationDevicesValue) => { + try { + await updateRuntimeConfig({ generation_devices: value }).unwrap(); + } catch { + toast({ + id: 'SETTINGS_GENERATION_DEVICES_SAVE_FAILED', + title: t('settings.generationDevicesSaveFailed'), + status: 'error', + }); + } + }, + [t, updateRuntimeConfig] + ); + + const autoBadge = useMemo(() => ({ device: AUTO, label: t('settings.generationDevicesAuto') }), [t]); + + // The active badges: the `auto` pseudo-device, or the explicitly-selected devices in config order. + const activeBadges = useMemo(() => { + if (isAuto) { + return [autoBadge]; + } + return selectedDevices.map((device) => ({ + device, + label: device, + tooltip: deviceOptions?.find((option) => option.device === device)?.name, + })); + }, [autoBadge, deviceOptions, isAuto, selectedDevices]); + + // The inactive badges: `auto` (when an explicit list is active) plus any unselected devices. + const inactiveBadges = useMemo(() => { + const badges: DeviceBadge[] = []; + if (!isAuto) { + badges.push(autoBadge); + } + for (const option of deviceOptions ?? []) { + if (!selectedDevices.includes(option.device)) { + badges.push({ device: option.device, label: option.device, tooltip: option.name }); + } + } + return badges; + }, [autoBadge, deviceOptions, isAuto, selectedDevices]); + + const onActivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + if (device === AUTO) { + save(AUTO); + return; + } + // Switching from `auto` starts a fresh explicit list; otherwise append to the current selection. + const next = isAuto ? [device] : Array.from(new Set([...selectedDevices, device])); + save(next); + }, + [isAuto, isDisabled, save, selectedDevices] + ); + + const onDeactivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + const next = selectedDevices.filter((d) => d !== device); + // Never leave an empty selection — fall back to `auto`, which is always meaningful. + save(next.length > 0 ? next : AUTO); + }, + [isDisabled, save, selectedDevices] + ); + + return ( + + {t('settings.generationDevices')} + + {activeBadges.map((badge) => ( + + ))} + + {inactiveBadges.length > 0 && ( + + {inactiveBadges.map((badge) => ( + + ))} + + )} + {t('settings.generationDevicesHelp')} + + ); +}); + +SettingsGenerationDevices.displayName = 'SettingsGenerationDevices'; + +type DeviceTagProps = { + badge: DeviceBadge; + isActive: boolean; + isClosable: boolean; + isDisabled: boolean; + onActivate: (device: string) => void; + onDeactivate: (device: string) => void; +}; + +const DeviceTag = memo(({ badge, isActive, isClosable, isDisabled, onActivate, onDeactivate }: DeviceTagProps) => { + const onClick = useCallback(() => { + if (isDisabled) { + return; + } + if (isActive) { + // An active, non-closable badge (the exclusive `auto`) is a no-op when clicked. + if (isClosable) { + onDeactivate(badge.device); + } + } else { + onActivate(badge.device); + } + }, [badge.device, isActive, isClosable, isDisabled, onActivate, onDeactivate]); + + const isInteractive = !isDisabled && (!isActive || isClosable); + + const tag = ( + + + {badge.label} + + {isActive && isClosable && } + + ); + + if (!badge.tooltip) { + return tag; + } + + return ( + + {tag} + + ); +}); + +DeviceTag.displayName = 'DeviceTag'; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index 64478953a37..62604ba0eab 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -28,6 +28,7 @@ import { useRefreshAfterResetModal } from 'features/system/components/SettingsMo import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsModal/SettingsDeveloperLogLevel'; import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces'; +import { SettingsGenerationDevices } from 'features/system/components/SettingsModal/SettingsGenerationDevices'; import { SettingsImageSubfolderStrategySelect } from 'features/system/components/SettingsModal/SettingsImageSubfolderStrategySelect'; import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates'; import { StickyScrollable } from 'features/system/components/StickyScrollable'; @@ -321,6 +322,7 @@ const SettingsModal = (props: { children: ReactElement }) => { + diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index 653f458dde8..d8801fe9845 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -58,6 +58,16 @@ export const appInfoApi = api.injectEndpoints({ }), providesTags: ['AppConfig'], }), + getGenerationDeviceOptions: build.query< + paths['/api/v1/app/generation_device_options']['get']['responses']['200']['content']['application/json'], + void + >({ + query: () => ({ + url: buildAppInfoUrl('generation_device_options'), + method: 'GET', + }), + providesTags: ['FetchOnReconnect'], + }), updateRuntimeConfig: build.mutation< paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'], paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json'] @@ -149,6 +159,7 @@ export const { useGetAppDepsQuery, useGetPatchmatchStatusQuery, useGetRuntimeConfigQuery, + useGetGenerationDeviceOptionsQuery, useGetExternalProviderStatusesQuery, useGetExternalProviderConfigsQuery, useSetExternalProviderConfigMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 91393f53a9c..dd0d454b473 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1652,6 +1652,26 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/app/generation_device_options": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Generation Device Options + * @description List the devices available for generation, for use with the `generation_devices` setting. + */ + get: operations["get_generation_device_options"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/app/runtime_config": { parameters: { query?: never; @@ -12190,6 +12210,22 @@ export type components = { */ password: string; }; + /** + * GenerationDeviceOption + * @description A device that may be selected for generation. + */ + GenerationDeviceOption: { + /** + * Device + * @description The device identifier, e.g. 'cuda:0', 'mps', or 'cpu' + */ + device: string; + /** + * Name + * @description Human-readable device name + */ + name: string; + }; /** * Get Image Mask Bounding Box * @description Gets the bounding box of the given mask image. @@ -30871,6 +30907,11 @@ export type components = { * @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items. */ max_queue_history?: number | null; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI. + */ + generation_devices?: unknown; }; /** * UserDTO @@ -36321,6 +36362,26 @@ export interface operations { }; }; }; + get_generation_device_options: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["GenerationDeviceOption"][]; + }; + }; + }; + }; get_runtime_config: { parameters: { query?: never; diff --git a/tests/app/routers/test_app_info.py b/tests/app/routers/test_app_info.py index da493cee457..96eb23f1342 100644 --- a/tests/app/routers/test_app_info.py +++ b/tests/app/routers/test_app_info.py @@ -225,6 +225,64 @@ def test_update_runtime_config_image_subfolder_strategy_schema() -> None: } +def test_update_runtime_config_persists_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["cuda:0", "cuda:1"]}) + + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == ["cuda:0", "cuda:1"] + + config_path = get_config().config_file_path + file_config = load_and_migrate_config(config_path) + assert file_config.generation_devices == ["cuda:0", "cuda:1"] + assert get_config().generation_devices == ["cuda:0", "cuda:1"] + + # "auto" round-trips back to the default. + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": "auto"}) + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == "auto" + assert get_config().generation_devices == "auto" + + +def test_update_runtime_config_rejects_invalid_generation_device( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["gpu0"]}) + + assert response.status_code == 422 + + +def test_update_runtime_config_rejects_null_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": None}) + + assert response.status_code == 422 + + +def test_get_generation_device_options_lists_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr(app_info.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(app_info.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(app_info.torch.cuda, "get_device_name", lambda index: f"GPU {index}") + + response = client.get("/api/v1/app/generation_device_options") + + assert response.status_code == 200 + assert response.json() == [ + {"device": "cuda:0", "name": "GPU 0"}, + {"device": "cuda:1", "name": "GPU 1"}, + ] + + def test_update_runtime_config_reads_and_writes_yaml_under_config_lock( monkeypatch: Any, mock_invoker: Invoker, client: TestClient ) -> None: From 5100605df4daf4cf384462f329e62746a97f6303 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Jun 2026 15:13:28 -0400 Subject: [PATCH 16/33] feat(settings): boldface the restart notice on Generation Devices Split the restart sentence into its own string and render it bold so users notice that device changes require restarting InvokeAI. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/public/locales/en.json | 3 ++- .../components/SettingsModal/SettingsGenerationDevices.tsx | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 45e52babe1a..c0cf1ee5c46 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1845,7 +1845,8 @@ "generation": "Generation", "generationDevices": "Generation Devices", "generationDevicesAuto": "Auto (all GPUs)", - "generationDevicesHelp": "Select which devices to use for parallel generation, one session per device. \"Auto\" uses every available GPU. Changes take effect after restarting InvokeAI.", + "generationDevicesHelp": "Select which devices to use for parallel generation, one session per device. \"Auto\" uses every available GPU.", + "generationDevicesRestart": "Changes take effect after restarting InvokeAI.", "generationDevicesSaveFailed": "Failed to save Generation Devices", "imageSubfolderStrategy": "Image Subfolder Strategy", "imageSubfolderStrategyDate": "Date", diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx index 0aa525c1673..d71e2f40d58 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx @@ -149,7 +149,12 @@ export const SettingsGenerationDevices = memo(() => { ))}
)} - {t('settings.generationDevicesHelp')} + + {t('settings.generationDevicesHelp')}{' '} + + {t('settings.generationDevicesRestart')} + + ); }); From 200b3a32a413851840ec251583a3323b658fdf30 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Jun 2026 15:23:59 -0400 Subject: [PATCH 17/33] feat(settings): show GPU name in Generation Devices badges Render device badges as "cuda:0 (RTX 3090 #1)" so identical cards can be told apart. Strips the "NVIDIA GeForce" vendor prefix and adds a 1-based "#N" suffix only when multiple cards share a name. The full device name remains available as the badge tooltip. Co-Authored-By: Claude Opus 4.8 --- .../SettingsGenerationDevices.tsx | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx index d71e2f40d58..2980fb85c73 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx @@ -23,6 +23,9 @@ const AUTO = 'auto'; type GenerationDevicesValue = 'auto' | string[]; +/** Drop the verbose vendor prefix so e.g. "NVIDIA GeForce RTX 3090" reads as "RTX 3090". */ +const shortenDeviceName = (name: string): string => name.replace(/^NVIDIA GeForce /, '').replace(/^NVIDIA /, ''); + type DeviceBadge = { /** The device identifier, or 'auto' for the special "use all GPUs" badge. */ device: string; @@ -63,17 +66,41 @@ export const SettingsGenerationDevices = memo(() => { const autoBadge = useMemo(() => ({ device: AUTO, label: t('settings.generationDevicesAuto') }), [t]); + // Build a per-device badge (label + tooltip) keyed by device id, e.g. "cuda:0 (RTX 3090 #1)". + // Cards sharing a name get a 1-based "#N" suffix so identical GPUs can be told apart. + const deviceBadges = useMemo>(() => { + const options = deviceOptions ?? []; + const nameCounts = new Map(); + for (const option of options) { + const name = shortenDeviceName(option.name); + nameCounts.set(name, (nameCounts.get(name) ?? 0) + 1); + } + const ordinals = new Map(); + const badges: Record = {}; + for (const option of options) { + const name = shortenDeviceName(option.name); + const ordinal = (ordinals.get(name) ?? 0) + 1; + ordinals.set(name, ordinal); + const namePart = (nameCounts.get(name) ?? 0) > 1 ? `${name} #${ordinal}` : name; + badges[option.device] = { device: option.device, label: `${option.device} (${namePart})`, tooltip: option.name }; + } + return badges; + }, [deviceOptions]); + + // Fall back to a bare device id when a configured device isn't in the current options (e.g. a + // GPU that's no longer present). + const getDeviceBadge = useCallback( + (device: string): DeviceBadge => deviceBadges[device] ?? { device, label: device }, + [deviceBadges] + ); + // The active badges: the `auto` pseudo-device, or the explicitly-selected devices in config order. const activeBadges = useMemo(() => { if (isAuto) { return [autoBadge]; } - return selectedDevices.map((device) => ({ - device, - label: device, - tooltip: deviceOptions?.find((option) => option.device === device)?.name, - })); - }, [autoBadge, deviceOptions, isAuto, selectedDevices]); + return selectedDevices.map(getDeviceBadge); + }, [autoBadge, getDeviceBadge, isAuto, selectedDevices]); // The inactive badges: `auto` (when an explicit list is active) plus any unselected devices. const inactiveBadges = useMemo(() => { @@ -83,11 +110,11 @@ export const SettingsGenerationDevices = memo(() => { } for (const option of deviceOptions ?? []) { if (!selectedDevices.includes(option.device)) { - badges.push({ device: option.device, label: option.device, tooltip: option.name }); + badges.push(getDeviceBadge(option.device)); } } return badges; - }, [autoBadge, deviceOptions, isAuto, selectedDevices]); + }, [autoBadge, deviceOptions, getDeviceBadge, isAuto, selectedDevices]); const onActivate = useCallback( (device: string) => { From cdcf7dff458fddd365cd87f0777a3e6638a024b8 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Jun 2026 15:59:19 -0400 Subject: [PATCH 18/33] chore(frontend): openapi --- invokeai/frontend/web/openapi.json | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index dbbe40d5ad4..97997071f95 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6431,6 +6431,30 @@ } } }, + "/api/v1/app/generation_device_options": { + "get": { + "tags": ["app"], + "summary": "Get Generation Device Options", + "description": "List the devices available for generation, for use with the `generation_devices` setting.", + "operationId": "get_generation_device_options", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/GenerationDeviceOption" + }, + "type": "array", + "title": "Response Get Generation Device Options" + } + } + } + } + } + } + }, "/api/v1/app/runtime_config": { "get": { "tags": ["app"], @@ -28461,6 +28485,24 @@ "title": "GeneratePasswordResponse", "description": "Response containing a generated password." }, + "GenerationDeviceOption": { + "properties": { + "device": { + "type": "string", + "title": "Device", + "description": "The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'" + }, + "name": { + "type": "string", + "title": "Name", + "description": "Human-readable device name" + } + }, + "type": "object", + "required": ["device", "name"], + "title": "GenerationDeviceOption", + "description": "A device that may be selected for generation." + }, "GetMaskBoundingBoxInvocation": { "category": "mask", "class": "invocation", @@ -70155,6 +70197,10 @@ ], "title": "Max Queue History", "description": "Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items." + }, + "generation_devices": { + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI." } }, "type": "object", From 8cef3cf6ca48ece546f84fccd4617b9891753423 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 11 Jun 2026 22:41:33 -0400 Subject: [PATCH 19/33] feat(multi-gpu): surface per-session GPU number in logs and UI Help users track which CUDA device is processing each session: - Model-load log: "Loaded model ... onto cuda device #N in ..s" - Denoise progress bars: "Denoising (#N)" across all architectures (SD1.5/SDXL, FLUX, FLUX2, Z-Image, Anima, SD3, CogView4) - Progress preview circle: GPU number centered in the ring, via a new `device` field on InvocationProgressEvent (resolved from the worker's thread-local session device) - Session Queue: new "GPU #" column between STATUS and TIME, backed by a `device` column on session_queue (migration_32) recorded when a worker claims an item Adds TorchDevice.get_session_device_label()/get_session_device_index() helpers and a frontend getCudaDeviceIndex() parser (with tests). Shows the number on CUDA only; CPU/MPS show nothing. Co-Authored-By: Claude Opus 4.8 --- invokeai/app/invocations/anima_denoise.py | 6 ++-- invokeai/app/invocations/cogview4_denoise.py | 2 +- invokeai/app/invocations/sd3_denoise.py | 5 ++- invokeai/app/invocations/z_image_denoise.py | 4 +-- invokeai/app/services/events/events_common.py | 12 +++++++ .../session_processor_default.py | 7 ++-- .../session_queue/session_queue_base.py | 4 +-- .../session_queue/session_queue_common.py | 4 +++ .../session_queue/session_queue_sqlite.py | 10 +++--- .../app/services/shared/sqlite/sqlite_util.py | 2 ++ .../migrations/migration_32.py | 36 +++++++++++++++++++ invokeai/backend/flux/denoise.py | 8 +++-- invokeai/backend/flux2/denoise.py | 8 +++-- .../load/model_cache/model_cache.py | 6 +++- .../stable_diffusion/diffusion_backend.py | 5 ++- invokeai/backend/util/devices.py | 17 +++++++++ invokeai/frontend/web/public/locales/en.json | 1 + .../common/util/getCudaDeviceIndex.test.ts | 29 +++++++++++++++ .../web/src/common/util/getCudaDeviceIndex.ts | 13 +++++++ .../ImageViewer/ProgressIndicator2.tsx | 25 +++++++++++-- .../QueueList/QueueItemComponent.tsx | 6 ++++ .../components/QueueList/QueueListHeader.tsx | 1 + .../queue/components/QueueList/constants.ts | 1 + .../frontend/web/src/services/api/schema.ts | 11 ++++++ .../test_session_queue_dequeue_concurrency.py | 20 +++++++++++ 25 files changed, 221 insertions(+), 22 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py create mode 100644 invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts create mode 100644 invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts diff --git a/invokeai/app/invocations/anima_denoise.py b/invokeai/app/invocations/anima_denoise.py index 9fa4b3fb07a..b301e817f9c 100644 --- a/invokeai/app/invocations/anima_denoise.py +++ b/invokeai/app/invocations/anima_denoise.py @@ -608,7 +608,7 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor if driver is not None: user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising (Anima)") + pbar = tqdm(total=total_steps, desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}") for it in driver.iterations(): timestep = torch.tensor( [it.sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype @@ -655,7 +655,9 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor pbar.close() else: # Built-in Euler implementation (default for Anima) - for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"): + for step_idx in tqdm( + range(total_steps), desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}" + ): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index c04210401be..cb06d2b3ff6 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -294,7 +294,7 @@ def _run_diffusion( assert isinstance(transformer, CogView4Transformer2DModel) # Denoising loop - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): t_curr = timesteps[step_idx] sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index f6c90b9690c..10c9080ac5e 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -284,7 +284,10 @@ def _run_diffusion( assert isinstance(transformer, SD3Transformer2DModel) # 6. Denoising loop - for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_idx, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # Expand the latents if we are doing CFG. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # Expand the timestep to match the latent model input. diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index c1e864ea179..c6887840df8 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -569,7 +569,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): sched_timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized sigma (0-1) @@ -686,7 +686,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: pbar.close() else: # Original Euler implementation (default, optimized for Z-Image) - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0c530f9a2f7..c30fa31b75c 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -138,6 +138,10 @@ class InvocationProgressEvent(InvocationEventBase): image: ProgressImage | None = Field( default=None, description="An image representing the current state of the progress" ) + device: str | None = Field( + default=None, + description="The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) @classmethod def build( @@ -148,6 +152,13 @@ def build( percentage: float | None = None, image: ProgressImage | None = None, ) -> "InvocationProgressEvent": + # This is emitted from the session-processor worker thread, which pins its CUDA device via + # TorchDevice.set_session_device(). Resolve that here so the UI can label progress by GPU. + from invokeai.backend.util.devices import TorchDevice + + session_device = TorchDevice.get_session_device() + device = str(session_device) if session_device is not None and session_device.type == "cuda" else None + return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -161,6 +172,7 @@ def build( percentage=percentage, image=image, message=message, + device=device, ) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c6edb5069f8..27c1f2a8632 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -529,8 +529,11 @@ def _process( break # Get the next session to process. dequeue() atomically claims the item, so concurrent - # workers never receive the same item. - worker.queue_item = self._invoker.services.session_queue.dequeue() + # workers never receive the same item. Pass this worker's device so the item is + # tagged with the GPU that ran it (None in single-device/legacy mode). + worker.queue_item = self._invoker.services.session_queue.dequeue( + device=str(worker.device) if worker.device is not None else None + ) if worker.queue_item is None: # The queue was empty, wait for next polling interval or event to try again diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 73acf9c31aa..07f4be1fded 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -31,8 +31,8 @@ class SessionQueueBase(ABC): """Base class for session queue""" @abstractmethod - def dequeue(self) -> Optional[SessionQueueItem]: - """Dequeues the next session queue item.""" + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: + """Dequeues the next session queue item, recording the processing device (e.g. 'cuda:1') if given.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..8e149af3afe 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -262,6 +262,10 @@ class SessionQueueItem(BaseModel): retried_from_item_id: Optional[int] = Field( default=None, description="The item_id of the queue item that this item was retried from" ) + device: Optional[str] = Field( + default=None, + description="The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) session: GraphExecutionState = Field(description="The fully-populated session to be executed") workflow: Optional[WorkflowWithoutID] = Field( default=None, description="The workflow associated with this queue item" diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index f1bcd8c7c5c..aa2ccc371a6 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -216,7 +216,7 @@ async def enqueue_batch( self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result - def dequeue(self) -> Optional[SessionQueueItem]: + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: # Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU) # cannot select and claim the same pending item. `_set_queue_item_status` already no-ops # if the item was concurrently moved to a terminal state (e.g. canceled), so we only need @@ -242,7 +242,8 @@ def dequeue(self) -> Optional[SessionQueueItem]: if result is None: return None queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") + # Record the claiming worker's device so the UI can label the item by GPU. + queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress", device=device) return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: @@ -299,6 +300,7 @@ def _set_queue_item_status( error_type: Optional[str] = None, error_message: Optional[str] = None, error_traceback: Optional[str] = None, + device: Optional[str] = None, ) -> SessionQueueItem: with self._db.transaction() as cursor: cursor.execute( @@ -320,10 +322,10 @@ def _set_queue_item_status( cursor.execute( """--sql UPDATE session_queue - SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ? + SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ?, device = COALESCE(?, device) WHERE item_id = ? """, - (status, error_type, error_message, error_traceback, item_id), + (status, error_type, error_message, error_traceback, device, item_id), ) queue_item = self.get_queue_item(item_id) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 12642610c8c..3e1d5c53f3e 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -34,6 +34,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_32 import build_migration_32 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -85,6 +86,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_29()) migrator.register_migration(build_migration_30()) migrator.register_migration(build_migration_31()) + migrator.register_migration(build_migration_32()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py new file mode 100644 index 00000000000..fe60433463d --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py @@ -0,0 +1,36 @@ +"""Migration 32: Add device column to session_queue table. + +This records which device (e.g. 'cuda:1') processed a queue item, so the UI can show a per-item +GPU number in the Session Queue. Existing rows get NULL (unknown device). +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration32Callback: + """Migration to add a device column to the session_queue table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(session_queue);") + columns = [row[1] for row in cursor.fetchall()] + + if "device" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN device TEXT;") + + +def build_migration_32() -> Migration: + """Builds the migration object for migrating from version 31 to version 32. + + This migration adds a device column to the session_queue table. + """ + return Migration( + from_version=31, + to_version=32, + callback=Migration32Callback(), + ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 0f4cf07ee5b..7b29a58d44f 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -15,6 +15,7 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -95,7 +96,7 @@ def denoise( # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -266,7 +267,10 @@ def denoise( return img # Original Euler implementation (when scheduler is None) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # DyPE: Update step state for timestep-dependent scaling if dype_extension is not None and dype_embedder is not None: dype_extension.update_step_state( diff --git a/invokeai/backend/flux2/denoise.py b/invokeai/backend/flux2/denoise.py index 2ff66236ce8..cd84b14b99d 100644 --- a/invokeai/backend/flux2/denoise.py +++ b/invokeai/backend/flux2/denoise.py @@ -14,6 +14,7 @@ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -118,7 +119,7 @@ def denoise( is_heun = hasattr(scheduler, "state_in_first_order") user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -226,7 +227,10 @@ def denoise( pbar.close() else: # Manual Euler stepping (original behavior) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 2ca8dd44ba2..762bbe167cb 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -619,9 +619,13 @@ def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Option loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0 # Use the model's actual compute_device for logging, not the cache's default model_device = cache_entry.cached_model.compute_device + if model_device.type == "cuda": + device_label = f"cuda device #{model_device.index}" if model_device.index is not None else "cuda device" + else: + device_label = f"{model_device.type} device" self._logger.info( f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto " - f"{model_device.type} device in {(time.time() - start_time):.2f}s. " + f"{device_label} in {(time.time() - start_time):.2f}s. " f"Total model size: {model_total_bytes / MB:.2f}MB, " f"VRAM: {model_cur_vram_bytes / MB:.2f}MB ({loaded_percent:.1%})" ) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..be3800411ad 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -10,6 +10,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager +from invokeai.backend.util.devices import TorchDevice class StableDiffusionBackend: @@ -44,7 +45,9 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa # ext: preview[pre_denoise_loop, priority=low] ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx) - for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 + for ctx.step_index, ctx.timestep in enumerate( # noqa: B020 + tqdm(ctx.inputs.timesteps, desc=f"Denoising{TorchDevice.get_session_device_label()}") + ): # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 0511601b557..7f5d9a96feb 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -69,6 +69,23 @@ def clear_session_device(cls) -> None: if hasattr(cls._session_device, "device"): del cls._session_device.device + @classmethod + def get_session_device_index(cls) -> Optional[int]: + """Return the CUDA index of the calling thread's effective device, or None if not on CUDA. + + Resolves the thread-local session device when a worker has pinned one (multi-GPU), otherwise + falls back to the globally-configured device. Used to annotate logs/progress with the GPU + number so concurrent sessions can be told apart. + """ + device = cls.get_session_device() or cls.choose_torch_device() + return device.index if device.type == "cuda" else None + + @classmethod + def get_session_device_label(cls) -> str: + """Return a ``" (#N)"`` suffix for the calling thread's CUDA device, or ``""`` when not on CUDA.""" + index = cls.get_session_device_index() + return f" (#{index})" if index is not None else "" + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 23a0f70aeed..d6711bac63c 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -443,6 +443,7 @@ "next": "Next", "status": "Status", "total": "Total", + "gpu": "GPU #", "time": "Time", "credits": "Credits", "pending": "Pending", diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts new file mode 100644 index 00000000000..3348ae14a2f --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from 'vitest'; + +import { getCudaDeviceIndex } from './getCudaDeviceIndex'; + +describe('getCudaDeviceIndex', () => { + it('parses the index from a cuda device string', () => { + expect(getCudaDeviceIndex('cuda:0')).toBe(0); + expect(getCudaDeviceIndex('cuda:1')).toBe(1); + expect(getCudaDeviceIndex('cuda:11')).toBe(11); + }); + + it('returns null for non-cuda devices', () => { + expect(getCudaDeviceIndex('cpu')).toBeNull(); + expect(getCudaDeviceIndex('mps')).toBeNull(); + }); + + it('returns null for null/undefined/empty', () => { + expect(getCudaDeviceIndex(null)).toBeNull(); + expect(getCudaDeviceIndex(undefined)).toBeNull(); + expect(getCudaDeviceIndex('')).toBeNull(); + }); + + it('returns null for malformed cuda strings', () => { + expect(getCudaDeviceIndex('cuda')).toBeNull(); + expect(getCudaDeviceIndex('cuda:')).toBeNull(); + expect(getCudaDeviceIndex('cuda:x')).toBeNull(); + expect(getCudaDeviceIndex('cuda:0:0')).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts new file mode 100644 index 00000000000..d4a394b48fc --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts @@ -0,0 +1,13 @@ +/** + * Parse the CUDA device index from a device string (e.g. `"cuda:1"` → `1`). + * + * Returns `null` when the device is null/undefined or is not a CUDA device (e.g. `"cpu"`, `"mps"`). + * Used to label progress previews and queue items with their GPU number in multi-GPU setups. + */ +export const getCudaDeviceIndex = (device: string | null | undefined): number | null => { + if (!device) { + return null; + } + const match = /^cuda:(\d+)$/.exec(device); + return match ? Number(match[1]) : null; +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx index b635c37d804..f5ca94f732d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx @@ -1,18 +1,37 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { CircularProgress, Tooltip } from '@invoke-ai/ui-library'; +import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; import { memo } from 'react'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; const circleStyles: SystemStyleObject = { + // The callers position this circle with `position="absolute"`, which makes it the containing + // block for the absolutely-centered GPU label below. Do NOT set `position` here — an `sx` value + // would override the caller's prop and break the circle's corner anchoring. circle: { transitionProperty: 'none', transitionDuration: '0s', }, }; +// Centered GPU-number label drawn inside the ring (CircularProgressLabel isn't exported by the ui-library). +const labelStyles: SystemStyleObject = { + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + fontSize: '0.6rem', + lineHeight: 1, + fontWeight: 'bold', + color: 'invokeBlue.300', + textShadow: '0 0 3px var(--invoke-colors-base-900)', + pointerEvents: 'none', +}; + export const ProgressIndicator = memo( ({ progressEvent, ...rest }: { progressEvent: S['InvocationProgressEvent'] } & CircularProgressProps) => { + const gpuIndex = getCudaDeviceIndex(progressEvent?.device); return ( + > + {gpuIndex !== null && {gpuIndex}} + ); } diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx index e1c5f4ec973..6d3f773a2e9 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx @@ -1,6 +1,7 @@ import type { ChakraProps, CollapseProps, FlexProps } from '@invoke-ai/ui-library'; import { ButtonGroup, Collapse, Flex, IconButton, Text } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; import { selectCurrentUser } from 'features/auth/store/authSlice'; import QueueStatusBadge from 'features/queue/components/common/QueueStatusBadge'; import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText'; @@ -95,6 +96,8 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { return `${seconds}s`; }, [item]); + const gpuIndex = useMemo(() => getCudaDeviceIndex(item.device), [item.device]); + const isCanceled = useMemo(() => ['canceled', 'completed', 'failed'].includes(item.status), [item.status]); const isFailed = useMemo(() => ['canceled', 'failed'].includes(item.status), [item.status]); const originText = useOriginText(item.origin); @@ -140,6 +143,9 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { + + {gpuIndex !== null ? gpuIndex : '-'} + {executionTime || '-'} diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx index 4cd3397d217..9f6e2fa5458 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx @@ -33,6 +33,7 @@ const QueueListHeader = () => { w={COLUMN_WIDTHS.statusBadge} alignItems="center" /> + None: # Every item is claimed exactly once: no duplicates, none lost. assert len(claimed_ids) == item_count assert len(set(claimed_ids)) == item_count + + +def test_dequeue_records_processing_device(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue(device="cuda:1") + assert item is not None + assert item.device == "cuda:1" + + # The device persists across later status transitions (which pass device=None). + completed = session_queue._set_queue_item_status(item.item_id, "completed") + assert completed.device == "cuda:1" + + +def test_dequeue_without_device_leaves_device_unset(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue() + assert item is not None + assert item.device is None From 1f55a4f8956db4390ad9ca00e29fd4cd47e751c4 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Jun 2026 22:00:56 -0400 Subject: [PATCH 20/33] feat(multi-gpu): show per-device names in startup log and progress circles - Startup log lists each generation device with its GPU number and id, e.g. "Using torch device: [AMD Radeon PRO W7900 #1 (cuda:0), ...]". Single-device setups keep the bare device name. - Canvas progress circles now show the CUDA device index in the center, matching the viewer panel. - Progress-circle tooltips show the device name and number on hover. - Both are hidden when only a single GPU is available. Co-Authored-By: Claude Opus 4.8 (1M context) --- invokeai/app/run_app.py | 2 +- invokeai/backend/util/devices.py | 38 +++++++++++++++++- .../common/hooks/useProgressDeviceLabel.ts | 39 +++++++++++++++++++ .../common/util/getDeviceNameLabels.test.ts | 38 ++++++++++++++++++ .../src/common/util/getDeviceNameLabels.ts | 25 ++++++++++++ .../StagingArea/QueueItemCircularProgress.tsx | 26 +++++++++++-- .../ImageViewer/ProgressIndicator2.tsx | 9 +++-- 7 files changed, 167 insertions(+), 10 deletions(-) create mode 100644 invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts create mode 100644 invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts create mode 100644 invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts diff --git a/invokeai/app/run_app.py b/invokeai/app/run_app.py index febd4f4d4b1..389b61e7347 100644 --- a/invokeai/app/run_app.py +++ b/invokeai/app/run_app.py @@ -41,7 +41,7 @@ def run_app() -> None: from invokeai.app.invocations.load_custom_nodes import load_custom_nodes from invokeai.backend.util.devices import TorchDevice - torch_device_name = TorchDevice.get_torch_device_name() + torch_device_name = TorchDevice.get_generation_devices_summary(app_config.generation_devices) logger.info(f"Using torch device: {torch_device_name}") # Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called. diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 7f5d9a96feb..0055acd1289 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,4 +1,5 @@ import threading +from collections import Counter, defaultdict from typing import Dict, Literal, Optional, Union import torch @@ -131,11 +132,44 @@ def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtyp # CPU / safe fallback return cls._to_dtype("float32") + @classmethod + def get_device_name(cls, device: torch.device) -> str: + """Return the human-readable name for a torch device (e.g. 'AMD Radeon PRO W7900', 'CPU').""" + return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + @classmethod def get_torch_device_name(cls) -> str: """Return the device name for the current torch device.""" - device = cls.choose_torch_device() - return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + return cls.get_device_name(cls.choose_torch_device()) + + @classmethod + def get_generation_devices_summary(cls, generation_devices: Union[str, list[str], None]) -> str: + """Build a human-readable summary of the devices that will be used for generation. + + For a single device, returns just its name (e.g. ``'AMD Radeon PRO W7900'`` or ``'CPU'``). For + multiple devices, returns a bracketed list annotating each with its GPU number and device id, + e.g. ``'[AMD Radeon PRO W7900 #1 (cuda:0), AMD Radeon PRO W7900 #2 (cuda:1)]'``. Identically + named GPUs get a 1-based ``#N`` suffix so they can be told apart; a uniquely named device gets + no suffix. + """ + devices = cls.get_generation_devices(generation_devices) + if not devices: + # Empty resolution (e.g. `generation_devices` set to an empty list) falls back to the + # single globally-configured device. + devices = [cls.choose_torch_device()] + + names = [cls.get_device_name(device) for device in devices] + if len(devices) == 1: + return names[0] + + name_counts = Counter(names) + ordinals: dict[str, int] = defaultdict(int) + parts: list[str] = [] + for device, name in zip(devices, names, strict=True): + ordinals[name] += 1 + label = f"{name} #{ordinals[name]}" if name_counts[name] > 1 else name + parts.append(f"{label} ({device})") + return "[" + ", ".join(parts) + "]" @classmethod def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) -> list[torch.device]: diff --git a/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts b/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts new file mode 100644 index 00000000000..701f7ce1ae9 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useProgressDeviceLabel.ts @@ -0,0 +1,39 @@ +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; +import { getDeviceNameLabels } from 'common/util/getDeviceNameLabels'; +import { useMemo } from 'react'; +import { useGetGenerationDeviceOptionsQuery } from 'services/api/endpoints/appInfo'; + +type ProgressDeviceLabel = { + /** The CUDA device index, shown in the center of the progress circle (e.g. `0`). */ + index: number; + /** The human-readable device name and number, shown on hover (e.g. `"AMD Radeon PRO W7900 #1"`). */ + name: string; +}; + +/** + * Resolve a device string (e.g. `"cuda:0"`) to the GPU index + name used to annotate progress + * previews. + * + * Returns `null` when there is nothing to show: the device is not a CUDA GPU, or only a single GPU + * is available (single-GPU setups show neither the index nor the hover tooltip). + */ +export const useProgressDeviceLabel = (device: string | null | undefined): ProgressDeviceLabel | null => { + const { data: deviceOptions } = useGetGenerationDeviceOptionsQuery(); + + return useMemo(() => { + const index = getCudaDeviceIndex(device); + if (index === null) { + return null; + } + const options = deviceOptions ?? []; + // With a single GPU there is no ambiguity to resolve, so we show nothing. + if (options.length <= 1) { + return null; + } + const name = device ? getDeviceNameLabels(options)[device] : undefined; + if (!name) { + return null; + } + return { index, name }; + }, [device, deviceOptions]); +}; diff --git a/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts new file mode 100644 index 00000000000..7a35d57d4f8 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.test.ts @@ -0,0 +1,38 @@ +import type { S } from 'services/api/types'; +import { describe, expect, it } from 'vitest'; + +import { getDeviceNameLabels } from './getDeviceNameLabels'; + +const opt = (device: string, name: string): S['GenerationDeviceOption'] => ({ device, name }); + +describe('getDeviceNameLabels', () => { + it('adds a 1-based #N suffix to identically-named devices', () => { + const labels = getDeviceNameLabels([opt('cuda:0', 'AMD Radeon PRO W7900'), opt('cuda:1', 'AMD Radeon PRO W7900')]); + expect(labels).toEqual({ + 'cuda:0': 'AMD Radeon PRO W7900 #1', + 'cuda:1': 'AMD Radeon PRO W7900 #2', + }); + }); + + it('does not add a suffix to a uniquely-named device', () => { + const labels = getDeviceNameLabels([opt('cuda:0', 'AMD Radeon PRO W7900')]); + expect(labels).toEqual({ 'cuda:0': 'AMD Radeon PRO W7900' }); + }); + + it('only suffixes the names that are duplicated', () => { + const labels = getDeviceNameLabels([ + opt('cuda:0', 'RTX 4090'), + opt('cuda:1', 'RTX 3090'), + opt('cuda:2', 'RTX 3090'), + ]); + expect(labels).toEqual({ + 'cuda:0': 'RTX 4090', + 'cuda:1': 'RTX 3090 #1', + 'cuda:2': 'RTX 3090 #2', + }); + }); + + it('returns an empty map for no options', () => { + expect(getDeviceNameLabels([])).toEqual({}); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts new file mode 100644 index 00000000000..210e7b88c67 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getDeviceNameLabels.ts @@ -0,0 +1,25 @@ +import type { S } from 'services/api/types'; + +/** + * Build a map of device id (e.g. `"cuda:0"`) → human-readable label (e.g. `"AMD Radeon PRO W7900 #1"`). + * + * Devices that share a name get a 1-based `#N` suffix so identical GPUs can be told apart; a + * uniquely-named device gets no suffix. The ordinal is assigned in the order the options are + * provided (which the backend returns in CUDA-index order). Used to label progress previews with + * the GPU they are rendering on in multi-GPU setups. + */ +export const getDeviceNameLabels = (options: S['GenerationDeviceOption'][]): Record => { + const nameCounts = new Map(); + for (const option of options) { + nameCounts.set(option.name, (nameCounts.get(option.name) ?? 0) + 1); + } + + const ordinals = new Map(); + const labels: Record = {}; + for (const option of options) { + const ordinal = (ordinals.get(option.name) ?? 0) + 1; + ordinals.set(option.name, ordinal); + labels[option.device] = (nameCounts.get(option.name) ?? 0) > 1 ? `${option.name} #${ordinal}` : option.name; + } + return labels; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx index 0a92106a6e6..7095be54aae 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/QueueItemCircularProgress.tsx @@ -1,5 +1,6 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { CircularProgress, Tooltip } from '@invoke-ai/ui-library'; +import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; +import { useProgressDeviceLabel } from 'common/hooks/useProgressDeviceLabel'; import { getProgressMessage } from 'features/controlLayers/components/StagingArea/shared'; import { memo } from 'react'; import type { S } from 'services/api/types'; @@ -16,17 +17,34 @@ const circleStyles: SystemStyleObject = { right: 2, }; +// Centered GPU-number label drawn inside the ring (CircularProgressLabel isn't exported by the ui-library). +const labelStyles: SystemStyleObject = { + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + fontSize: '0.6rem', + lineHeight: 1, + fontWeight: 'bold', + color: 'invokeBlue.300', + textShadow: '0 0 3px var(--invoke-colors-base-900)', + pointerEvents: 'none', +}; + type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & CircularProgressProps; export const QueueItemCircularProgress = memo(({ itemId, status, ...rest }: Props) => { const { progressEvent } = useProgressDatum(itemId); + const deviceLabel = useProgressDeviceLabel(progressEvent?.device); if (status !== 'in_progress') { return null; } + const message = getProgressMessage(progressEvent); + return ( - + + > + {deviceLabel && {deviceLabel.index}} + ); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx index f5ca94f732d..30e7312e2aa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx @@ -1,6 +1,6 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; -import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; +import { useProgressDeviceLabel } from 'common/hooks/useProgressDeviceLabel'; import { memo } from 'react'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; @@ -31,9 +31,10 @@ const labelStyles: SystemStyleObject = { export const ProgressIndicator = memo( ({ progressEvent, ...rest }: { progressEvent: S['InvocationProgressEvent'] } & CircularProgressProps) => { - const gpuIndex = getCudaDeviceIndex(progressEvent?.device); + const deviceLabel = useProgressDeviceLabel(progressEvent?.device); + const message = formatProgressMessage(progressEvent); return ( - + - {gpuIndex !== null && {gpuIndex}} + {deviceLabel && {deviceLabel.index}} ); From a1fde9613a3ed54c1fedc305931f733f8ddf71e1 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Jun 2026 20:27:57 -0400 Subject: [PATCH 21/33] feat(model-cache): share one CPU copy of model weights across per-GPU caches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In multi-GPU mode the model manager builds one ModelCache per generation device, each with storage_device="cpu" and its own RAM-resident copy of every model. A model loaded on N GPUs therefore occupied N copies in RAM, and each cache sized itself against max_cache_ram_gb independently, so RAM use during the text/reference-image encoding phases skyrocketed and the system swapped — worst when two images rendered at once. This deduplicates the CPU-resident weights and makes RAM accounting global. - SharedCpuWeightsStore: process-/manager-global, refcounted store of one canonical CPU state_dict per model key. The first device to load a key registers its weights; subsequent devices adopt the canonical tensors and re-point their module's params at them (load_state_dict(assign=True)), freeing the duplicate. Weights live once in RAM regardless of GPU count; freed only when the last device releases. Per-device modules are kept (params are device-shuffled in place, so two GPUs need two modules), but their CPU-resident params alias the shared tensors. - RamBudget: single system-wide RAM authority. Splits RAM into shared (counted once via the store) and non-shared (per-instance). ModelCache eviction now runs against the global, deduplicated total and re-checks availability each iteration, since evicting a model another device still holds frees no RAM. build_model_manager wires one store + one budget into all device caches; the cap is max_cache_ram_gb as a true system-wide limit, else the sum of per-cache heuristics. Passing ram_budget=None preserves the prior local accounting. - LoRA/patch safety: direct LoRA patching did an in-place copy_ on the weight, which would corrupt the now-shared canonical tensor (and taint keep_ram_copy even with one GPU) when patching a CPU-resident weight. Switched to an out-of-place add (memory- equivalent) so the canonical tensor is never mutated; fixed the FluxControlLoRA expansion path to target the module's live parameter. Sidecar patching and FreeU/Seamless (which patch forward methods) were already safe. Validated on 2x AMD W7900 / ROCm: correct inference on both GPUs from one shared copy (full + partial load + Q8_0 GGUF quantized), concurrent load/unload without corruption, and LoRA isolation across devices. ~40 new tests; existing suites unchanged. Adds scripts/multigpu_ram_driver.py to drive concurrent dual-GPU generations via the queue API and measure peak RSS / leak drift. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../model_manager/model_manager_default.py | 26 ++ .../cached_model_only_full_load.py | 50 ++- .../cached_model_with_partial_load.py | 52 ++- .../load/model_cache/model_cache.py | 93 +++++- .../load/model_cache/ram_budget.py | 64 ++++ .../load/model_cache/shared_cpu_weights.py | 118 +++++++ invokeai/backend/patches/layer_patcher.py | 17 +- scripts/multigpu_ram_driver.py | 295 ++++++++++++++++++ .../test_cached_model_shared_weights.py | 100 ++++++ .../test_model_cache_ram_budget.py | 118 +++++++ .../test_model_cache_shared_weights.py | 86 +++++ .../load/model_cache/test_ram_budget.py | 48 +++ .../model_cache/test_shared_cpu_weights.py | 106 +++++++ .../model_cache/test_shared_weights_gpu.py | 236 ++++++++++++++ .../test_layer_patcher_shared_weights.py | 106 +++++++ 15 files changed, 1502 insertions(+), 13 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/ram_budget.py create mode 100644 invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py create mode 100755 scripts/multigpu_ram_driver.py create mode 100644 tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py create mode 100644 tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py create mode 100644 tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py create mode 100644 tests/backend/model_manager/load/model_cache/test_ram_budget.py create mode 100644 tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py create mode 100644 tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py create mode 100644 tests/backend/patches/test_layer_patcher_shared_weights.py diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b7680524a34..404fc8c72ee 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -17,6 +17,8 @@ from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -86,6 +88,12 @@ def build_model_manager( logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) + # One store + budget shared by every per-device cache. The store deduplicates each model's CPU + # weights to a single copy across GPUs (see SharedCpuWeightsStore); the budget is the single + # system-wide RAM authority so per-device caches stop double-counting shared weights when they + # decide what to evict (see RamBudget). + shared_store = SharedCpuWeightsStore() + def build_cache(device: torch.device) -> ModelCache: return ModelCache( execution_device_working_mem_gb=app_config.device_working_mem_gb, @@ -98,6 +106,7 @@ def build_cache(device: torch.device) -> ModelCache: log_memory_usage=app_config.log_memory_usage, logger=logger, keep_alive_minutes=app_config.model_cache_keep_alive_min, + shared_cpu_weights=shared_store, ) # The default cache for callers without a pinned device (API threads, single-device installs). @@ -113,6 +122,23 @@ def build_cache(device: torch.device) -> ModelCache: if key not in ram_caches: ram_caches[key] = build_cache(device) + # Attach the single global RAM budget. The cap is the user's max_cache_ram_gb interpreted as a + # true system-wide limit; when unset, it is the sum of the caches' individually-calculated + # sizes, so each device keeps its effective capacity and weight deduplication becomes headroom. + gb = 2**30 + distinct_caches = list(dict.fromkeys(ram_caches.values())) + if app_config.max_cache_ram_gb is not None: + global_ram_budget_bytes = int(app_config.max_cache_ram_gb * gb) + else: + global_ram_budget_bytes = sum(c.local_ram_cache_size_bytes for c in distinct_caches) + ram_budget = RamBudget(max_bytes=global_ram_budget_bytes, shared_store=shared_store) + for cache in distinct_caches: + cache.set_ram_budget(ram_budget) + logger.info( + f"Model cache global RAM budget: {global_ram_budget_bytes / gb:.2f} GB " + f"across {len(distinct_caches)} device cache(s)." + ) + loader = ModelLoadService( app_config=app_config, ram_cache=ram_cache, diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index bb04edef9b5..8d1239cdfea 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -2,6 +2,7 @@ import torch +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor @@ -12,7 +13,13 @@ class CachedModelOnlyFullLoad: """ def __init__( - self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False + self, + model: torch.nn.Module | Any, + compute_device: torch.device, + total_bytes: int, + keep_ram_copy: bool = False, + shared_store: SharedCpuWeightsStore | None = None, + cache_key: str | None = None, ): """Initialize a CachedModelOnlyFullLoad. Args: @@ -22,16 +29,34 @@ def __init__( keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is sufficient RAM). + shared_store (SharedCpuWeightsStore | None): If provided (along with cache_key), share a single canonical + CPU copy of the weights across per-device caches instead of one copy per device. + cache_key (str | None): The model cache key used to identify shared weights in `shared_store`. """ # model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. self._model = model self._compute_device = compute_device self._offload_device = torch.device("cpu") + # When set, this model's CPU weights are a shared canonical copy owned by `shared_store` + # under `cache_key`; `release_shared_weights()` must be called exactly once on eviction. + self._shared_store: SharedCpuWeightsStore | None = None + self._shared_key: str | None = None # A CPU read-only copy of the model's state dict. self._cpu_state_dict: dict[str, torch.Tensor] | None = None if isinstance(model, torch.nn.Module) and keep_ram_copy: - self._cpu_state_dict = model.state_dict() + cpu_state_dict = model.state_dict() + # In multi-GPU mode, share one canonical CPU copy across the per-device caches (see + # SharedCpuWeightsStore). If another device already registered this key, re-point our + # module at the shared tensors and drop our duplicate so the weights live once in RAM. + if shared_store is not None and cache_key is not None: + canonical = shared_store.acquire(cache_key, cpu_state_dict) + if canonical is not cpu_state_dict: + model.load_state_dict(canonical, assign=True) + cpu_state_dict = canonical + self._shared_store = shared_store + self._shared_key = cache_key + self._cpu_state_dict = cpu_state_dict self._total_bytes = total_bytes self._is_in_vram = False @@ -45,6 +70,27 @@ def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: # TODO(ryand): Document this better. return self._cpu_state_dict + @property + def uses_shared_weights(self) -> bool: + """True if this model's CPU weights are deduplicated in a SharedCpuWeightsStore. + + When True, its RAM is accounted by the store (counted once across devices); when False, its + RAM is per-instance and must be counted by the RamBudget's non-shared total. + """ + return self._shared_store is not None + + def release_shared_weights(self) -> None: + """Release this model's reference to its shared canonical CPU weights, if any. + + Must be called exactly once when the cache entry is evicted. Idempotent: a second call is a + no-op. After release, the shared store frees the canonical tensors once the last device that + held this key releases it. + """ + if self._shared_store is not None and self._shared_key is not None: + self._shared_store.release(self._shared_key) + self._shared_store = None + self._shared_key = None + def total_bytes(self) -> int: """Get the total size (in bytes) of all the weights in the model.""" return self._total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 328978b45b1..0f9b534d716 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -1,5 +1,6 @@ import torch +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -14,14 +15,40 @@ class CachedModelWithPartialLoad: MPS memory, etc. """ - def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False): + def __init__( + self, + model: torch.nn.Module, + compute_device: torch.device, + keep_ram_copy: bool = False, + shared_store: SharedCpuWeightsStore | None = None, + cache_key: str | None = None, + ): self._model = model self._compute_device = compute_device + # When set, this model's CPU weights are a shared canonical copy owned by `shared_store` + # under `cache_key`; `release_shared_weights()` must be called exactly once on eviction. + self._shared_store: SharedCpuWeightsStore | None = None + self._shared_key: str | None = None model_state_dict = model.state_dict() # A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA # patching. Set to `None` if keep_ram_copy is False. - self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None + cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None + + # In multi-GPU mode, share a single canonical CPU copy of the weights across the per-device + # caches instead of keeping one copy per device (see SharedCpuWeightsStore). If another + # device already registered this key, re-point our module's params at the shared tensors and + # drop our freshly-built duplicate so the weights live once in RAM. + if cpu_state_dict is not None and shared_store is not None and cache_key is not None: + canonical = shared_store.acquire(cache_key, cpu_state_dict) + if canonical is not cpu_state_dict: + self._model.load_state_dict(canonical, assign=True) + model_state_dict = canonical + cpu_state_dict = canonical + self._shared_store = shared_store + self._shared_key = cache_key + + self._cpu_state_dict: dict[str, torch.Tensor] | None = cpu_state_dict # A dictionary of the size of each tensor in the state dict. # HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for @@ -121,6 +148,27 @@ def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: # TODO(ryand): Document this better. return self._cpu_state_dict + @property + def uses_shared_weights(self) -> bool: + """True if this model's CPU weights are deduplicated in a SharedCpuWeightsStore. + + When True, its RAM is accounted by the store (counted once across devices); when False, its + RAM is per-instance and must be counted by the RamBudget's non-shared total. + """ + return self._shared_store is not None + + def release_shared_weights(self) -> None: + """Release this model's reference to its shared canonical CPU weights, if any. + + Must be called exactly once when the cache entry is evicted. Idempotent: a second call is a + no-op. After release, the shared store frees the canonical tensors once the last device that + held this key releases it. + """ + if self._shared_store is not None and self._shared_key is not None: + self._shared_store.release(self._shared_key) + self._shared_store = None + self._shared_key = None + def total_bytes(self) -> int: """Get the total size (in bytes) of all the weights in the model.""" return self._total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 762bbe167cb..69547ee5fb1 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -20,6 +20,11 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import ( + SHARED_CPU_WEIGHTS, + SharedCpuWeightsStore, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( apply_custom_layers_to_model, ) @@ -216,6 +221,8 @@ def __init__( log_memory_usage: bool = False, logger: Optional[Logger] = None, keep_alive_minutes: float = 0, + shared_cpu_weights: SharedCpuWeightsStore | None = SHARED_CPU_WEIGHTS, + ram_budget: RamBudget | None = None, ): """Initialize the model RAM cache. @@ -236,7 +243,15 @@ def __init__( behaviour. :param logger: InvokeAILogger to use (otherwise creates one) :param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely. + :param shared_cpu_weights: Process-global store that lets per-device caches share a single CPU copy of each + model's weights (see SharedCpuWeightsStore). Defaults to the global store so that, in multi-GPU mode, a + model loaded on multiple GPUs occupies RAM only once. Pass None to disable sharing for this cache. + :param ram_budget: Optional shared RamBudget used as the single global RAM authority across all per-device + caches. When provided, eviction decisions are made against the deduplicated, system-wide RAM total rather + than this cache's local (double-counted) sum. When None, the cache uses its own local RAM accounting. """ + self._shared_cpu_weights = shared_cpu_weights + self._ram_budget = ram_budget self._enable_partial_loading = enable_partial_loading self._keep_ram_copy_of_weights = keep_ram_copy_of_weights self._execution_device_working_mem_gb = execution_device_working_mem_gb @@ -302,6 +317,22 @@ def execution_device(self) -> torch.device: """Return the default execution device this cache loads models onto.""" return self._execution_device + def set_ram_budget(self, ram_budget: RamBudget) -> None: + """Attach the shared global RamBudget after construction. + + Used by the model manager once all per-device caches exist and the global cap has been + computed from their individual sizes (see ModelManagerService.build_model_manager). + """ + self._ram_budget = ram_budget + + @property + def local_ram_cache_size_bytes(self) -> int: + """The RAM cache size this cache computed for itself (from max_cache_ram_gb or the heuristic). + + Used by the model manager to seed the global RamBudget cap when no explicit limit is set. + """ + return self._ram_cache_size_bytes + @property @synchronized def stats(self) -> Optional[CacheStats]: @@ -313,9 +344,12 @@ def stats(self) -> Optional[CacheStats]: def stats(self, stats: CacheStats) -> None: """Set the CacheStats object for collecting cache statistics.""" self._stats = stats - # Populate the cache size in the stats object when it's set + # Populate the cache size in the stats object when it's set. Prefer the global budget cap + # (the real system-wide limit) when one is attached. if self._stats is not None: - self._stats.cache_size = self._ram_cache_size_bytes + self._stats.cache_size = ( + self._ram_budget.max_bytes if self._ram_budget is not None else self._ram_cache_size_bytes + ) def _record_activity(self) -> None: """Record model activity and reset the timeout timer if configured. @@ -423,16 +457,30 @@ def put(self, key: str, model: AnyModel, execution_device: Optional[torch.device # Wrap model. if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading: wrapped_model = CachedModelWithPartialLoad( - model, effective_execution_device, keep_ram_copy=self._keep_ram_copy_of_weights + model, + effective_execution_device, + keep_ram_copy=self._keep_ram_copy_of_weights, + shared_store=self._shared_cpu_weights, + cache_key=key, ) else: wrapped_model = CachedModelOnlyFullLoad( - model, effective_execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights + model, + effective_execution_device, + size, + keep_ram_copy=self._keep_ram_copy_of_weights, + shared_store=self._shared_cpu_weights, + cache_key=key, ) cache_record = CacheRecord(key=key, cached_model=wrapped_model) self._cached_models[key] = cache_record self._cache_stack.append(key) + # Account this model's RAM in the global budget. Shared weights are tracked once by the + # SharedCpuWeightsStore; only non-deduplicated models are added to the budget's non-shared + # total (a non-shared model resident on N devices correctly counts N times). + if self._ram_budget is not None and not wrapped_model.uses_shared_weights: + self._ram_budget.add_non_shared(wrapped_model.total_bytes()) self._logger.debug( f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)" ) @@ -774,11 +822,20 @@ def _calc_ram_available_to_model_cache(self) -> int: return ram_available_to_model_cache def _get_ram_in_use(self) -> int: - """Get the amount of RAM currently in use.""" + """Get the amount of RAM currently in use. + + With a shared RamBudget attached, this returns the deduplicated, system-wide total across all + per-device caches (shared model weights counted once). Without one, it returns this cache's + local sum. + """ + if self._ram_budget is not None: + return self._ram_budget.total_in_use() return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values()) def _get_ram_available(self) -> int: """Get the amount of RAM available for the cache to use.""" + if self._ram_budget is not None: + return self._ram_budget.available() return self._ram_cache_size_bytes - self._get_ram_in_use() def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: @@ -917,7 +974,18 @@ def _make_room_internal(self, bytes_needed: int) -> None: ram_bytes_freed = 0 pos = 0 models_cleared = 0 - while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack): + while pos < len(self._cache_stack): + # Stop once there is enough room. With a shared RamBudget, re-check the global, + # deduplicated availability each iteration: evicting a model that other devices still + # hold frees no RAM (its shared weights stay live until the last reference is released), + # so a fixed "bytes freed" tally would be wrong. Without a budget, the local tally is + # exact, so the original cheaper check is kept. + if self._ram_budget is not None: + if bytes_needed <= self._get_ram_available(): + break + elif ram_bytes_freed >= ram_bytes_to_free: + break + model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] @@ -961,8 +1029,21 @@ def _make_room_internal(self, bytes_needed: int) -> None: def _delete_cache_entry(self, cache_entry: CacheRecord) -> None: """Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist.""" + was_present = cache_entry.key in self._cached_models self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key] self._cached_models.pop(cache_entry.key, None) + # Drop this device's reference to the shared canonical CPU weights so they can be freed once + # the last device releases them. Guard on was_present so a double-delete doesn't + # double-release (release_shared_weights is itself idempotent, but a re-added entry under the + # same key must not be released by a stale delete). + if was_present: + uses_shared = cache_entry.cached_model.uses_shared_weights + total_bytes = cache_entry.cached_model.total_bytes() + cache_entry.cached_model.release_shared_weights() + # Drop the matching non-shared contribution from the global budget (shared weights are + # released via the store above). Captured before release_shared_weights() flips the flag. + if self._ram_budget is not None and not uses_shared: + self._ram_budget.remove_non_shared(total_bytes) @synchronized def drop_model(self, model_key: str) -> int: diff --git a/invokeai/backend/model_manager/load/model_cache/ram_budget.py b/invokeai/backend/model_manager/load/model_cache/ram_budget.py new file mode 100644 index 00000000000..6428c646753 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/ram_budget.py @@ -0,0 +1,64 @@ +import threading +from typing import Optional + +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +class RamBudget: + """The single global authority for how much RAM the model caches are actually using. + + In multi-GPU mode there is one `ModelCache` per device. Each cache independently sums the + `total_bytes()` of the models it holds, so a model resident on N devices is counted N times — + even though Phase 1/2 made its CPU weights live only ONCE in RAM (see SharedCpuWeightsStore). + That per-cache double-count makes the caches believe RAM is fuller than it is, causing premature + eviction and reload churn, and makes `max_cache_ram_gb` meaningless as a system-wide cap. + + RamBudget fixes the accounting by separating RAM into two non-overlapping parts: + + - Shared weights: model weights that are deduplicated in the SharedCpuWeightsStore. Counted + exactly once via `store.total_bytes_in_use()`, regardless of how many devices hold them. + - Non-shared RAM: models that are NOT deduplicated (keep_ram_copy disabled, or non-Module + models whose single in-RAM copy is per-device). These are tracked here as an explicit running + total; a model resident on N devices contributes N times, which is correct because it really + does occupy N copies of RAM. + + `total_in_use()` is the sum of the two and reflects the true RAM footprint. All per-device caches + share one RamBudget and make their eviction decisions against it. + + Thread-safety / lock ordering: RamBudget guards its own counter with an internal lock and NEVER + acquires a ModelCache lock (it only reads the store, which has its own lock). Callers update it + while holding their cache lock, so the only lock order is cache-lock -> (store-lock | budget-lock), + never the reverse — so it cannot deadlock against the per-device caches. + """ + + def __init__(self, max_bytes: int, shared_store: Optional[SharedCpuWeightsStore]): + self._max_bytes = max_bytes + self._store = shared_store + self._non_shared_bytes = 0 + self._lock = threading.Lock() + + @property + def max_bytes(self) -> int: + """The global cap on actual model-cache RAM, in bytes.""" + return self._max_bytes + + def add_non_shared(self, nbytes: int) -> None: + """Record `nbytes` of newly-resident non-deduplicated model RAM.""" + with self._lock: + self._non_shared_bytes += nbytes + + def remove_non_shared(self, nbytes: int) -> None: + """Record the release of `nbytes` of non-deduplicated model RAM.""" + with self._lock: + self._non_shared_bytes = max(0, self._non_shared_bytes - nbytes) + + def total_in_use(self) -> int: + """The true total RAM used by the model caches: shared weights (counted once) + non-shared.""" + shared = self._store.total_bytes_in_use() if self._store is not None else 0 + with self._lock: + non_shared = self._non_shared_bytes + return shared + non_shared + + def available(self) -> int: + """Bytes remaining under the global cap (may be negative if over budget).""" + return self._max_bytes - self.total_in_use() diff --git a/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py new file mode 100644 index 00000000000..4b1c634a25b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py @@ -0,0 +1,118 @@ +import threading +from dataclasses import dataclass, field + +import torch + +from invokeai.backend.util.calc_tensor_size import calc_tensor_size + + +@dataclass +class _SharedWeightsEntry: + """A single canonical CPU state dict shared across per-device caches.""" + + state_dict: dict[str, torch.Tensor] + total_bytes: int + # Number of per-device cached models currently aliasing this entry. The entry is freed + # (its RAM released) when this drops to zero. + refcount: int = 0 + _key_bytes: dict[str, int] = field(default_factory=dict) + + +class SharedCpuWeightsStore: + """Process-global store of canonical CPU weight tensors, shared across per-device model caches. + + In multi-GPU mode there is one `ModelCache` per generation device. Without coordination each + cache keeps its own CPU copy of every model's weights, so a model loaded on N GPUs occupies N + copies in RAM. The cached-model wrappers cannot simply share a single `torch.nn.Module`, because + loading to VRAM mutates a module's parameters in place (`load_state_dict(assign=True)` / `.to`), + and two GPUs running the same model concurrently need their params on two different devices at + once. The CPU weight tensors, however, are read-only and device-agnostic, so they can be shared. + + This store keeps a single canonical CPU `state_dict` per cache key. The first device to load a + key registers its freshly-built state dict as canonical; subsequent devices `acquire()` the + canonical and re-point their own module's CPU parameters at the shared tensors (via + `load_state_dict(assign=True)`), discarding their private duplicate. The result: model weights + live once in RAM regardless of how many GPUs hold the model. + + Lifetime is reference-counted. Each per-device cached model that adopts an entry must call + `release()` exactly once when it is evicted; the canonical tensors are dropped only when the + last device releases them. + + Thread-safety: `acquire()`/`release()` are guarded by an internal lock. Note that model + construction (where `acquire()` is normally called) is already serialized process-globally by + `MODEL_LOAD_LOCK.write_lock()`; the internal lock here additionally protects `release()`, which + runs under a per-cache lock off the global construction lock. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._entries: dict[str, _SharedWeightsEntry] = {} + + def acquire(self, key: str, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Adopt the canonical CPU state dict for `key`, registering `state_dict` as canonical if + this is the first acquire. + + Increments the entry's refcount. The caller MUST pair every `acquire()` with exactly one + `release()`. + + Returns: + The canonical state dict. If this call registered the entry, the returned object is the + same `state_dict` that was passed in (the caller keeps using its own tensors). Otherwise + it is the previously-registered canonical dict, and the caller is responsible for + re-pointing its module at these tensors and dropping the `state_dict` it passed in. + """ + with self._lock: + entry = self._entries.get(key) + if entry is None: + entry = _SharedWeightsEntry( + state_dict=state_dict, + total_bytes=sum(calc_tensor_size(v) for v in state_dict.values()), + ) + self._entries[key] = entry + entry.refcount += 1 + return entry.state_dict + + def release(self, key: str) -> None: + """Release one reference to `key`'s canonical state dict, freeing it when the count hits 0. + + A `release()` for a key that is not present is a no-op (e.g. a cached model that never + acquired shared weights, or a double eviction guard). + """ + with self._lock: + entry = self._entries.get(key) + if entry is None: + return + entry.refcount -= 1 + if entry.refcount <= 0: + del self._entries[key] + + # -- Introspection / accounting (also used by tests) ---------------------------------------- + + def __contains__(self, key: str) -> bool: + with self._lock: + return key in self._entries + + def refcount(self, key: str) -> int: + """Return the current refcount for `key`, or 0 if not present.""" + with self._lock: + entry = self._entries.get(key) + return entry.refcount if entry is not None else 0 + + def total_bytes_in_use(self) -> int: + """Return the total size (in bytes) of all canonical state dicts currently held. + + This counts each shared model's weights exactly once, regardless of how many devices alias + it — i.e. the true RAM footprint of cached weights, not the per-device double-count. + """ + with self._lock: + return sum(entry.total_bytes for entry in self._entries.values()) + + def keys(self) -> list[str]: + with self._lock: + return list(self._entries.keys()) + + +# Process-global default store. Per-device caches share this instance so that the same model loaded +# on multiple GPUs keeps a single CPU copy. Tests may construct isolated `SharedCpuWeightsStore` +# instances instead. +SHARED_CPU_WEIGHTS = SharedCpuWeightsStore() diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index fbfcd04de20..e14e35baa08 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -216,7 +216,10 @@ def _apply_model_layer_patch( param_name, torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad), ) - module_param = expanded_weight + # Point at the module's live (expanded) parameter so the out-of-place weight + # update below lands on the module. `expanded_weight` is a detached raw tensor; + # reassigning its `.data` would not propagate to the newly-set Parameter. + module_param = module_to_patch.get_parameter(param_name) else: # For other LoRAs, shape mismatch indicates architecture incompatibility - skip the layer logger = InvokeAILogger.get_logger(LayerPatcher.__name__) @@ -227,9 +230,17 @@ def _apply_model_layer_patch( ) continue - # Convert param_weight to the correct device and dtype, then apply to model weights + # Convert param_weight to the correct device and dtype, then apply to model weights. param_weight_converted = param_weight.to(device=device, dtype=dtype) - module_param.data.copy_(module_param.data + param_weight_converted) + # Apply out-of-place (assign a new tensor) rather than an in-place `copy_`. The weight we + # are patching may be the model's canonical CPU copy, which is shared across the + # per-device model caches in multi-GPU mode (see SharedCpuWeightsStore) and is also the + # cache's keep_ram_copy used to restore the model after unpatching. An in-place mutation + # here would corrupt that shared/cached tensor — and every other device's view of it. + # Reassigning `.data` leaves the original tensor untouched while giving this module the + # patched weights, and is memory-equivalent (the in-place form already allocated the + # `module_param.data + param_weight_converted` temporary). + module_param.data = module_param.data + param_weight_converted patch.to(device=TorchDevice.CPU_DEVICE) diff --git a/scripts/multigpu_ram_driver.py b/scripts/multigpu_ram_driver.py new file mode 100755 index 00000000000..2da725debde --- /dev/null +++ b/scripts/multigpu_ram_driver.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +"""Driver to exercise the multi-GPU shared-RAM model cache under real, concurrent generations. + +It repeatedly enqueues N batches at once (so the multi-GPU session processor runs them in parallel +across devices), polls the queue until each round drains, and samples the InvokeAI server process's +RAM (RSS) the whole time. It then reports: + + - baseline (idle) RSS, + - peak RSS during generation (this is the text/reference-encode spike you care about), and + - idle RSS after each round -> a leak verdict (does RAM return to baseline, or creep up?). + +This automates the two manual checks from the test plan: + #1 "dual concurrent encode RAM" -> run with --rounds 1 --pairs <#gpus> and read the peak. + #5 "leak check over many gens" -> run with --rounds 25+ and read the idle drift. + +------------------------------------------------------------------------------------------------ +Getting a batch file +------------------------------------------------------------------------------------------------ +The script needs the exact body InvokeAI's UI sends to enqueue a generation. Easiest way to capture +it: + 1. Open InvokeAI in the browser with devtools -> Network open. + 2. Click Invoke once. + 3. Find the POST to `.../queue/default/enqueue_batch`, copy its JSON request body, save to a file + (e.g. batch.json). It looks like {"prepend": false, "batch": {"graph": {...}, "runs": 1}}. + +The script bust the node cache by default (sets use_cache=false on every node and randomizes any +"seed" fields) so every submission actually runs the model instead of returning a cached result. + +------------------------------------------------------------------------------------------------ +Examples +------------------------------------------------------------------------------------------------ + # Headline dual-GPU encode RAM (2 GPUs -> 2 concurrent jobs), one round: + python scripts/multigpu_ram_driver.py --graph batch.json --pairs 2 --rounds 1 + + # Leak soak: 30 rounds of 2 concurrent jobs, save timeline for plotting: + python scripts/multigpu_ram_driver.py --graph batch.json --pairs 2 --rounds 30 --csv ram.csv + + # If PID auto-detection fails, point it at the server explicitly: + python scripts/multigpu_ram_driver.py --graph batch.json --pid 12345 +""" + +from __future__ import annotations + +import argparse +import copy +import json +import random +import sys +import threading +import time +import urllib.error +import urllib.parse +import urllib.request +from dataclasses import dataclass, field + +import psutil + + +# -------------------------------------------------------------------------------------------------- +# HTTP helpers (stdlib only) +# -------------------------------------------------------------------------------------------------- +def _request(method: str, url: str, body: dict | None = None, timeout: float = 60.0) -> dict: + data = json.dumps(body).encode() if body is not None else None + req = urllib.request.Request( + url, data=data, headers={"Content-Type": "application/json"}, method=method + ) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read() + return json.loads(raw) if raw else {} + except urllib.error.HTTPError as e: + detail = e.read().decode(errors="replace") + raise SystemExit(f"HTTP {e.code} on {method} {url}\n{detail}") from e + except urllib.error.URLError as e: + raise SystemExit(f"Could not reach {url}: {e.reason}. Is the server running?") from e + + +def enqueue(base: str, queue_id: str, body: dict) -> dict: + return _request("POST", f"{base}/api/v1/queue/{queue_id}/enqueue_batch", body) + + +def queue_counts(base: str, queue_id: str) -> tuple[int, int]: + """Return (pending, in_progress), searching the response defensively for those keys.""" + resp = _request("GET", f"{base}/api/v1/queue/{queue_id}/status") + # The status payload nests the queue counts under "queue"; fall back to top-level. + node = resp.get("queue", resp) if isinstance(resp, dict) else {} + return int(node.get("pending", 0)), int(node.get("in_progress", 0)) + + +# -------------------------------------------------------------------------------------------------- +# Batch preparation +# -------------------------------------------------------------------------------------------------- +def normalize_body(loaded: dict) -> dict: + """Accept either the full {"prepend":..., "batch": {...}} body or a bare Batch ({"graph":...}).""" + if "batch" in loaded: + return copy.deepcopy(loaded) + if "graph" in loaded: + return {"prepend": False, "batch": copy.deepcopy(loaded)} + raise SystemExit("Batch file must contain either a top-level 'batch' or 'graph' key.") + + +def bust_cache(body: dict, mutate_seed: bool, disable_cache: bool) -> dict: + """Return a copy of the body with the node cache busted so the submission really computes.""" + body = copy.deepcopy(body) + nodes = body.get("batch", {}).get("graph", {}).get("nodes", {}) + if not isinstance(nodes, dict): + return body + for node in nodes.values(): + if not isinstance(node, dict): + continue + if disable_cache: + node["use_cache"] = False + if mutate_seed and "seed" in node: + node["seed"] = random.randint(0, 2**31 - 1) + return body + + +# -------------------------------------------------------------------------------------------------- +# Process discovery + RSS sampling +# -------------------------------------------------------------------------------------------------- +def find_server_pid(port: int) -> int: + """Best-effort: find the PID listening on `port`, else a process whose cmdline looks like the server.""" + for conn in psutil.net_connections(kind="inet"): + if conn.laddr and conn.laddr.port == port and conn.pid: + return conn.pid + needles = ("invokeai-web", "invokeai.app.run_app", "invokeai_web", "uvicorn") + for proc in psutil.process_iter(["pid", "cmdline"]): + cmd = " ".join(proc.info.get("cmdline") or []) + if any(n in cmd for n in needles): + return proc.info["pid"] + raise SystemExit( + f"Could not auto-detect the InvokeAI server PID on port {port}. Pass --pid explicitly." + ) + + +def tree_rss(proc: psutil.Process, use_uss: bool) -> int: + """RSS (or USS) of the process and its children, in bytes.""" + procs = [proc] + proc.children(recursive=True) + total = 0 + for p in procs: + try: + if use_uss: + total += p.memory_full_info().uss + else: + total += p.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + return total + + +@dataclass +class Sampler: + proc: psutil.Process + hz: float + use_uss: bool + samples: list[tuple[float, int]] = field(default_factory=list) + _stop: threading.Event = field(default_factory=threading.Event) + _thread: threading.Thread | None = None + + def start(self) -> None: + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _run(self) -> None: + period = 1.0 / self.hz + while not self._stop.is_set(): + self.samples.append((time.monotonic(), tree_rss(self.proc, self.use_uss))) + time.sleep(period) + + def stop(self) -> None: + self._stop.set() + if self._thread: + self._thread.join(timeout=2.0) + + def current(self) -> int: + return self.samples[-1][1] if self.samples else tree_rss(self.proc, self.use_uss) + + def peak_between(self, t0: float, t1: float) -> int: + vals = [rss for t, rss in self.samples if t0 <= t <= t1] + return max(vals) if vals else 0 + + +# -------------------------------------------------------------------------------------------------- +# Round loop +# -------------------------------------------------------------------------------------------------- +GB = 1024**3 + + +def gb(n: int) -> float: + return n / GB + + +def wait_drained(base: str, queue_id: str, timeout: float) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + pending, in_progress = queue_counts(base, queue_id) + if pending == 0 and in_progress == 0: + return + time.sleep(0.5) + raise SystemExit(f"Queue did not drain within {timeout}s. Aborting.") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--graph", required=True, help="Path to a captured enqueue_batch body (JSON).") + ap.add_argument("--url", default="http://127.0.0.1:9090", help="Server base URL.") + ap.add_argument("--queue-id", default="default") + ap.add_argument("--pairs", type=int, default=2, help="Concurrent batches per round (>= #GPUs).") + ap.add_argument("--rounds", type=int, default=1, help="Number of rounds (use 25+ for leak soak).") + ap.add_argument("--pid", type=int, default=None, help="Server PID (auto-detected if omitted).") + ap.add_argument("--hz", type=float, default=10.0, help="RSS sampling rate.") + ap.add_argument("--uss", action="store_true", help="Sample USS instead of RSS (more accurate, slower).") + ap.add_argument("--settle", type=float, default=4.0, help="Seconds to wait after each round for RAM to release.") + ap.add_argument("--timeout", type=float, default=1800.0, help="Per-round drain timeout (s).") + ap.add_argument("--warmup", action="store_true", help="Run one un-measured round first (loads models from disk).") + ap.add_argument("--keep-cache", action="store_true", help="Do NOT set use_cache=false on nodes.") + ap.add_argument("--no-seed-mutate", action="store_true", help="Do NOT randomize node 'seed' fields.") + ap.add_argument("--csv", default=None, help="Write the full (t, rss_gb) timeline here.") + args = ap.parse_args() + + with open(args.graph) as f: + body = normalize_body(json.load(f)) + + base = args.url.rstrip("/") + port = urllib.parse.urlparse(base).port or 9090 + pid = args.pid or find_server_pid(port) + proc = psutil.Process(pid) + print(f"Server PID {pid}: {' '.join(proc.cmdline()[:3])} ...") + print(f"Metric: {'USS' if args.uss else 'RSS'} (process tree) | pairs/round={args.pairs} rounds={args.rounds}") + + def submit_round() -> tuple[float, float]: + t0 = time.monotonic() + for _ in range(args.pairs): + prepared = bust_cache(body, mutate_seed=not args.no_seed_mutate, disable_cache=not args.keep_cache) + res = enqueue(base, args.queue_id, prepared) + if res.get("enqueued", 0) < 1: + raise SystemExit(f"Enqueue returned nothing useful: {res}") + wait_drained(base, args.queue_id, args.timeout) + return t0, time.monotonic() + + sampler = Sampler(proc=proc, hz=args.hz, use_uss=args.uss) + sampler.start() + try: + if args.warmup: + print("Warmup round (not measured)...") + submit_round() + time.sleep(args.settle) + + time.sleep(2.0) # settle before baseline + baseline = sampler.current() + print(f"\nBaseline idle {('USS' if args.uss else 'RSS')}: {gb(baseline):.2f} GB\n") + print(f"{'round':>5} {'peak_GB':>9} {'Δpeak_GB':>9} {'idle_after_GB':>14} {'Δidle_GB':>9}") + + idle_after_first = None + overall_peak = baseline + for r in range(1, args.rounds + 1): + t0, t1 = submit_round() + peak = sampler.peak_between(t0, t1) + overall_peak = max(overall_peak, peak) + time.sleep(args.settle) + idle_after = sampler.current() + if idle_after_first is None: + idle_after_first = idle_after + print( + f"{r:>5} {gb(peak):>9.2f} {gb(peak - baseline):>9.2f} " + f"{gb(idle_after):>14.2f} {gb(idle_after - baseline):>9.2f}" + ) + finally: + sampler.stop() + + # Summary + idle_drift = (sampler.current() - (idle_after_first or baseline)) + print("\n--- Summary ---") + print(f"Baseline idle: {gb(baseline):.2f} GB") + print(f"Overall peak: {gb(overall_peak):.2f} GB (Δ {gb(overall_peak - baseline):+.2f} GB over baseline)") + print(f"Idle drift (leak): {gb(idle_drift):+.2f} GB across {args.rounds} rounds") + verdict = "LIKELY LEAK" if idle_drift > 0.5 * GB else "no leak detected" + print(f"Leak verdict: {verdict} (threshold 0.50 GB)") + print("Interpretation: peak Δ should be ~1x the encoder size (not Nx). Idle drift should be ~0.") + + if args.csv: + t_start = sampler.samples[0][0] if sampler.samples else 0.0 + with open(args.csv, "w") as f: + f.write("t_seconds,rss_gb\n") + for t, rss in sampler.samples: + f.write(f"{t - t_start:.3f},{gb(rss):.4f}\n") + print(f"\nTimeline written to {args.csv} ({len(sampler.samples)} samples).") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted.", file=sys.stderr) + sys.exit(130) diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py new file mode 100644 index 00000000000..9969c50fb96 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py @@ -0,0 +1,100 @@ +"""Tests for sharing a single canonical CPU copy of model weights across per-device cached models. + +These exercise the multi-GPU RAM-dedup path: two cached models built for the same cache key (as +would happen on two GPUs) must end up aliasing one set of CPU tensors instead of holding two +copies. They run on CPU — the wrapper constructors never touch VRAM, so no GPU is required. +""" + +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + +CPU = torch.device("cpu") + + +def _data_ptrs(state_dict: dict[str, torch.Tensor]) -> dict[str, int]: + return {k: v.data_ptr() for k, v in state_dict.items()} + + +def test_partial_load_shares_cpu_weights_across_devices(): + store = SharedCpuWeightsStore() + # Two independently-initialised modules (distinct weights), as two devices would build. + model_a = DummyModule() + model_b = DummyModule() + a_ptrs = _data_ptrs(model_a.state_dict()) + + cached_a = CachedModelWithPartialLoad(model_a, CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_b = CachedModelWithPartialLoad(model_b, CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + + # Both cached models expose the SAME canonical CPU tensors. + assert cached_a.get_cpu_state_dict() is cached_b.get_cpu_state_dict() + assert _data_ptrs(cached_b.get_cpu_state_dict()) == a_ptrs + + # model_b's own parameters were re-pointed at the canonical tensors (b's originals are gone). + assert _data_ptrs(model_b.state_dict()) == a_ptrs + + assert store.refcount("m") == 2 + # Counted once despite two devices holding it. + assert store.total_bytes_in_use() == cached_a.total_bytes() + + +def test_full_load_shares_cpu_weights_across_devices(): + store = SharedCpuWeightsStore() + model_a = DummyModule() + model_b = DummyModule() + a_ptrs = _data_ptrs(model_a.state_dict()) + + cached_a = CachedModelOnlyFullLoad(model_a, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_b = CachedModelOnlyFullLoad(model_b, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m") + + assert cached_a.get_cpu_state_dict() is cached_b.get_cpu_state_dict() + assert _data_ptrs(model_b.state_dict()) == a_ptrs + assert store.refcount("m") == 2 + + +def test_release_shared_weights_frees_at_last_reference(): + store = SharedCpuWeightsStore() + cached_a = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_b = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + assert store.refcount("m") == 2 + + cached_a.release_shared_weights() + assert store.refcount("m") == 1 + assert "m" in store + + cached_b.release_shared_weights() + assert "m" not in store + assert store.total_bytes_in_use() == 0 + + +def test_release_shared_weights_is_idempotent(): + store = SharedCpuWeightsStore() + cached = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + cached.release_shared_weights() + cached.release_shared_weights() # second call must not double-decrement + assert store.refcount("m") == 0 + assert "m" not in store + + +def test_no_store_means_no_sharing_and_no_release_error(): + # Without a shared store, behaviour is unchanged: each model keeps its own CPU state dict. + model = DummyModule() + cached = CachedModelWithPartialLoad(model, CPU, keep_ram_copy=True) + assert cached.get_cpu_state_dict() is not None + # release is a safe no-op when nothing was shared. + cached.release_shared_weights() + + +def test_keep_ram_copy_false_does_not_touch_store(): + store = SharedCpuWeightsStore() + cached = CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=False, shared_store=store, cache_key="m") + assert cached.get_cpu_state_dict() is None + assert "m" not in store + assert store.refcount("m") == 0 diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py new file mode 100644 index 00000000000..ad248c75716 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py @@ -0,0 +1,118 @@ +"""End-to-end tests of the global RamBudget driving eviction across per-device caches. + +Validates that the budget counts a shared model once (not once-per-GPU), counts non-deduplicated +models per-instance, and that eviction is made against the global deduplicated total — including the +case where a cache cannot free RAM because another device still holds the model. Runs on CPU. +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from invokeai.backend.util.calc_tensor_size import calc_tensor_size +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + +# Persistent state-dict bytes of one DummyModule (what the shared store accounts for a shared model). +S = sum(calc_tensor_size(v) for v in DummyModule().state_dict().values()) + + +@pytest.fixture +def mock_logger(): + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store, budget, logger, keep_ram_copy=True) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=False, + keep_ram_copy_of_weights=keep_ram_copy, + execution_device="cpu", + storage_device="cpu", + logger=logger, + shared_cpu_weights=store, + ram_budget=budget, + ) + + +def test_shared_model_counts_once_in_global_budget(mock_logger): + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=10**12, shared_store=store) + cache_a = _make_cache(store, budget, mock_logger) + cache_b = _make_cache(store, budget, mock_logger) + try: + cache_a.put("m", DummyModule()) + one_device = budget.total_in_use() + assert one_device == S + + cache_b.put("m", DummyModule()) + # Second device shares the weights -> the global budget total does NOT grow. + assert budget.total_in_use() == one_device + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_non_shared_model_counts_per_device(mock_logger): + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=10**12, shared_store=store) + # keep_ram_copy=False -> not deduplicated, so each device's copy is real RAM. + cache_a = _make_cache(store, budget, mock_logger, keep_ram_copy=False) + cache_b = _make_cache(store, budget, mock_logger, keep_ram_copy=False) + try: + cache_a.put("m", DummyModule()) + one = budget.total_in_use() + assert one > 0 + cache_b.put("m", DummyModule()) + # Two independent copies -> counted twice. + assert budget.total_in_use() == 2 * one + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_global_budget_evicts_lru_in_single_cache(mock_logger): + # Budget fits one model but not two -> putting the second evicts the first. + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=int(S * 1.4), shared_store=store) + cache = _make_cache(store, budget, mock_logger) + try: + cache.put("a", DummyModule()) + cache.put("b", DummyModule()) + assert "a" not in cache._cached_models # evicted to make room for b + assert "b" in cache._cached_models + assert "a" not in store and store.refcount("b") == 1 + assert budget.total_in_use() == S + finally: + cache.shutdown() + + +def test_eviction_cannot_free_ram_held_by_another_device(mock_logger): + """If a cache's only droppable model is still held by another device, eviction frees nothing + globally (the shared weights stay live) and the new model is still admitted -> transiently over + budget until the other device releases. The eviction loop must handle this without spinning.""" + store = SharedCpuWeightsStore() + budget = RamBudget(max_bytes=int(S * 1.4), shared_store=store) + cache_a = _make_cache(store, budget, mock_logger) + cache_b = _make_cache(store, budget, mock_logger) + try: + cache_a.put("shared", DummyModule()) + cache_b.put("shared", DummyModule()) # both devices hold "shared" (refcount 2, counted once) + assert budget.total_in_use() == S + + cache_a.put("new", DummyModule()) # triggers make_room; "shared" is a's only droppable entry + # a dropped its ref to "shared", but b still holds it, so the shared weights weren't freed. + assert "shared" not in cache_a._cached_models + assert "shared" in cache_b._cached_models + assert store.refcount("shared") == 1 + assert "new" in cache_a._cached_models + # "shared" (still alive via b) + "new" -> over the 1.4*S cap, as expected. + assert budget.total_in_use() == 2 * S + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py b/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py new file mode 100644 index 00000000000..1b6eee525ab --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_shared_weights.py @@ -0,0 +1,86 @@ +"""End-to-end test of CPU-weight sharing through ModelCache.put()/eviction. + +Simulates the multi-GPU topology — one ModelCache per device, all sharing a single +SharedCpuWeightsStore — and asserts that the same model loaded into both caches keeps exactly one +CPU copy, with RAM freed only when the last device evicts it. Runs on CPU (no VRAM moves). +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule + + +@pytest.fixture +def mock_logger(): + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store: SharedCpuWeightsStore, logger: MagicMock) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=False, + keep_ram_copy_of_weights=True, + execution_device="cpu", + storage_device="cpu", + logger=logger, + shared_cpu_weights=store, + ) + + +def test_two_device_caches_share_one_cpu_copy(mock_logger: MagicMock): + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, mock_logger) + cache_b = _make_cache(store, mock_logger) + try: + cache_a.put("m", DummyModule()) + ram_one_device = store.total_bytes_in_use() + assert ram_one_device > 0 + + cache_b.put("m", DummyModule()) + + # One canonical CPU copy shared by both "devices": the second device's put adds NO RAM. + assert store.refcount("m") == 2 + assert store.total_bytes_in_use() == ram_one_device + sd_a = cache_a.get("m").cached_model.get_cpu_state_dict() + sd_b = cache_b.get("m").cached_model.get_cpu_state_dict() + assert sd_a is sd_b + + # Evicting from one device drops only its reference; the weights stay for the other. + cache_a.make_room(10**12) + assert "m" not in cache_a._cached_models + assert store.refcount("m") == 1 + assert "m" in store + + # Evicting from the last device frees the shared RAM. + cache_b.make_room(10**12) + assert store.refcount("m") == 0 + assert "m" not in store + assert store.total_bytes_in_use() == 0 + finally: + cache_a.shutdown() + cache_b.shutdown() + + +def test_drop_model_releases_shared_weights(mock_logger: MagicMock): + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, mock_logger) + cache_b = _make_cache(store, mock_logger) + try: + cache_a.put("m", DummyModule()) + cache_b.put("m", DummyModule()) + assert store.refcount("m") == 2 + + assert cache_a.drop_model("m") == 1 + assert store.refcount("m") == 1 + assert cache_b.drop_model("m") == 1 + assert "m" not in store + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/model_manager/load/model_cache/test_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_ram_budget.py new file mode 100644 index 00000000000..d8704fffad5 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_ram_budget.py @@ -0,0 +1,48 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +def test_total_in_use_sums_store_and_non_shared(): + store = SharedCpuWeightsStore() + store.acquire("k", {"a": torch.ones(100, dtype=torch.float32)}) # 400 bytes + budget = RamBudget(max_bytes=10_000, shared_store=store) + + assert budget.total_in_use() == 400 # store only + budget.add_non_shared(600) + assert budget.total_in_use() == 1000 + assert budget.available() == 9000 + budget.remove_non_shared(600) + assert budget.total_in_use() == 400 + + +def test_shared_weights_counted_once_regardless_of_refcount(): + store = SharedCpuWeightsStore() + sd = {"a": torch.ones(100, dtype=torch.float32)} # 400 bytes + store.acquire("k", sd) + store.acquire("k", sd) # second device acquires the same key + budget = RamBudget(max_bytes=10_000, shared_store=store) + # Two references, one physical copy -> counted once. + assert budget.total_in_use() == 400 + + +def test_remove_non_shared_floors_at_zero(): + budget = RamBudget(max_bytes=10_000, shared_store=None) + budget.add_non_shared(100) + budget.remove_non_shared(500) + assert budget.total_in_use() == 0 + + +def test_available_can_go_negative_when_over_budget(): + budget = RamBudget(max_bytes=100, shared_store=None) + budget.add_non_shared(250) + assert budget.available() == -150 + + +def test_no_store_tracks_only_non_shared(): + budget = RamBudget(max_bytes=1000, shared_store=None) + assert budget.total_in_use() == 0 + budget.add_non_shared(300) + assert budget.total_in_use() == 300 + assert budget.max_bytes == 1000 diff --git a/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py b/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py new file mode 100644 index 00000000000..23d8fe875ef --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_shared_cpu_weights.py @@ -0,0 +1,106 @@ +import threading + +import torch + +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +def _state_dict() -> dict[str, torch.Tensor]: + return { + "a": torch.ones(10, 10, dtype=torch.float32), # 400 bytes + "b": torch.ones(5, dtype=torch.float32), # 20 bytes + } + + +def test_first_acquire_registers_and_returns_same_object(): + store = SharedCpuWeightsStore() + sd = _state_dict() + canonical = store.acquire("k", sd) + # The first acquire keeps the caller's own dict as canonical. + assert canonical is sd + assert store.refcount("k") == 1 + assert "k" in store + + +def test_second_acquire_returns_canonical_not_the_new_dict(): + store = SharedCpuWeightsStore() + first = _state_dict() + second = _state_dict() # distinct tensors, same shapes + canonical_first = store.acquire("k", first) + canonical_second = store.acquire("k", second) + + # The second caller gets the originally-registered tensors, not its own. + assert canonical_second is canonical_first + assert canonical_second["a"].data_ptr() == first["a"].data_ptr() + assert canonical_second["a"].data_ptr() != second["a"].data_ptr() + assert store.refcount("k") == 2 + + +def test_total_bytes_counts_each_key_once(): + store = SharedCpuWeightsStore() + # Two devices acquire the same key -> counted once. + store.acquire("k", _state_dict()) + store.acquire("k", _state_dict()) + assert store.total_bytes_in_use() == 420 + # A different key adds its own bytes. + store.acquire("k2", {"x": torch.ones(100, dtype=torch.float32)}) # 400 bytes + assert store.total_bytes_in_use() == 820 + + +def test_release_frees_only_at_zero(): + store = SharedCpuWeightsStore() + store.acquire("k", _state_dict()) + store.acquire("k", _state_dict()) + assert store.refcount("k") == 2 + + store.release("k") + assert store.refcount("k") == 1 + assert "k" in store + assert store.total_bytes_in_use() == 420 + + store.release("k") + assert store.refcount("k") == 0 + assert "k" not in store + assert store.total_bytes_in_use() == 0 + + +def test_release_unknown_key_is_noop(): + store = SharedCpuWeightsStore() + store.release("missing") # must not raise + assert store.total_bytes_in_use() == 0 + + +def test_reacquire_after_full_release_registers_fresh(): + store = SharedCpuWeightsStore() + first = _state_dict() + store.acquire("k", first) + store.release("k") + assert "k" not in store + + second = _state_dict() + canonical = store.acquire("k", second) + # After a full release the next caller becomes the new canonical. + assert canonical is second + assert store.refcount("k") == 1 + + +def test_concurrent_acquire_release_is_consistent(): + store = SharedCpuWeightsStore() + sd = _state_dict() + # Pre-register so the key exists for the whole run and the count never hits zero. + store.acquire("k", sd) + + def worker(): + for _ in range(200): + store.acquire("k", _state_dict()) + store.release("k") + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every acquire was paired with a release, so only the pre-registration reference remains. + assert store.refcount("k") == 1 + assert store.total_bytes_in_use() == 420 diff --git a/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py new file mode 100644 index 00000000000..601fdb3b5df --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py @@ -0,0 +1,236 @@ +"""Real-GPU validation of cross-device CPU-weight sharing. + +These require two CUDA (incl. ROCm/HIP) devices. They prove the properties the CPU-only unit tests +cannot: that a module re-pointed at shared canonical CPU weights (a) loads onto its GPU and produces +correct inference output, and (b) survives two GPUs loading/unloading from the *same* shared CPU +state dict concurrently without corrupting each other's results. +""" + +import copy +import logging +import threading +from unittest.mock import MagicMock + +import gguf +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule +from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor + +requires_two_gpus = pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices." +) + +DEVICE_A = "cuda:0" +DEVICE_B = "cuda:1" + + +def _mock_logger() -> MagicMock: + logger = MagicMock() + logger.getEffectiveLevel.return_value = logging.INFO + return logger + + +def _make_cache(store: SharedCpuWeightsStore, device: str, partial: bool) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=1.0, + enable_partial_loading=partial, + keep_ram_copy_of_weights=True, + execution_device=device, + storage_device="cpu", + logger=_mock_logger(), + shared_cpu_weights=store, + ) + + +@requires_two_gpus +@pytest.mark.parametrize("partial", [False, True]) +def test_shared_weights_produce_correct_output_on_both_gpus(partial: bool): + """A model loaded on two GPUs from one shared CPU copy must compute correct results on both.""" + torch.manual_seed(0) + model_a = DummyModule() + # model_b starts with DIFFERENT weights; sharing must overwrite them with model_a's canonical + # weights (both keys map to the same logical model). + torch.manual_seed(1) + model_b = DummyModule() + + x = torch.randn(4, 10) + # Reference output from model_a's original weights, computed before any cache/device mutation. + reference = copy.deepcopy(model_a)(x) + + store = SharedCpuWeightsStore() + cache_a = _make_cache(store, DEVICE_A, partial) + cache_b = _make_cache(store, DEVICE_B, partial) + try: + cache_a.put("m", model_a) + cache_b.put("m", model_b) + + # Single shared CPU copy across both devices. + assert store.refcount("m") == 2 + assert cache_a.get("m").cached_model.get_cpu_state_dict() is cache_b.get("m").cached_model.get_cpu_state_dict() + + rec_a = cache_a.get("m") + rec_b = cache_b.get("m") + cache_a.lock(rec_a, None) + cache_b.lock(rec_b, None) + try: + out_a = rec_a.cached_model.model(x.to(DEVICE_A)) + out_b = rec_b.cached_model.model(x.to(DEVICE_B)) + finally: + cache_a.unlock(rec_a) + cache_b.unlock(rec_b) + + # Both devices reproduce model_a's output (so model_b really adopted the shared weights). + assert torch.allclose(out_a.cpu(), reference, atol=1e-5) + assert torch.allclose(out_b.cpu(), reference, atol=1e-5) + finally: + cache_a.shutdown() + cache_b.shutdown() + + +@requires_two_gpus +@pytest.mark.parametrize("wrapper_cls", [CachedModelOnlyFullLoad, CachedModelWithPartialLoad]) +def test_concurrent_load_unload_from_shared_state_dict(wrapper_cls): + """Two GPUs repeatedly loading/unloading from one shared CPU state dict must not corrupt each + other. Each thread drives its own device's wrapper; the canonical CPU tensors are read-only and + must stay intact across concurrent .to(device) reads and load_state_dict(assign=True) restores. + """ + torch.manual_seed(0) + model_a = DummyModule() + torch.manual_seed(1) + model_b = DummyModule() + + x = torch.randn(4, 10) + reference = copy.deepcopy(model_a)(x) + + store = SharedCpuWeightsStore() + + def build(model, device): + if wrapper_cls is CachedModelWithPartialLoad: + return CachedModelWithPartialLoad( + model, torch.device(device), keep_ram_copy=True, shared_store=store, cache_key="m" + ) + return CachedModelOnlyFullLoad( + model, torch.device(device), total_bytes=1000, keep_ram_copy=True, shared_store=store, cache_key="m" + ) + + cached_a = build(model_a, DEVICE_A) + cached_b = build(model_b, DEVICE_B) + + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def run(cached, device): + try: + xd = x.to(device) + for _ in range(20): + barrier.wait() # maximise overlap of the two devices' loads + cached.full_load_to_vram() + out = cached.model(xd) + assert torch.allclose(out.cpu(), reference, atol=1e-5) + cached.full_unload_from_vram() + except Exception as e: # noqa: BLE001 - surface to main thread + errors.append(e) + try: + barrier.abort() + except Exception: + pass + + t_a = threading.Thread(target=run, args=(cached_a, DEVICE_A)) + t_b = threading.Thread(target=run, args=(cached_b, DEVICE_B)) + t_a.start() + t_b.start() + t_a.join() + t_b.join() + + assert not errors, f"Concurrent load/unload corrupted results: {errors[0]!r}" + # Canonical CPU weights survived and are still shared. + assert store.refcount("m") == 2 + cached_a.release_shared_weights() + cached_b.release_shared_weights() + assert "m" not in store + + +class _GGUFModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def _build_gguf_model(seed: int) -> _GGUFModel: + """A small model whose linear weight is a Q8_0 GGML-quantized (CPU-resident) tensor. + + This mirrors how large quantized transformers/encoders are stored: the weights live on the CPU + as GGMLTensors and are dequantized on the fly during the forward pass. It is the path that goes + through the shared-CPU-weights mechanism, so it validates that re-pointing a quantized state + dict across devices preserves correct dequantized inference. + """ + torch.manual_seed(seed) + model = _GGUFModel() + model.linear.weight = torch.nn.Parameter(quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0)) + return model + + +@requires_two_gpus +def test_shared_gguf_quantized_weights_correct_on_both_gpus(): + """A GGUF-quantized model loaded on two GPUs from one shared CPU copy must dequantize and + compute correct results on both devices.""" + x = torch.randn(1, 32, dtype=torch.float32) + + # Reference: a standalone copy of the same (seed-0) quantized weights, run via the autocast + # custom layers. Weights stay on CPU; compute happens on the device. + reference_model = _build_gguf_model(0) + apply_custom_layers_to_model(reference_model, device_autocasting_enabled=True) + reference = reference_model(x.to(DEVICE_A)).cpu() + + model_a = _build_gguf_model(0) + model_b = _build_gguf_model(1) # different weights; sharing must overwrite with canonical + + store = SharedCpuWeightsStore() + # enable_partial_loading=True routes quantized nn.Modules through CachedModelWithPartialLoad. + cache_a = _make_cache(store, DEVICE_A, partial=True) + cache_b = _make_cache(store, DEVICE_B, partial=True) + try: + cache_a.put("m", model_a) + ram_one_device = store.total_bytes_in_use() + cache_b.put("m", model_b) + + # One shared CPU copy of the quantized weights; second device adds no RAM. + assert store.refcount("m") == 2 + assert store.total_bytes_in_use() == ram_one_device + rec_a = cache_a.get("m") + rec_b = cache_b.get("m") + assert rec_a.cached_model.get_cpu_state_dict() is rec_b.cached_model.get_cpu_state_dict() + # model_b's quantized weight was re-pointed at model_a's canonical tensor. + assert rec_b.cached_model.model.linear.weight.data_ptr() == rec_a.cached_model.model.linear.weight.data_ptr() + + cache_a.lock(rec_a, None) + cache_b.lock(rec_b, None) + try: + out_a = rec_a.cached_model.model(x.to(DEVICE_A)) + out_b = rec_b.cached_model.model(x.to(DEVICE_B)) + finally: + cache_a.unlock(rec_a) + cache_b.unlock(rec_b) + + # Both GPUs reproduce the reference dequantized output. + assert torch.allclose(out_a.cpu(), reference, atol=1e-5) + assert torch.allclose(out_b.cpu(), reference, atol=1e-5) + finally: + cache_a.shutdown() + cache_b.shutdown() diff --git a/tests/backend/patches/test_layer_patcher_shared_weights.py b/tests/backend/patches/test_layer_patcher_shared_weights.py new file mode 100644 index 00000000000..328f93b04cc --- /dev/null +++ b/tests/backend/patches/test_layer_patcher_shared_weights.py @@ -0,0 +1,106 @@ +"""Regression tests: LoRA direct patching must not mutate the model's canonical CPU weights. + +In multi-GPU mode the per-device caches share one canonical CPU state_dict (SharedCpuWeightsStore), +and that same dict is the keep_ram_copy used to restore a model after unpatching. Direct patching +must therefore never mutate those tensors in place — otherwise a LoRA applied on one GPU would +corrupt the weights seen by the other GPU (and taint the cached "clean" copy even with one GPU). + +These run on CPU and force direct patching, which is the path that touches CPU-resident weights. +""" + +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from tests.backend.patches.test_layer_patcher import DummyModuleWithOneLayer + + +def _make_loras(num_loras: int, in_features: int, out_features: int, rank: int): + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_loras): + layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((rank, in_features), device="cpu", dtype=torch.float32), + "lora_up.weight": torch.ones((out_features, rank), device="cpu", dtype=torch.float32), + }, + ) + } + lora_models.append((ModelPatchRaw(layers), 0.5)) + return lora_models + + +@torch.no_grad() +def test_force_direct_patch_does_not_mutate_canonical_cpu_weights(): + in_features, out_features, rank = 4, 8, 2 + model = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model) + + # `canonical` holds references to the model's actual parameter tensors — exactly what the shared + # store would hand out as the canonical CPU copy and what model_on_device() passes as + # cached_weights. We snapshot their values to detect any in-place mutation. + canonical = dict(model.state_dict()) + snapshot = {k: v.detach().clone() for k, v in canonical.items()} + + lora_models = _make_loras(num_loras=2, in_features=in_features, out_features=out_features, rank=rank) + x = torch.randn(1, in_features, dtype=torch.float32) + out_before = model(x) + + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=lora_models, + prefix="", + dtype=torch.float32, + cached_weights=canonical, + force_direct_patching=True, + ): + # Sanity: this really is the direct path (no sidecar wrappers), so weights were applied + # directly — and the patch actually changed the output. + assert model.linear_layer_1.get_num_patches() == 0 + out_during = model(x) + assert not torch.allclose(out_before, out_during) + + # The canonical tensors must be untouched even while the patch is active. + for k in canonical: + torch.testing.assert_close(canonical[k], snapshot[k]) + + # ...and after unpatching. + for k in canonical: + torch.testing.assert_close(canonical[k], snapshot[k]) + assert torch.allclose(out_before, model(x)) + + +@torch.no_grad() +def test_two_models_sharing_canonical_are_isolated_under_direct_patch(): + """Patch one model built from the shared canonical weights; a second model built from the same + canonical tensors must be unaffected (no cross-device bleed).""" + in_features, out_features, rank = 4, 8, 2 + model_a = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model_a) + canonical = dict(model_a.state_dict()) + + # model_b shares the canonical tensors (as a second device's cache would via load_state_dict). + model_b = DummyModuleWithOneLayer(in_features, out_features, device="cpu", dtype=torch.float32) + apply_custom_layers_to_model(model_b) + model_b.load_state_dict(canonical, assign=True) + + x = torch.randn(1, in_features, dtype=torch.float32) + out_b_before = model_b(x) + + lora_models = _make_loras(num_loras=1, in_features=in_features, out_features=out_features, rank=rank) + with LayerPatcher.apply_smart_model_patches( + model=model_a, + patches=lora_models, + prefix="", + dtype=torch.float32, + cached_weights=canonical, + force_direct_patching=True, + ): + # model_a is patched; model_b (sharing the canonical weights) must be unchanged. + assert torch.allclose(model_b(x), out_b_before) + + assert torch.allclose(model_b(x), out_b_before) From fbd95a83e6347b0e5b63abc225604c7c10a84940 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Jun 2026 20:38:56 -0400 Subject: [PATCH 22/33] fix(session-queue): cancel all in-progress items in bulk-cancel APIs (multi-GPU) With one session-processor worker per device, multiple queue items can be in_progress at once. cancel_by_batch_ids(), cancel_by_destination() and cancel_by_queue_id() excluded in_progress rows from their bulk UPDATE and then canceled only the single get_current() item (LIMIT 1), so on multi-GPU the other running items kept consuming a GPU and could still produce output after the user requested cancellation. Each running item must be canceled via _set_queue_item_status(), which emits the QueueItemStatusChangedEvent that the processor maps to the worker running that item_id and uses to set its cancel event. Add _cancel_in_progress_matching() to cancel every in-progress item matching the same filter (with user-id scoping preserved) and call it from all three bulk-cancel methods. The returned `canceled` count now includes canceled in-progress items. Adds regression tests that dequeue two items onto separate devices and assert every bulk cancel API moves all matching in_progress items to canceled and emits a cancel event for each (and that user-scoped cancel leaves another user's in-progress item running). Reported by JPPhoto in review of #9263. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../session_queue/session_queue_sqlite.py | 112 +++++++++------- .../test_session_queue_multigpu_cancel.py | 124 ++++++++++++++++++ 2 files changed, 190 insertions(+), 46 deletions(-) create mode 100644 tests/app/services/session_queue/test_session_queue_multigpu_cancel.py diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 25558c6260f..8bb9e6173d2 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,7 +2,7 @@ import json import sqlite3 import threading -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from pydantic_core import to_jsonable_python @@ -475,30 +475,61 @@ def fail_queue_item( ) return queue_item + def _cancel_in_progress_matching(self, match_filter: str, params: list[Any]) -> int: + """Cancel every in-progress item matching `match_filter`, emitting a cancel event for each. + + The bulk-cancel methods exclude in-progress items from their single UPDATE statement, because + a running item must be canceled via `_set_queue_item_status()` so that its + `QueueItemStatusChangedEvent` is emitted — the session processor responds to that event by + setting the cancel event of the worker running that exact item_id. With multiple workers + (multi-GPU) more than one item can be in_progress at once, so each matching item is canceled + individually here rather than relying on a single `get_current()` (which returns only one). + + `match_filter` is a WHERE fragment without the leading WHERE (e.g. + "queue_id == ? AND batch_id IN (?, ?)"); `params` are its bound values. + + Returns the number of in-progress items actually canceled. + """ + with self._db.transaction() as cursor: + cursor.execute( + f"""--sql + SELECT item_id + FROM session_queue + WHERE status == 'in_progress' AND {match_filter}; + """, + tuple(params), + ) + item_ids = [row[0] for row in cursor.fetchall()] + + canceled = 0 + for item_id in item_ids: + # _set_queue_item_status no-ops (and returns the existing item) if the item finished + # between the SELECT and now, so count only the ones we actually moved to 'canceled'. + if self._set_queue_item_status(item_id, "canceled").status == "canceled": + canceled += 1 + return canceled + def cancel_by_batch_ids( self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None ) -> CancelByBatchIDsResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) - placeholders = ", ".join(["?" for _ in batch_ids]) + placeholders = ", ".join(["?" for _ in batch_ids]) + # Build the match filter (with optional user_id filter) shared by the bulk update and the + # in-progress cancellation below. + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND batch_id IN ({placeholders}) {user_filter}" + params: list[Any] = [queue_id] + batch_ids + if user_id is not None: + params.append(user_id) - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" + with self._db.transaction() as cursor: where = f"""--sql - WHERE - queue_id == ? - AND batch_id IN ({placeholders}) + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' - {user_filter} """ - params = [queue_id] + batch_ids - if user_id is not None: - params.append(user_id) - cursor.execute( f"""--sql SELECT COUNT(*) @@ -518,36 +549,28 @@ def cancel_by_batch_ids( tuple(params), ) - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - if user_id is None or current_queue_item.user_id == user_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") - + # Cancel every in-progress item matching the same filter (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByBatchIDsResult(canceled=count) def cancel_by_destination( self, queue_id: str, destination: str, user_id: Optional[str] = None ) -> CancelByDestinationResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND destination == ? {user_filter}" + params: list[Any] = [queue_id, destination] + if user_id is not None: + params.append(user_id) - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" + with self._db.transaction() as cursor: where = f"""--sql - WHERE - queue_id == ? - AND destination == ? + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' - {user_filter} """ - params = [queue_id, destination] - if user_id is not None: - params.append(user_id) - cursor.execute( f"""--sql SELECT COUNT(*) @@ -567,11 +590,8 @@ def cancel_by_destination( tuple(params), ) - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.destination == destination: - if user_id is None or current_queue_item.user_id == user_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") - + # Cancel every in-progress item matching the same filter (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByDestinationResult(canceled=count) def delete_by_destination( @@ -649,18 +669,18 @@ def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None return DeleteAllExceptCurrentResult(deleted=count) def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: + match_filter = "queue_id == ?" + params: list[Any] = [queue_id] + with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) - where = """--sql - WHERE - queue_id is ? + where = f"""--sql + WHERE {match_filter} AND status != 'canceled' AND status != 'completed' AND status != 'failed' - -- We will cancel the current item separately below - skip it here + -- In-progress items are canceled individually below so each worker is signaled. AND status != 'in_progress' """ - params = [queue_id] cursor.execute( f"""--sql SELECT COUNT(*) @@ -680,8 +700,8 @@ def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: tuple(params), ) - if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self._set_queue_item_status(current_queue_item.item_id, "canceled") + # Cancel every in-progress item in the queue (multi-GPU: possibly several at once). + count += self._cancel_in_progress_matching(match_filter, params) return CancelByQueueIDResult(canceled=count) def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult: diff --git a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py new file mode 100644 index 00000000000..23901bec7c1 --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py @@ -0,0 +1,124 @@ +"""Regression tests for multi-GPU bulk cancellation. + +With one session-processor worker per device, several queue items can be `in_progress` at the same +time. The bulk-cancel APIs must cancel ALL matching in-progress items (each emitting a cancel event +so its worker stops), not just the single `get_current()` item. See JPPhoto's review on PR #9263. +""" + +import uuid + +import pytest + +from invokeai.app.services.events.events_common import QueueItemStatusChangedEvent +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation, TestEventService + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert( + session_queue: SqliteSessionQueue, + batch_id: str, + destination: str | None = None, + user_id: str = "system", + queue_id: str = "default", +) -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (queue_id, session_json, session.id, batch_id, None, 0, None, None, destination, None, user_id), + ) + return cursor.lastrowid + + +def _canceled_event_item_ids(mock_invoker: Invoker) -> set[int]: + event_bus: TestEventService = mock_invoker.services.events + return { + e.item_id + for e in event_bus.events + if isinstance(e, QueueItemStatusChangedEvent) and e.status == "canceled" + } + + +def _dequeue_two_on_separate_devices(session_queue: SqliteSessionQueue) -> tuple[int, int]: + a = session_queue.dequeue(device="cuda:0") + b = session_queue.dequeue(device="cuda:1") + assert a is not None and b is not None + assert a.status == "in_progress" and b.status == "in_progress" + return a.item_id, b.item_id + + +def test_cancel_by_batch_ids_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + batch_id = str(uuid.uuid4()) + _insert(session_queue, batch_id=batch_id) + _insert(session_queue, batch_id=batch_id) + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_batch_ids("default", [batch_id]) + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + # Each worker must have received a cancel event for its item. + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_cancel_by_destination_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_destination("default", "canvas") + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_cancel_by_queue_id_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + _insert(session_queue, batch_id=str(uuid.uuid4())) + _insert(session_queue, batch_id=str(uuid.uuid4())) + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.cancel_by_queue_id("default") + + assert result.canceled == 2 + assert session_queue.get_queue_item(id_a).status == "canceled" + assert session_queue.get_queue_item(id_b).status == "canceled" + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + + +def test_cancel_by_batch_ids_respects_user_scope(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + """A user-scoped cancel must not cancel another user's in-progress item in the same batch.""" + batch_id = str(uuid.uuid4()) + _insert(session_queue, batch_id=batch_id, user_id="alice") + _insert(session_queue, batch_id=batch_id, user_id="bob") + alice_item = session_queue.dequeue(device="cuda:0") + bob_item = session_queue.dequeue(device="cuda:1") + assert alice_item is not None and bob_item is not None + + result = session_queue.cancel_by_batch_ids("default", [batch_id], user_id="alice") + + assert result.canceled == 1 + assert session_queue.get_queue_item(alice_item.item_id).status == "canceled" + assert session_queue.get_queue_item(bob_item.item_id).status == "in_progress" + assert _canceled_event_item_ids(mock_invoker) == {alice_item.item_id} From ed2a330ae641a0d38c0175dbe2bd8091f966bbaa Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Jun 2026 21:17:23 -0400 Subject: [PATCH 23/33] fix(multi-gpu): address review findings (cancel race, bulk delete, device guards, refcount leak) Fixes from the code review of PR #9263: - Cancellation could be silently lost around dequeue: the per-iteration worker.cancel_event.clear() ran AFTER dequeue + gc.collect() + logging, so a cancel arriving in that window was set by the status handler and then wiped. Move the clear to before dequeue, and after claiming an item re-check (cancel_event + a fresh DB status read via _is_queue_item_terminal) and skip running if it is already terminal, closing both race windows. The runner's stale queue_item.status check could not catch this. - delete_by_destination only stopped one in-progress item (get_current) before deleting all matching rows, leaving other GPU workers running (and then failing to update a deleted row). Cancel every matching in-progress item via _cancel_in_progress_matching first. - generation_devices validation: a bare non-"auto" string (e.g. "cuda:0") was iterated character-by-character; an empty list silently fell back to one device. Reject both with a clear message. - get_generation_devices now fails fast on a CUDA device that does not exist (index past device_count, or CUDA unavailable) instead of starting a worker that errors cryptically at first allocation. - Shared-weights wrappers: if the canonical re-point (load_state_dict assign=True) threw after acquire(), the reference was leaked (the wrapper never entered the cache). Compute size metadata first, make acquire the last step, and release on failure. Adds tests for each: post-dequeue terminal guard, delete_by_destination cancellation, generation_devices validation, absent-device rejection, and acquire-released-on-repoint-failure. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../app/services/config/config_default.py | 9 ++++ .../session_processor_default.py | 31 ++++++++++- .../session_queue/session_queue_sqlite.py | 23 ++++----- .../cached_model_only_full_load.py | 12 +++-- .../cached_model_with_partial_load.py | 42 ++++++++------- invokeai/backend/util/devices.py | 11 ++++ .../config/test_config_generation_devices.py | 38 ++++++++++++++ .../test_session_processor_cancel_guard.py | 51 +++++++++++++++++++ .../test_session_queue_multigpu_cancel.py | 20 ++++++++ .../test_cached_model_shared_weights.py | 22 ++++++++ tests/backend/util/test_devices.py | 20 ++++++++ 11 files changed, 246 insertions(+), 33 deletions(-) create mode 100644 tests/app/services/config/test_config_generation_devices.py create mode 100644 tests/app/services/session_processor/test_session_processor_cancel_guard.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 15d447dd182..7e7b8e8f1c9 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -263,6 +263,15 @@ class InvokeAIAppConfig(BaseSettings): def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, list[str]]: if v == "auto": return v + # A non-"auto" string would otherwise be iterated character-by-character below (rejecting + # 'c' from "cuda:0"), producing a confusing error. Require an explicit list instead. + if isinstance(v, str): + raise ValueError( + f"Invalid generation_devices value '{v}'. Use 'auto' or a list of devices, " + "e.g. ['cuda:0', 'cuda:1']." + ) + if len(v) == 0: + raise ValueError("generation_devices cannot be an empty list. Use 'auto' or a list of devices.") pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") for device in v: if not pattern.match(device): diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 27c1f2a8632..93c4554b1fe 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -498,6 +498,19 @@ def get_status(self) -> SessionProcessorStatus: is_processing=any(worker.queue_item is not None for worker in self._workers), ) + def _is_queue_item_terminal(self, item_id: int) -> bool: + """Return True if the queue item is already finished (canceled/failed/completed) or gone. + + Checked right after a worker claims an item to catch a cancellation that raced the claim and + so never reached this worker's cancel_event — e.g. the status-changed handler ran before the + worker recorded `queue_item` and so couldn't match a worker to signal. + """ + try: + status = self._invoker.services.session_queue.get_queue_item(item_id).status + except SessionQueueItemNotFoundError: + return True + return status in ("canceled", "failed", "completed") + def _process( self, worker: _SessionWorker, @@ -528,6 +541,13 @@ def _process( if stop_event.is_set(): break + # Clear any stale cancel signal from the previous item BEFORE claiming the next + # one. Clearing it after dequeue (as before) could wipe a cancel that arrived for + # the item we just claimed — e.g. during the gc.collect() below — silently losing + # the cancellation. Any cancel that arrives after this point for the claimed item + # stays set and is caught by the runner's _is_canceled() check. + worker.cancel_event.clear() + # Get the next session to process. dequeue() atomically claims the item, so concurrent # workers never receive the same item. Pass this worker's device so the item is # tagged with the GPU that ran it (None in single-device/legacy mode). @@ -541,6 +561,16 @@ def _process( poll_now_event.wait(self._polling_interval) continue + # A cancellation can race the claim: it may have marked the row terminal before + # this worker recorded `queue_item`, so _on_queue_item_status_changed couldn't set + # our cancel_event. Re-check (cancel_event + a fresh DB status read) and skip + # running if the item is already finished, so the cancel is never lost. + if worker.cancel_event.is_set() or self._is_queue_item_terminal(worker.queue_item.item_id): + self._invoker.services.logger.debug( + f"Queue item {worker.queue_item.item_id} was canceled before it started; skipping." + ) + continue + # GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks. # Most queue items take seconds to execute, so the relative cost of a GC is very small. # Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak @@ -551,7 +581,6 @@ def _process( f"Executing queue item {worker.queue_item.item_id}, session {worker.queue_item.session_id} " f"on {worker.label}" ) - worker.cancel_event.clear() # Run the graph worker.runner.run(queue_item=worker.queue_item) diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 8bb9e6173d2..e6d8229860d 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -597,20 +597,19 @@ def cancel_by_destination( def delete_by_destination( self, queue_id: str, destination: str, user_id: Optional[str] = None ) -> DeleteByDestinationResult: - with self._db.transaction() as cursor: - current_queue_item = self.get_current(queue_id) - - # Handle current item separately - check ownership if user_id is provided - if current_queue_item is not None and current_queue_item.destination == destination: - if user_id is None or current_queue_item.user_id == user_id: - self.cancel_queue_item(current_queue_item.item_id) + user_filter = "AND user_id = ?" if user_id is not None else "" + match_filter = f"queue_id == ? AND destination == ? {user_filter}" + params: list[Any] = [queue_id, destination] + if user_id is not None: + params.append(user_id) - # Build WHERE clause with optional user_id filter - user_filter = "AND user_id = ?" if user_id is not None else "" - params = [queue_id, destination] - if user_id is not None: - params.append(user_id) + # Cancel every in-progress item first so each running worker is signaled to stop before we + # delete its row. With multiple workers (multi-GPU) more than one item can be in_progress; + # canceling only get_current() would leave the others running (and then failing to update a + # deleted row). See _cancel_in_progress_matching. + self._cancel_in_progress_matching(match_filter, params) + with self._db.transaction() as cursor: cursor.execute( f"""--sql SELECT COUNT(*) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index 8d1239cdfea..243a00015d6 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -51,11 +51,17 @@ def __init__( # module at the shared tensors and drop our duplicate so the weights live once in RAM. if shared_store is not None and cache_key is not None: canonical = shared_store.acquire(cache_key, cpu_state_dict) - if canonical is not cpu_state_dict: - model.load_state_dict(canonical, assign=True) - cpu_state_dict = canonical self._shared_store = shared_store self._shared_key = cache_key + try: + if canonical is not cpu_state_dict: + model.load_state_dict(canonical, assign=True) + cpu_state_dict = canonical + except Exception: + # The re-point failed after acquiring a reference; release it so the shared + # entry's refcount isn't leaked (this wrapper will never enter the cache). + self.release_shared_weights() + raise self._cpu_state_dict = cpu_state_dict self._total_bytes = total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 0f9b534d716..2a1d83cb011 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -35,37 +35,45 @@ def __init__( # patching. Set to `None` if keep_ram_copy is False. cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None - # In multi-GPU mode, share a single canonical CPU copy of the weights across the per-device - # caches instead of keeping one copy per device (see SharedCpuWeightsStore). If another - # device already registered this key, re-point our module's params at the shared tensors and - # drop our freshly-built duplicate so the weights live once in RAM. - if cpu_state_dict is not None and shared_store is not None and cache_key is not None: - canonical = shared_store.acquire(cache_key, cpu_state_dict) - if canonical is not cpu_state_dict: - self._model.load_state_dict(canonical, assign=True) - model_state_dict = canonical - cpu_state_dict = canonical - self._shared_store = shared_store - self._shared_key = cache_key - - self._cpu_state_dict: dict[str, torch.Tensor] | None = cpu_state_dict - # A dictionary of the size of each tensor in the state dict. # HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for # consistency in case the application code has modified the model's size (e.g. by casting to a different # precision). Of course, this means that we are making model cache load/unload decisions based on model size # data that may not be fully accurate. + # + # Note: these are computed from the model's own state dict *before* the shared-weights re-point + # below. The re-point only swaps tensor storage; keys, shapes and dtypes are unchanged, so the + # metadata is identical either way. Computing it first keeps the acquire the last (and only + # failure-prone) step, so a failure there can release the reference cleanly without leaking it. self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()} - self._total_bytes = sum(self._state_dict_bytes.values()) self._cur_vram_bytes: int | None = None - self._modules_that_support_autocast = self._find_modules_that_support_autocast() self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast( model_state_dict ) self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict) + # In multi-GPU mode, share a single canonical CPU copy of the weights across the per-device + # caches instead of keeping one copy per device (see SharedCpuWeightsStore). If another + # device already registered this key, re-point our module's params at the shared tensors and + # drop our freshly-built duplicate so the weights live once in RAM. + if cpu_state_dict is not None and shared_store is not None and cache_key is not None: + canonical = shared_store.acquire(cache_key, cpu_state_dict) + self._shared_store = shared_store + self._shared_key = cache_key + try: + if canonical is not cpu_state_dict: + self._model.load_state_dict(canonical, assign=True) + cpu_state_dict = canonical + except Exception: + # The re-point failed after acquiring a reference; release it so the shared entry's + # refcount isn't leaked (this wrapper will never be inserted into the cache). + self.release_shared_weights() + raise + + self._cpu_state_dict: dict[str, torch.Tensor] | None = cpu_state_dict + def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 0055acd1289..7a5e8f3e8b9 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -194,6 +194,17 @@ def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) seen: set[str] = set() for device_str in device_strs: device = cls.normalize(device_str) + # Fail fast on a CUDA device that doesn't exist, rather than starting a worker pinned to + # it that only errors cryptically at the first tensor allocation. ("auto" only generates + # valid indices, so this just validates explicitly-configured devices.) + if device.type == "cuda": + if not torch.cuda.is_available(): + raise ValueError(f"generation_devices requested '{device_str}', but no CUDA device is available.") + if device.index is not None and device.index >= torch.cuda.device_count(): + raise ValueError( + f"generation_devices requested '{device_str}', but only {torch.cuda.device_count()} " + f"CUDA device(s) are available (valid indices 0-{torch.cuda.device_count() - 1})." + ) if str(device) not in seen: seen.add(str(device)) devices.append(device) diff --git a/tests/app/services/config/test_config_generation_devices.py b/tests/app/services/config/test_config_generation_devices.py new file mode 100644 index 00000000000..e589b35dd3d --- /dev/null +++ b/tests/app/services/config/test_config_generation_devices.py @@ -0,0 +1,38 @@ +"""Validation tests for the multi-GPU `generation_devices` config field.""" + +import pytest +from pydantic import ValidationError + +from invokeai.app.services.config.config_default import InvokeAIAppConfig + + +@pytest.mark.parametrize( + "value", + [ + "auto", + ["cuda:0"], + ["cuda:0", "cuda:1"], + ["cpu"], + ["mps"], + ["cuda"], + ], +) +def test_valid_generation_devices(value): + cfg = InvokeAIAppConfig(generation_devices=value) + assert cfg.generation_devices == value + + +def test_non_auto_string_is_rejected(): + # A bare string (other than "auto") would otherwise be iterated character-by-character. + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices="cuda:0") + + +def test_empty_list_is_rejected(): + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices=[]) + + +def test_invalid_device_name_is_rejected(): + with pytest.raises(ValidationError): + InvokeAIAppConfig(generation_devices=["gpu0"]) diff --git a/tests/app/services/session_processor/test_session_processor_cancel_guard.py b/tests/app/services/session_processor/test_session_processor_cancel_guard.py new file mode 100644 index 00000000000..b99f19a3068 --- /dev/null +++ b/tests/app/services/session_processor/test_session_processor_cancel_guard.py @@ -0,0 +1,51 @@ +"""Tests for the post-dequeue cancellation guard that closes the multi-GPU cancel-loss race. + +A cancellation can mark a queue item terminal in the window between dequeue claiming it and the +worker recording `queue_item` (so the status-changed handler can't set the worker's cancel_event). +`_is_queue_item_terminal` is the fresh DB re-check the worker uses to skip running such an item. +""" + +from types import SimpleNamespace + +import pytest + +from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItemNotFoundError + + +class _Queue: + def __init__(self, status: str | None = None, raise_not_found: bool = False): + self._status = status + self._raise = raise_not_found + + def get_queue_item(self, item_id: int): + if self._raise: + raise SessionQueueItemNotFoundError("gone") + return SimpleNamespace(item_id=item_id, status=self._status) + + +def _processor_with_queue(queue: _Queue) -> DefaultSessionProcessor: + processor = DefaultSessionProcessor() + processor._invoker = SimpleNamespace(services=SimpleNamespace(session_queue=queue)) # type: ignore[attr-defined] + return processor + + +@pytest.mark.parametrize( + ("status", "expected"), + [ + ("in_progress", False), + ("pending", False), + ("canceled", True), + ("failed", True), + ("completed", True), + ], +) +def test_is_queue_item_terminal_status(status: str, expected: bool): + processor = _processor_with_queue(_Queue(status=status)) + assert processor._is_queue_item_terminal(1) is expected + + +def test_is_queue_item_terminal_treats_missing_as_terminal(): + # A deleted row (e.g. queue cleared during the race) should be treated as terminal, not run. + processor = _processor_with_queue(_Queue(raise_not_found=True)) + assert processor._is_queue_item_terminal(1) is True diff --git a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py index 23901bec7c1..d06d41c6070 100644 --- a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py +++ b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py @@ -11,6 +11,7 @@ from invokeai.app.services.events.events_common import QueueItemStatusChangedEvent from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItemNotFoundError from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue from invokeai.app.services.shared.graph import Graph, GraphExecutionState from tests.test_nodes import PromptTestInvocation, TestEventService @@ -107,6 +108,25 @@ def test_cancel_by_queue_id_cancels_all_in_progress(session_queue: SqliteSession assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) +def test_delete_by_destination_cancels_all_in_progress(session_queue: SqliteSessionQueue, mock_invoker: Invoker): + """delete_by_destination must signal every running worker (not just get_current()) before + deleting their rows, or the un-canceled workers keep running and then fail to update a deleted + row.""" + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + _insert(session_queue, batch_id=str(uuid.uuid4()), destination="canvas") + id_a, id_b = _dequeue_two_on_separate_devices(session_queue) + + result = session_queue.delete_by_destination("default", "canvas") + + assert result.deleted == 2 + # Both in-progress workers were signaled to cancel before deletion. + assert {id_a, id_b} <= _canceled_event_item_ids(mock_invoker) + # Rows are gone. + for item_id in (id_a, id_b): + with pytest.raises(SessionQueueItemNotFoundError): + session_queue.get_queue_item(item_id) + + def test_cancel_by_batch_ids_respects_user_scope(session_queue: SqliteSessionQueue, mock_invoker: Invoker): """A user-scoped cancel must not cancel another user's in-progress item in the same batch.""" batch_id = str(uuid.uuid4()) diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py index 9969c50fb96..d297d86d9ed 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py @@ -5,6 +5,7 @@ copies. They run on CPU — the wrapper constructors never touch VRAM, so no GPU is required. """ +import pytest import torch from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( @@ -98,3 +99,24 @@ def test_keep_ram_copy_false_does_not_touch_store(): assert cached.get_cpu_state_dict() is None assert "m" not in store assert store.refcount("m") == 0 + + +class _RepointFailsModule(DummyModule): + """A model whose load_state_dict raises, to simulate a re-point failure during construction.""" + + def load_state_dict(self, *args, **kwargs): # type: ignore[override] + raise RuntimeError("simulated re-point failure") + + +def test_acquire_is_released_if_repoint_fails(): + # First device registers the canonical weights (refcount 1). + store = SharedCpuWeightsStore() + CachedModelWithPartialLoad(DummyModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + assert store.refcount("m") == 1 + + # Second device adopts the canonical copy, but its re-point throws. The just-acquired reference + # must be released so the store's refcount is not leaked (the wrapper never enters the cache). + with pytest.raises(RuntimeError, match="simulated re-point failure"): + CachedModelWithPartialLoad(_RepointFailsModule(), CPU, keep_ram_copy=True, shared_store=store, cache_key="m") + + assert store.refcount("m") == 1 # back to just the first device, not leaked at 2 diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index aa8433c632e..b95a9775d6f 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -255,3 +255,23 @@ def test_choose_anima_inference_dtype_auto_delegates_to_safe_dtype(): result = TorchDevice.choose_anima_inference_dtype(device) assert result is sentinel mock_safe.assert_called_once_with(device) + + +@patch("torch.cuda.device_count", return_value=2) +@patch("torch.cuda.is_available", return_value=True) +def test_get_generation_devices_rejects_out_of_range_cuda(mock_avail, mock_count): + # cuda:2 does not exist on a 2-GPU machine — fail fast instead of deferring to first allocation. + with pytest.raises(ValueError, match="only 2 CUDA"): + TorchDevice.get_generation_devices(["cuda:2"]) + + +@patch("torch.cuda.device_count", return_value=2) +@patch("torch.cuda.is_available", return_value=True) +def test_get_generation_devices_accepts_in_range_cuda(mock_avail, mock_count): + assert [str(d) for d in TorchDevice.get_generation_devices(["cuda:1"])] == ["cuda:1"] + + +@patch("torch.cuda.is_available", return_value=False) +def test_get_generation_devices_rejects_cuda_when_unavailable(mock_avail): + with pytest.raises(ValueError, match="no CUDA"): + TorchDevice.get_generation_devices(["cuda:0"]) From 57e1e79d7b54598d24e3ff4c5bdcead407a727f7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Jun 2026 21:40:27 -0400 Subject: [PATCH 24/33] fix(ci): ruff format + make CPU-incompatible device test mock CUDA - Apply ruff 0.11.2 formatting to the files flagged by `ruff format --check`. - The new fail-fast guard in get_generation_devices() (reject a CUDA device that doesn't exist) made the pre-existing test_get_generation_devices_explicit_list_is_deduplicated fail on CPU-only CI runners, since it passes a cuda list with no CUDA present. Mock torch.cuda.is_available/device_count in that test (matching the existing pattern in this file) so it validates dedup on any runner. Co-Authored-By: Claude Opus 4.8 (1M context) --- invokeai/app/services/config/config_default.py | 3 +-- scripts/multigpu_ram_driver.py | 10 +++------- .../test_session_queue_multigpu_cancel.py | 4 +--- .../test_cached_model_shared_weights.py | 8 ++++++-- .../load/model_cache/test_shared_weights_gpu.py | 4 +--- tests/backend/util/test_devices.py | 13 +++++++++---- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 7e7b8e8f1c9..8c07c2139f4 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -267,8 +267,7 @@ def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, lis # 'c' from "cuda:0"), producing a confusing error. Require an explicit list instead. if isinstance(v, str): raise ValueError( - f"Invalid generation_devices value '{v}'. Use 'auto' or a list of devices, " - "e.g. ['cuda:0', 'cuda:1']." + f"Invalid generation_devices value '{v}'. Use 'auto' or a list of devices, e.g. ['cuda:0', 'cuda:1']." ) if len(v) == 0: raise ValueError("generation_devices cannot be an empty list. Use 'auto' or a list of devices.") diff --git a/scripts/multigpu_ram_driver.py b/scripts/multigpu_ram_driver.py index 2da725debde..9d985b4d98c 100755 --- a/scripts/multigpu_ram_driver.py +++ b/scripts/multigpu_ram_driver.py @@ -61,9 +61,7 @@ # -------------------------------------------------------------------------------------------------- def _request(method: str, url: str, body: dict | None = None, timeout: float = 60.0) -> dict: data = json.dumps(body).encode() if body is not None else None - req = urllib.request.Request( - url, data=data, headers={"Content-Type": "application/json"}, method=method - ) + req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}, method=method) try: with urllib.request.urlopen(req, timeout=timeout) as resp: raw = resp.read() @@ -128,9 +126,7 @@ def find_server_pid(port: int) -> int: cmd = " ".join(proc.info.get("cmdline") or []) if any(n in cmd for n in needles): return proc.info["pid"] - raise SystemExit( - f"Could not auto-detect the InvokeAI server PID on port {port}. Pass --pid explicitly." - ) + raise SystemExit(f"Could not auto-detect the InvokeAI server PID on port {port}. Pass --pid explicitly.") def tree_rss(proc: psutil.Process, use_uss: bool) -> int: @@ -269,7 +265,7 @@ def submit_round() -> tuple[float, float]: sampler.stop() # Summary - idle_drift = (sampler.current() - (idle_after_first or baseline)) + idle_drift = sampler.current() - (idle_after_first or baseline) print("\n--- Summary ---") print(f"Baseline idle: {gb(baseline):.2f} GB") print(f"Overall peak: {gb(overall_peak):.2f} GB (Δ {gb(overall_peak - baseline):+.2f} GB over baseline)") diff --git a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py index d06d41c6070..0d97ad6ab03 100644 --- a/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py +++ b/tests/app/services/session_queue/test_session_queue_multigpu_cancel.py @@ -53,9 +53,7 @@ def _insert( def _canceled_event_item_ids(mock_invoker: Invoker) -> set[int]: event_bus: TestEventService = mock_invoker.services.events return { - e.item_id - for e in event_bus.events - if isinstance(e, QueueItemStatusChangedEvent) and e.status == "canceled" + e.item_id for e in event_bus.events if isinstance(e, QueueItemStatusChangedEvent) and e.status == "canceled" } diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py index d297d86d9ed..79a34ff0d96 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_shared_weights.py @@ -52,8 +52,12 @@ def test_full_load_shares_cpu_weights_across_devices(): model_b = DummyModule() a_ptrs = _data_ptrs(model_a.state_dict()) - cached_a = CachedModelOnlyFullLoad(model_a, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m") - cached_b = CachedModelOnlyFullLoad(model_b, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m") + cached_a = CachedModelOnlyFullLoad( + model_a, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m" + ) + cached_b = CachedModelOnlyFullLoad( + model_b, CPU, total_bytes=100, keep_ram_copy=True, shared_store=store, cache_key="m" + ) assert cached_a.get_cpu_state_dict() is cached_b.get_cpu_state_dict() assert _data_ptrs(model_b.state_dict()) == a_ptrs diff --git a/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py index 601fdb3b5df..4f5b700cbb3 100644 --- a/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py +++ b/tests/backend/model_manager/load/model_cache/test_shared_weights_gpu.py @@ -29,9 +29,7 @@ from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor -requires_two_gpus = pytest.mark.skipif( - torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices." -) +requires_two_gpus = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices.") DEVICE_A = "cuda:0" DEVICE_B = "cuda:1" diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index b95a9775d6f..1d7dfa75614 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -98,10 +98,15 @@ def test_get_generation_devices_auto_without_cuda(): def test_get_generation_devices_explicit_list_is_deduplicated(): """An explicit list is normalized and deduplicated, preserving order.""" - assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ - torch.device("cuda:0"), - torch.device("cuda:1"), - ] + # Mock CUDA as present so the device-existence validation passes on CPU-only runners. + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=2), + ): + assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] @pytest.mark.parametrize("value", [None, []]) From 2d3802a2bef19793a5d2dd8b6a95adcadbbe4fc6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 26 Jun 2026 16:55:26 -0400 Subject: [PATCH 25/33] fix(multi-gpu): stop RAM blowup/swapping during concurrent generations Three RAM fixes for multi-GPU (and one that helps single-GPU too), addressing transient spikes to ~100% RAM and swapping during text-encode/transformer loads: 1. Cap the global RAM-cache budget at a safe fraction of system RAM. When max_cache_ram_gb is unset, the budget was the *sum* of the per-device cache heuristics, so N GPUs each claiming ~50% of RAM summed to ~N*50% and starved the OS. Now clamp the sum to ModelCache.calc_system_ram_headroom_bytes() (50% of RAM - 2GB baseline, floored at 4GB). Promote the sizing magic numbers to named constants shared by the per-device heuristic and the global cap. 2. Adopt already-resident CPU weights across devices at load time. When a second device loads a model another device already holds, deep-copy a registered meta-weight structural clone and assign the shared canonical weights, instead of re-reading the model from disk and materializing a full transient second copy. Loader-agnostic (one mechanism in ModelLoader, no per-loader code): works for diffusers, single-file checkpoint, GGUF and transformers models, and preserves registered hooks (e.g. fp8 layerwise-cast). Best-effort with a meta-tensor self-check and fallback to a normal disk load on any failure. Skipped on single-device installs. 3. Dequantize FLUX.2 FP8 checkpoints straight to bf16. _dequantize_fp8_weights materialized the whole model in float32 (~36GB for 9B) before a later cast to bf16; now the multiply is done in float32 but stored bf16 per-weight, so the model is never held in float32. Numerically identical; halves the cold-load transient (helps single-GPU too). Co-Authored-By: Claude Opus 4.8 --- .../model_manager/model_manager_default.py | 16 +- .../model_manager/load/load_default.py | 112 +++++++++++++- .../load/model_cache/model_cache.py | 49 +++++- .../load/model_cache/shared_cpu_weights.py | 34 +++++ .../model_manager/load/model_loaders/flux.py | 13 +- .../test_model_cache_ram_budget.py | 49 +++++- .../load/test_shared_weight_adoption.py | 140 ++++++++++++++++++ 7 files changed, 400 insertions(+), 13 deletions(-) create mode 100644 tests/backend/model_manager/load/test_shared_weight_adoption.py diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 404fc8c72ee..176b61ddcab 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -125,12 +125,26 @@ def build_cache(device: torch.device) -> ModelCache: # Attach the single global RAM budget. The cap is the user's max_cache_ram_gb interpreted as a # true system-wide limit; when unset, it is the sum of the caches' individually-calculated # sizes, so each device keeps its effective capacity and weight deduplication becomes headroom. + # That sum is then clamped to a safe fraction of system RAM: each per-device heuristic already + # allows up to ~half of system RAM, so summing across N GPUs would otherwise claim ~N× that and + # leave nothing for the OS, causing swap thrashing. The clamp leaves real headroom; shared-weight + # dedup means the true footprint usually stays well under the cap regardless. gb = 2**30 distinct_caches = list(dict.fromkeys(ram_caches.values())) + # Cross-device weight adoption (and its per-model meta-shell capture) only pays off with more + # than one device cache; disable the capture cost otherwise. + shared_store.enable_shell_capture = len(distinct_caches) > 1 if app_config.max_cache_ram_gb is not None: global_ram_budget_bytes = int(app_config.max_cache_ram_gb * gb) else: - global_ram_budget_bytes = sum(c.local_ram_cache_size_bytes for c in distinct_caches) + summed_cache_bytes = sum(c.local_ram_cache_size_bytes for c in distinct_caches) + system_ram_headroom_bytes = ModelCache.calc_system_ram_headroom_bytes() + global_ram_budget_bytes = min(summed_cache_bytes, system_ram_headroom_bytes) + if global_ram_budget_bytes < summed_cache_bytes: + logger.info( + f"Capping model cache RAM budget at {global_ram_budget_bytes / gb:.2f} GB to leave system " + f"headroom (sum of per-device caches was {summed_cache_bytes / gb:.2f} GB)." + ) ram_budget = RamBudget(max_bytes=global_ram_budget_bytes, shared_store=shared_store) for cache in distinct_caches: cache.set_ram_budget(ram_budget) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 02929ff6132..de87c797e8e 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -1,6 +1,8 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Default implementation of model loading in InvokeAI.""" +import copy +import itertools import re from logging import Logger from pathlib import Path @@ -155,9 +157,26 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod pass config.path = str(self._get_model_path(config)) - self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) - with skip_torch_weight_init(): - loaded_model = self._load_model(config, submodel_type) + + # Fast path (multi-GPU): if another device already loaded this exact model, its canonical + # CPU weights are still resident in the shared store along with an empty (meta-weight) + # clone of the built module. Adopt those weights instead of re-reading the model from + # disk — this avoids both the redundant disk read and the large transient second copy + # that would otherwise spike RAM (and, on a RAM-constrained box, drive the system into + # swap). Any failure falls back to a normal load, so it can never change the result. + loaded_model = self._try_adopt_shared_weights(cache_key) + + shell_to_register: Optional[torch.nn.Module] = None + if loaded_model is None: + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + with skip_torch_weight_init(): + loaded_model = self._load_model(config, submodel_type) + # Snapshot a meta-weight clone now — before put() applies custom layers or any VRAM + # move — so the next device to load this model can adopt these weights (see above). + # Skipped in single-device setups, where no other cache will ever adopt it. + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is not None and shared_store.enable_shell_capture: + shell_to_register = self._build_meta_shell(loaded_model) # Determine execution device from model config, considering submodel type execution_device = self._get_execution_device(config, submodel_type) @@ -168,6 +187,13 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod execution_device=execution_device, ) + # Register the shell only after put() has created the shared entry (via the wrapper's + # acquire); it is dropped automatically when that entry's last reference is released. + if shell_to_register is not None: + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is not None: + shared_store.set_shell(cache_key, shell_to_register) + return self._ram_cache.get(key=cache_key, stats_name=stats_name) def get_size_fs( @@ -329,6 +355,86 @@ def post_hook(mod: torch.nn.Module, _args: object, _output: object) -> None: module.register_forward_pre_hook(pre_hook) module.register_forward_hook(post_hook, always_call=True) + def _try_adopt_shared_weights(self, cache_key: str) -> Optional[AnyModel]: + """Build this model by adopting another device's already-resident CPU weights, skipping the + disk read entirely. + + Returns the constructed model, or None if adoption is unavailable or fails for any reason (in + which case the caller loads the model from disk normally). Loader-agnostic: it deep-copies the + meta-weight shell that the first device registered (`_build_meta_shell`) and assigns the + shared canonical weights into the copy — no per-loader architecture knowledge required, and + fp8 cast hooks carried by the shell are preserved automatically. + + Must be called while holding the MODEL_LOAD_LOCK write lock (as `_load_and_cache` does), so + the peeked canonical weights and shell cannot be evicted between the peek and the adopt. + """ + shared_store = self._ram_cache.shared_cpu_weights + if shared_store is None: + return None + canonical = shared_store.peek(cache_key) + shell = shared_store.get_shell(cache_key) + if canonical is None or shell is None: + return None + + try: + # Independent module per device (its params will be moved to its own GPU); deep-copying an + # all-meta shell is cheap (no weight data). assign=True then re-points the copy's + # parameters at the shared canonical tensors with no allocation. + model = copy.deepcopy(shell) + model.load_state_dict(canonical, assign=True) + # Safety net: if anything is left on the meta device (e.g. a persistent buffer somehow + # missing from the canonical state dict) the model would silently produce wrong results. + for tensor in itertools.chain(model.parameters(), model.buffers()): + if tensor.is_meta: + raise RuntimeError("adopted model has tensors left on the meta device") + except Exception as e: + # Adoption is best-effort; never let it break a load. Fall back to a normal disk load. + self._logger.warning( + f"Could not adopt shared CPU weights for '{cache_key}' ({e!r}); loading from disk instead." + ) + return None + + self._logger.info( + f"Adopted shared CPU weights for '{cache_key}' from another device's cache (skipped disk load)." + ) + return model + + @staticmethod + def _build_meta_shell(model: AnyModel) -> Optional[torch.nn.Module]: + """Return an empty, meta-weight structural clone of `model`, or None if it can't be cloned. + + The clone has the identical module structure, registered hooks (e.g. the fp8 layerwise-cast + hooks), and non-persistent buffers as `model`, but every parameter and persistent buffer is + replaced by a 0-byte tensor on the `meta` device. A second device adopts it by deep-copying + and assigning the shared canonical weights — so this works for every model family (diffusers, + single-file checkpoint, GGUF, transformers) without any per-loader code. + + Best-effort: returns None on any failure (the model then simply isn't adoptable, and the next + device loads it from disk as before). + """ + if not isinstance(model, torch.nn.Module): + return None + try: + # Persistent buffers come from the canonical state dict on adoption, so they (like params) + # are replaced by meta placeholders. Non-persistent buffers are NOT in the state dict, so + # they must be carried over with real data (deepcopy copies them); they are typically + # small (e.g. rotary-embedding tables, attention masks). + persistent_names = set(model.state_dict().keys()) + persistent_buffer_ids = {id(b) for n, b in model.named_buffers() if n in persistent_names} + + memo: dict[int, object] = {} + for param in model.parameters(recurse=True): + memo[id(param)] = torch.nn.Parameter( + torch.empty_like(param, device="meta"), requires_grad=param.requires_grad + ) + for buffer in model.buffers(recurse=True): + if id(buffer) in persistent_buffer_ids: + memo[id(buffer)] = torch.empty_like(buffer, device="meta") + + return copy.deepcopy(model, memo) + except Exception: + return None + # This needs to be implemented in the subclass def _load_model( self, diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 69547ee5fb1..f370528cda6 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -40,6 +40,17 @@ # Size of a MB in bytes. MB = 2**20 +# Default RAM-cache sizing constants. These are used both by the per-device heuristic +# (_calc_ram_available_to_model_cache) and by the multi-GPU global budget cap +# (ModelManagerService.build_model_manager), so the two stay consistent. +# +# - RAM_CACHE_SYSTEM_FRACTION: fraction of total system RAM the model cache may use by default. +# - RAM_CACHE_BASELINE_BYTES: assumed non-model RAM used by InvokeAI itself, reserved before sizing. +# - MIN_RAM_CACHE_BYTES: absolute floor so the cache is never sized uselessly small. +RAM_CACHE_SYSTEM_FRACTION = 0.5 +RAM_CACHE_BASELINE_BYTES = 2 * GB +MIN_RAM_CACHE_BYTES = 4 * GB + class _ModelLoadReadWriteLock: """A write-preferring readers-writer lock that serializes model construction against VRAM moves. @@ -317,6 +328,15 @@ def execution_device(self) -> torch.device: """Return the default execution device this cache loads models onto.""" return self._execution_device + @property + def shared_cpu_weights(self) -> SharedCpuWeightsStore | None: + """The process-global store this cache deduplicates CPU weights into, or None if disabled. + + Exposed so the loader can check (via `peek`) whether another device already holds a model's + canonical CPU weights and adopt them at construction time instead of re-reading from disk. + """ + return self._shared_cpu_weights + def set_ram_budget(self, ram_budget: RamBudget) -> None: """Attach the shared global RamBudget after construction. @@ -795,8 +815,10 @@ def _calc_ram_available_to_model_cache(self) -> int: heuristics_applied = [1] total_system_ram_bytes = psutil.virtual_memory().total # Assumed baseline RAM used by InvokeAI for non-model stuff. - baseline_ram_used_by_invokeai = 2 * GB - ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai) + baseline_ram_used_by_invokeai = RAM_CACHE_BASELINE_BYTES + ram_available_to_model_cache = int( + total_system_ram_bytes * RAM_CACHE_SYSTEM_FRACTION - baseline_ram_used_by_invokeai + ) # Apply heuristic 2. # ------------------ @@ -812,15 +834,34 @@ def _calc_ram_available_to_model_cache(self) -> int: # Apply heuristic 3. # ------------------ - if ram_available_to_model_cache < 4 * GB: + if ram_available_to_model_cache < MIN_RAM_CACHE_BYTES: heuristics_applied.append(3) - ram_available_to_model_cache = 4 * GB + ram_available_to_model_cache = MIN_RAM_CACHE_BYTES self._logger.info( f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}." ) return ram_available_to_model_cache + @staticmethod + def calc_system_ram_headroom_bytes() -> int: + """The default system-wide cap on TOTAL model-cache RAM, leaving headroom for the OS. + + This is the maximum RAM the model caches should collectively use when the user has not set an + explicit `max_cache_ram_gb`. It mirrors heuristic 1 of `_calc_ram_available_to_model_cache` + (a fraction of system RAM, less InvokeAI's baseline) with the same minimum floor. + + In multi-GPU mode there is one cache per device, and each device's heuristic independently + allows up to this fraction of system RAM; summed across N devices that would claim ~N× as + much RAM and cause the system to swap. The model manager uses this value to cap that sum so a + safe amount of RAM is always left for the OS and other processes. + """ + total_system_ram_bytes = psutil.virtual_memory().total + return max( + int(total_system_ram_bytes * RAM_CACHE_SYSTEM_FRACTION) - RAM_CACHE_BASELINE_BYTES, + MIN_RAM_CACHE_BYTES, + ) + def _get_ram_in_use(self) -> int: """Get the amount of RAM currently in use. diff --git a/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py index 4b1c634a25b..4ce456e45b6 100644 --- a/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py +++ b/invokeai/backend/model_manager/load/model_cache/shared_cpu_weights.py @@ -15,6 +15,10 @@ class _SharedWeightsEntry: # Number of per-device cached models currently aliasing this entry. The entry is freed # (its RAM released) when this drops to zero. refcount: int = 0 + # An empty (meta-weight) structural clone of the first-built module, used so a second device can + # adopt the canonical weights without re-reading the model from disk. None until registered (and + # for entries whose model isn't an nn.Module). Holds ~no real RAM: its weights are on `meta`. + shell: object | None = None _key_bytes: dict[str, int] = field(default_factory=dict) @@ -47,6 +51,10 @@ class SharedCpuWeightsStore: def __init__(self) -> None: self._lock = threading.Lock() self._entries: dict[str, _SharedWeightsEntry] = {} + # Whether to capture per-model meta-weight shells for cross-device adoption. Only useful with + # more than one device cache, so the model manager disables it in single-device setups to + # avoid the (small) per-first-load clone cost. See ModelLoader._build_meta_shell. + self.enable_shell_capture: bool = True def acquire(self, key: str, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Adopt the canonical CPU state dict for `key`, registering `state_dict` as canonical if @@ -72,6 +80,32 @@ def acquire(self, key: str, state_dict: dict[str, torch.Tensor]) -> dict[str, to entry.refcount += 1 return entry.state_dict + def peek(self, key: str) -> dict[str, torch.Tensor] | None: + """Return the canonical state dict for `key` WITHOUT changing its refcount, or None if absent. + + Used by the loader to adopt already-resident weights at construction time (skipping the disk + read) when another device has already loaded this model. The reference is taken later, in the + cached-model wrapper's `acquire()`, exactly as for a normal load — so this peek must not + itself increment the count. + """ + with self._lock: + entry = self._entries.get(key) + return entry.state_dict if entry is not None else None + + def set_shell(self, key: str, shell: object) -> None: + """Register the empty (meta-weight) structural clone for `key`, if an entry exists and none + is set yet. A no-op when the key has no canonical entry (e.g. keep_ram_copy disabled).""" + with self._lock: + entry = self._entries.get(key) + if entry is not None and entry.shell is None: + entry.shell = shell + + def get_shell(self, key: str) -> object | None: + """Return the registered meta-weight shell for `key`, or None if absent.""" + with self._lock: + entry = self._entries.get(key) + return entry.shell if entry is not None else None + def release(self, key: str) -> None: """Release one reference to `key`'s canonical state dict, freeing it when the count hits 0. diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index b3c46d04db3..739ba458888 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -1080,7 +1080,13 @@ def _dequantize_fp8_weights(self, sd: dict) -> dict: if block_size > 1: scale = scale.repeat_interleave(block_size, dim=dim) - sd[weight_key] = weight_float * scale + # Do the multiply in float32 for precision, but store bf16 (FLUX.2's compute dtype) + # immediately so the *whole* model is never materialized in float32. Holding every + # dequantized weight as float32 here doubled RAM transiently (~36GB vs ~17GB for a 9B + # model) and was the dominant cold-load spike, especially with two GPUs. The result is + # identical to the previous code, which cast the same values to bf16 a few steps later. + sd[weight_key] = (weight_float * scale).to(torch.bfloat16) + del weight_float # Filter out scale metadata keys and other FP8 metadata keys_to_remove = [ @@ -1110,8 +1116,9 @@ def _dequantize_fp8_weights(self, sd: dict) -> dict: del sd[k] for key in keys_to_convert: - # Convert FP8 tensor to float32 - sd[key] = sd[key].float() + # Convert native FP8 tensors straight to bf16 (FLUX.2's compute dtype) rather than float32, + # so a cold load never transiently holds the whole model in float32 (see the scaled path). + sd[key] = sd[key].to(torch.bfloat16) return sd diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py index ad248c75716..cb40827702c 100644 --- a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py @@ -6,11 +6,17 @@ """ import logging -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + GB, + MIN_RAM_CACHE_BYTES, + RAM_CACHE_BASELINE_BYTES, + RAM_CACHE_SYSTEM_FRACTION, + ModelCache, +) from invokeai.backend.model_manager.load.model_cache.ram_budget import RamBudget from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore from invokeai.backend.util.calc_tensor_size import calc_tensor_size @@ -92,6 +98,45 @@ def test_global_budget_evicts_lru_in_single_cache(mock_logger): cache.shutdown() +def _mock_total_ram(total_bytes: int): + """Patch psutil.virtual_memory().total as seen by model_cache.""" + vm = MagicMock() + vm.total = total_bytes + return patch( + "invokeai.backend.model_manager.load.model_cache.model_cache.psutil.virtual_memory", + return_value=vm, + ) + + +def test_system_ram_headroom_is_fraction_minus_baseline(): + # On a 96 GB box, the default cap is 50% - 2 GB = 46 GB, leaving real headroom for the OS. + with _mock_total_ram(96 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + assert headroom == int(96 * GB * RAM_CACHE_SYSTEM_FRACTION) - RAM_CACHE_BASELINE_BYTES + assert headroom == 46 * GB + # And it must leave at least half the system for everything else. + assert headroom <= 96 * GB * 0.5 + + +def test_system_ram_headroom_respects_floor_on_tiny_systems(): + # A machine with almost no RAM still gets the absolute minimum, never a negative/zero budget. + with _mock_total_ram(2 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + assert headroom == MIN_RAM_CACHE_BYTES + + +def test_headroom_clamps_summed_multi_gpu_budget(): + # Reproduces the multi-GPU blowup: two 45 GB per-device caches sum to 90 GB, which would leave + # only ~6 GB on a 96 GB machine. The headroom cap must clamp the budget below that sum. + per_device_cache_bytes = 45 * GB + summed = 2 * per_device_cache_bytes # 90 GB, as the old code used verbatim + with _mock_total_ram(96 * GB): + headroom = ModelCache.calc_system_ram_headroom_bytes() + clamped = min(summed, headroom) + assert clamped == headroom < summed + assert clamped == 46 * GB + + def test_eviction_cannot_free_ram_held_by_another_device(mock_logger): """If a cache's only droppable model is still held by another device, eviction frees nothing globally (the shared weights stay live) and the new model is still admitted -> transiently over diff --git a/tests/backend/model_manager/load/test_shared_weight_adoption.py b/tests/backend/model_manager/load/test_shared_weight_adoption.py new file mode 100644 index 00000000000..3393e2af91b --- /dev/null +++ b/tests/backend/model_manager/load/test_shared_weight_adoption.py @@ -0,0 +1,140 @@ +"""Tests for load-time adoption of shared CPU weights (multi-GPU RAM-spike fix). + +When a second device loads a model that another device already holds, the loader deep-copies the +empty (meta-weight) structural shell the first device registered and assigns the canonical CPU +weights into it — instead of re-reading the model from disk and materializing a full transient +second copy. This is loader-agnostic (no per-model-family code): it works by cloning a built module, +so it covers diffusers, single-file checkpoints, GGUF and transformers models alike, and preserves +any registered hooks (e.g. fp8 layerwise-cast hooks). +""" + +from unittest.mock import MagicMock + +import torch + +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.shared_cpu_weights import SharedCpuWeightsStore + + +class _TinyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(4, 4) + # A non-persistent buffer: not in the state dict, so adoption must carry it over with data. + self.register_buffer("scale", torch.tensor([2.0]), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin(x) * self.scale + + +def _loader_with_store(store: SharedCpuWeightsStore | None) -> ModelLoader: + loader = ModelLoader.__new__(ModelLoader) # bypass __init__ (needs app deps we don't use here) + loader._logger = MagicMock() + loader._ram_cache = MagicMock() + loader._ram_cache.shared_cpu_weights = store + return loader + + +def _populate(store: SharedCpuWeightsStore, key: str, model: torch.nn.Module) -> None: + """Mimic the first device's load: register canonical weights + a meta shell for `model`.""" + store.acquire(key, model.state_dict()) + shell = ModelLoader._build_meta_shell(model) + assert shell is not None + store.set_shell(key, shell) + + +def test_meta_shell_has_no_real_weight_storage(): + model = _TinyModel() + shell = ModelLoader._build_meta_shell(model) + assert shell is not None + # Parameters are on meta (0 bytes); the non-persistent buffer keeps real data. + assert all(p.is_meta for p in shell.parameters()) + assert not shell.scale.is_meta + assert torch.equal(shell.scale, model.scale) + + +def test_build_meta_shell_returns_none_for_non_module(): + assert ModelLoader._build_meta_shell({"not": "a module"}) is None # type: ignore[arg-type] + + +def test_adopts_canonical_weights_without_copying(): + store = SharedCpuWeightsStore() + source = _TinyModel() + _populate(store, "m", source) + canonical = store.peek("m") + refcount_before = store.refcount("m") + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + + assert model is not None + # The adopted params ARE the canonical tensors (assign=True, no copy) -> no extra RAM. + assert model.lin.weight.data_ptr() == canonical["lin.weight"].data_ptr() + assert model.lin.bias.data_ptr() == canonical["lin.bias"].data_ptr() + assert not any(t.is_meta for t in model.parameters()) + assert not any(t.is_meta for t in model.buffers()) + # peek()/get_shell() must not have taken a reference -- the wrapper's acquire() does that later. + assert store.refcount("m") == refcount_before + + +def test_adopted_model_produces_correct_output(): + store = SharedCpuWeightsStore() + source = _TinyModel() + _populate(store, "m", source) + x = torch.randn(3, 4) + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + + assert torch.allclose(model(x), source(x), atol=1e-6) + + +def test_adoption_preserves_forward_hooks(): + # fp8 layerwise casting is implemented as forward hooks; cloning the built module must keep them. + store = SharedCpuWeightsStore() + source = _TinyModel() + fired: list[str] = [] + source.lin.register_forward_pre_hook(lambda mod, args: fired.append("pre")) + _populate(store, "m", source) + + model = _loader_with_store(store)._try_adopt_shared_weights("m") + model(torch.randn(1, 4)) + + assert fired == ["pre"] # the cloned module's hook fired + + +def test_no_shell_means_no_adoption(): + # Canonical present but no shell registered (e.g. first device couldn't clone) -> fall back. + store = SharedCpuWeightsStore() + store.acquire("m", _TinyModel().state_dict()) + assert _loader_with_store(store)._try_adopt_shared_weights("m") is None + + +def test_absent_key_means_no_adoption(): + assert _loader_with_store(SharedCpuWeightsStore())._try_adopt_shared_weights("missing") is None + + +def test_no_shared_store_means_no_adoption(): + assert _loader_with_store(None)._try_adopt_shared_weights("m") is None + + +def test_mismatched_canonical_falls_back_safely(): + # If the canonical weights don't match the shell's structure, adoption must fail soft (-> None), + # not raise, so the caller can load normally. + store = SharedCpuWeightsStore() + source = _TinyModel() + shell = ModelLoader._build_meta_shell(source) + assert shell is not None + store.acquire("m", {"unexpected.key": torch.zeros(2)}) # wrong state dict + store.set_shell("m", shell) + + loader = _loader_with_store(store) + assert loader._try_adopt_shared_weights("m") is None + loader._logger.warning.assert_called_once() + + +def test_shell_dropped_when_entry_released(): + store = SharedCpuWeightsStore() + _populate(store, "m", _TinyModel()) + assert store.get_shell("m") is not None + store.release("m") # last reference -> entry (and its shell) gone + assert store.get_shell("m") is None + assert "m" not in store From 69039113f034bcb5955927b18867e601f22cef72 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 26 Jun 2026 22:53:56 -0400 Subject: [PATCH 26/33] fix(qwen-image): reserve VAE working memory so decode/encode don't OOM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Qwen Image VAE encode/decode invocations called model_on_device() without a working-memory estimate, unlike every other VAE family (SD/SDXL/SD3/CogView4/FLUX). So the model cache reserved only its small default working memory, never offloaded a large resident transformer (the VAE weights themselves are tiny), and the VAE's forward-pass activations then OOM'd VRAM — e.g. a ~40GB Qwen Image Edit transformer left ~1GB free while decode needed ~5GB. Reproduces single-GPU; unrelated to the multi-GPU RAM work. Add estimate_vae_working_memory_qwen_image() (same per-output-pixel scaling as the other estimators, handling the 5D Qwen latents) and pass it from both the i2l (encode, used for reference images in Image Edit) and l2i (decode) nodes, so the cache offloads the transformer before the VAE runs. Co-Authored-By: Claude Opus 4.8 --- .../qwen_image_image_to_latents.py | 6 ++++- .../qwen_image_latents_to_image.py | 6 ++++- invokeai/backend/util/vae_working_memory.py | 27 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/qwen_image_image_to_latents.py b/invokeai/app/invocations/qwen_image_image_to_latents.py index ef88e03082b..cac536f00c9 100644 --- a/invokeai/app/invocations/qwen_image_image_to_latents.py +++ b/invokeai/app/invocations/qwen_image_image_to_latents.py @@ -18,6 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image @invocation( @@ -44,7 +45,10 @@ class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard) @staticmethod def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info.model_on_device() as (_, vae): + # Reserve working memory for the encode so the cache offloads any large resident model first; + # otherwise the encode's activations OOM (the VAE weights themselves are tiny). + estimated_working_memory = estimate_vae_working_memory_qwen_image("encode", image_tensor, vae_info.model) + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKLQwenImage) vae.disable_tiling() diff --git a/invokeai/app/invocations/qwen_image_latents_to_image.py b/invokeai/app/invocations/qwen_image_latents_to_image.py index b3ea39c4bbf..6b03e903d7d 100644 --- a/invokeai/app/invocations/qwen_image_latents_to_image.py +++ b/invokeai/app/invocations/qwen_image_latents_to_image.py @@ -19,6 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_qwen_image @invocation( @@ -41,9 +42,12 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, AutoencoderKLQwenImage) + # Reserve working memory for the decode so the cache offloads any large resident model (e.g. + # the transformer) first; otherwise the decode's activations OOM. See estimator for details. + estimated_working_memory = estimate_vae_working_memory_qwen_image("decode", latents, vae_info.model) with ( SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), - vae_info.model_on_device() as (_, vae), + vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae), ): context.util.signal_progress("Running VAE") assert isinstance(vae, AutoencoderKLQwenImage) diff --git a/invokeai/backend/util/vae_working_memory.py b/invokeai/backend/util/vae_working_memory.py index f9228ced652..8edd0a794f2 100644 --- a/invokeai/backend/util/vae_working_memory.py +++ b/invokeai/backend/util/vae_working_memory.py @@ -2,6 +2,7 @@ import torch from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR @@ -92,6 +93,32 @@ def estimate_vae_working_memory_flux( return int(working_memory) +def estimate_vae_working_memory_qwen_image( + operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKLQwenImage +) -> int: + """Estimate the working memory required by the invocation in bytes. + + Without this, the Qwen Image VAE encode/decode passes no working-memory estimate to the model + cache, so the cache reserves only its small default and never offloads a large resident + transformer (the VAE weights themselves are tiny). The decode then OOMs on its activations. This + mirrors the other VAE estimators: peak working memory scales ~linearly with the number of output + pixels and the element size. The Qwen Image latents are 5D (B, C, frames, H, W); the trailing two + dims are spatial, same as the 2D VAEs. See #8414. + """ + latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1 + + h = latent_scale_factor_for_operation * image_tensor.shape[-2] + w = latent_scale_factor_for_operation * image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + + # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 + # Encoding uses ~45% the working memory as decoding. + scaling_constant = 2200 if operation == "decode" else 1100 + working_memory = h * w * element_size * scaling_constant + + return int(working_memory) + + def estimate_vae_working_memory_sd3( operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL ) -> int: From 3e917dd2df1b708a25124c8ebb701b8257ba81c5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 27 Jun 2026 10:17:27 -0400 Subject: [PATCH 27/33] fix(flux2): tile reference-image VAE encode to avoid VRAM OOM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The FLUX.2 VAE encoder's mid-block self-attention scales quadratically with the input's spatial size, and on ROCm scaled_dot_product_attention falls back to a materialized attention matrix. Encoding a reference image (kontext) at full size therefore allocated ~15GB in a single attention call at 1024px — and hundreds of GB at the 2024px reference cap — OOMing VRAM regardless of how much other model memory was freed. Tile the reference-image encode to bound per-tile attention. The VAE's default tile size equals its sample_size (1024), whose per-tile attention still OOMs, so force a 512px tile (with a matching latent tile size derived from the config). Save/restore the VAE's tiling config since it is a shared, cached instance, so the final image decode does not inherit these settings. Co-Authored-By: Claude Opus 4.8 --- invokeai/backend/flux2/ref_image_extension.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/flux2/ref_image_extension.py b/invokeai/backend/flux2/ref_image_extension.py index 9cc6240db66..368f3c4452f 100644 --- a/invokeai/backend/flux2/ref_image_extension.py +++ b/invokeai/backend/flux2/ref_image_extension.py @@ -208,16 +208,32 @@ def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]: vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) - # FLUX.2 VAE uses diffusers API - latent_dist = vae.encode(image_tensor, return_dict=False)[0] - - # Use mode() for deterministic encoding (no sampling) - if hasattr(latent_dist, "mode"): - ref_image_latents_unpacked = latent_dist.mode() - elif hasattr(latent_dist, "sample"): - ref_image_latents_unpacked = latent_dist.sample() - else: - ref_image_latents_unpacked = latent_dist + # The FLUX.2 VAE encoder's mid-block self-attention scales quadratically with the + # input's spatial size (and on ROCm, SDPA falls back to a *materialized* attention + # matrix), so encoding a reference image at full size OOMs VRAM — ~15GB at 1024px, + # hundreds of GB at the 2024px reference cap. Tile the encode to bound peak memory + # regardless of reference resolution. The VAE's default tile size equals its + # sample_size (1024), which still OOMs per tile, so force a smaller 512px tile. + # Save/restore the tiling config because this VAE is a shared, cached instance (e.g. + # the final image decode must not inherit these settings). + downsample = 2 ** (len(vae.config.block_out_channels) - 1) + prev_tiling = (vae.use_tiling, vae.tile_sample_min_size, vae.tile_latent_min_size) + vae.use_tiling = True + vae.tile_sample_min_size = 512 + vae.tile_latent_min_size = 512 // downsample + try: + # FLUX.2 VAE uses diffusers API + latent_dist = vae.encode(image_tensor, return_dict=False)[0] + + # Use mode() for deterministic encoding (no sampling) + if hasattr(latent_dist, "mode"): + ref_image_latents_unpacked = latent_dist.mode() + elif hasattr(latent_dist, "sample"): + ref_image_latents_unpacked = latent_dist.sample() + else: + ref_image_latents_unpacked = latent_dist + finally: + vae.use_tiling, vae.tile_sample_min_size, vae.tile_latent_min_size = prev_tiling TorchDevice.empty_cache() From 902c77d8e957d59d53fde3e6d48155c6085db439 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 27 Jun 2026 10:17:27 -0400 Subject: [PATCH 28/33] fix(multi-gpu): query execution device for VRAM-in-use accounting ModelCache._get_vram_in_use() called torch.cuda.memory_allocated() with no device argument, while _get_vram_available() reads memory_allocated(execution_device). The formula relies on those two canceling. In multi-GPU mode each worker calls torch.cuda.set_device for its own GPU, so the process-current device flips between workers; the no-argument call can then read a different (e.g. idle) GPU's allocation, breaking the cancellation and inflating "available" VRAM toward the card total. The cache then believes there is room and never offloads, so VRAM offloading effectively ignores device_working_mem_gb in multi-GPU. Single-GPU was unaffected (current device always equals the execution device). Query self._execution_device in both _get_vram_in_use() and the cache-state debug log. Add a regression test asserting the per-cache execution device is used. Co-Authored-By: Claude Opus 4.8 --- .../load/model_cache/model_cache.py | 15 ++++++++-- .../test_model_cache_ram_budget.py | 30 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index f370528cda6..7bc6931e5a3 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -770,7 +770,13 @@ def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int: def _get_vram_in_use(self) -> int: """Get the amount of VRAM currently in use by the cache.""" if self._execution_device.type == "cuda": - return torch.cuda.memory_allocated() + # Must be queried for THIS cache's execution device, not the process-current device. In + # multi-GPU mode each worker calls torch.cuda.set_device for its own GPU, so the current + # device flips between workers; querying without the device argument can read a different + # (e.g. idle) GPU's allocation. That breaks the cancellation in _get_vram_available + # (which adds vram_allocated(execution_device)), inflating "available" toward total VRAM + # so the cache never offloads — causing VRAM OOMs that ignore device_working_mem_gb. + return torch.cuda.memory_allocated(self._execution_device) elif self._execution_device.type == "mps": return torch.mps.current_allocated_memory() else: @@ -967,7 +973,12 @@ def _log_cache_state(self, title: str = "Model cache state:", include_entry_deta ) if torch.cuda.is_available(): - log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB) + # Query this cache's execution device (not the process-current one) for correct + # per-device numbers in multi-GPU mode. See _get_vram_in_use. + allocated = ( + torch.cuda.memory_allocated(self._execution_device) if self._execution_device.type == "cuda" else 0 + ) + log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", allocated / MB) log += " {:<30} {}\n".format("Total models:", len(self._cached_models)) if include_entry_details and len(self._cached_models) > 0: diff --git a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py index cb40827702c..3ceb36c8dd6 100644 --- a/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py +++ b/tests/backend/model_manager/load/model_cache/test_model_cache_ram_budget.py @@ -98,6 +98,36 @@ def test_global_budget_evicts_lru_in_single_cache(mock_logger): cache.shutdown() +def test_get_vram_in_use_queries_this_caches_execution_device(mock_logger): + """Regression: _get_vram_in_use must query its own execution device, not the process-current one. + + In multi-GPU mode each worker calls torch.cuda.set_device for its GPU, so a no-argument + memory_allocated() can read a different device. That breaks the cancellation in + _get_vram_available and inflates "available" VRAM, so the cache never offloads and OOMs while + ignoring device_working_mem_gb. + """ + import torch + + mc = "invokeai.backend.model_manager.load.model_cache.model_cache" + with ( + patch(f"{mc}.torch.cuda.mem_get_info", return_value=(10 * GB, 48 * GB)), + patch(f"{mc}.torch.cuda.memory_allocated", return_value=42) as mock_alloc, + ): + cache = ModelCache( + execution_device_working_mem_gb=3.0, + enable_partial_loading=True, + keep_ram_copy_of_weights=True, + execution_device="cuda:1", + storage_device="cpu", + logger=mock_logger, + ) + try: + assert cache._get_vram_in_use() == 42 + mock_alloc.assert_called_with(torch.device("cuda:1")) + finally: + cache.shutdown() + + def _mock_total_ram(total_bytes: int): """Patch psutil.virtual_memory().total as seen by model_cache.""" vm = MagicMock() From aa6fec8d081cc8fdafa0b48f91584bc7119f4b6c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 27 Jun 2026 15:32:11 -0400 Subject: [PATCH 29/33] fix(qwen-image): calibrate VAE working-memory estimate to the 3D-conv decode peak MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Qwen Image VAE is a 3D-conv (video) VAE whose decode allocates large conv3d feature maps. A ~1MP decode was measured to peak at ~17 GiB of VRAM — far above what the generic 2200/1100 SD/FLUX constants reserved (~4.6 GiB), so the cache concluded the decode "fit" alongside the resident 20GB transformer + 15GB text encoder, never offloaded them, and OOMed. The offload only frees ~(working_mem - free) bytes, so the reservation must both cover the real peak and be large enough to trigger the offload of models the decode doesn't need. Raise the Qwen decode/encode constants (13000/6500) to match the measured peak. It's linear in output pixels, so it over-reserves past ~1.5MP (where the decode can exceed the card even after offloading) — that case is covered by force_tiled_decode. Co-Authored-By: Claude Opus 4.8 --- invokeai/backend/util/vae_working_memory.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/util/vae_working_memory.py b/invokeai/backend/util/vae_working_memory.py index 8edd0a794f2..8b91dc54161 100644 --- a/invokeai/backend/util/vae_working_memory.py +++ b/invokeai/backend/util/vae_working_memory.py @@ -111,9 +111,16 @@ def estimate_vae_working_memory_qwen_image( w = latent_scale_factor_for_operation * image_tensor.shape[-1] element_size = next(vae.parameters()).element_size() - # This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414 - # Encoding uses ~45% the working memory as decoding. - scaling_constant = 2200 if operation == "decode" else 1100 + # Calibrated for the Qwen Image VAE, a 3D-conv (video) VAE whose decode allocates large conv3d + # feature maps — a ~1MP decode was measured to peak at ~17 GiB of VRAM, far above the 2D SD/FLUX + # VAEs the generic 2200/1100 constants were tuned for. The reservation must cover that peak AND be + # large enough to make the cache offload an otherwise-resident transformer + text encoder (which + # the decode doesn't need): the offload only frees ~(working_mem - free) bytes, so under-reserving + # leaves the big models resident and the decode OOMs. Over-reserving is safe here (it just offloads + # models the decode doesn't use). Encoding uses ~half the working memory of decoding. + # NOTE: this is linear in output pixels; a sufficiently large output (>~1.5MP) can still exceed + # the card even after offloading everything — that case needs tiled decode, handled separately. + scaling_constant = 13000 if operation == "decode" else 6500 working_memory = h * w * element_size * scaling_constant return int(working_memory) From 43a46bd8580752b09628325cce4eb2e9ddbee6b9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 27 Jun 2026 15:32:11 -0400 Subject: [PATCH 30/33] feat(qwen-image): honor force_tiled_decode in the l2i node The Qwen Image latents-to-image node hardcoded vae.disable_tiling(), ignoring the global force_tiled_decode setting that the SD/SDXL l2i node honors. Wire it up the same way so users can opt into tiled VAE decode for very large outputs that exceed VRAM even after the transformer/text encoder are offloaded. Off by default, so normal-size decodes are unchanged (full-frame, no tile blending). Co-Authored-By: Claude Opus 4.8 --- .../app/invocations/qwen_image_latents_to_image.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/qwen_image_latents_to_image.py b/invokeai/app/invocations/qwen_image_latents_to_image.py index 6b03e903d7d..c418fe43cbe 100644 --- a/invokeai/app/invocations/qwen_image_latents_to_image.py +++ b/invokeai/app/invocations/qwen_image_latents_to_image.py @@ -53,7 +53,15 @@ def invoke(self, context: InvocationContext) -> ImageOutput: assert isinstance(vae, AutoencoderKLQwenImage) latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) - vae.disable_tiling() + # Honor the global force_tiled_decode setting, like the SD/SDXL l2i node. Tiling bounds the + # VAE's per-tile memory, which is the scalable way to decode very large outputs that would + # exceed VRAM even after offloading the transformer/text encoder. For normal sizes, leave + # it off (faster, no tile blending) — the reserved working memory offloads other models so + # the full-frame decode fits. + if context.config.get().force_tiled_decode: + vae.enable_tiling() + else: + vae.disable_tiling() tiling_context = nullcontext() From b84d4507895c721cac80e72b0382f7675bd990ce Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 27 Jun 2026 15:43:29 -0400 Subject: [PATCH 31/33] fix(ui): stop progress disk flashing during indeterminate phases The preview-panel progress circle re-renders on every InvocationProgressEvent. The parent passes a fresh progressEvent object each event, so the CircularProgress re-rendered constantly; during the indeterminate phases (everything except denoising) that restarted its CSS spin animation each time, which looked like the disk flashing. (Determinate denoising was unaffected because the value genuinely changes per step.) Split the circle into a memoized, ref-forwarding subcomponent keyed on its visual props (isIndeterminate, value, device label) so message-only updates no longer re-render it and the spin animation stays continuous. The Tooltip still anchors to it via the forwarded ref. Co-Authored-By: Claude Opus 4.8 --- .../ImageViewer/ProgressIndicator2.tsx | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx index 30e7312e2aa..dde90ead5fd 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx @@ -1,7 +1,8 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; import { useProgressDeviceLabel } from 'common/hooks/useProgressDeviceLabel'; -import { memo } from 'react'; +import type { ComponentRef } from 'react'; +import { forwardRef, memo } from 'react'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; @@ -29,23 +30,37 @@ const labelStyles: SystemStyleObject = { pointerEvents: 'none', }; +type ProgressDeviceLabel = ReturnType; + +// The circle is split out and memoized so it does NOT re-render when only the tooltip message +// changes. Every progress event re-renders the parent, and during the indeterminate phases +// (everything except denoising) those events keep the same `isIndeterminate`/`value` — but +// re-rendering the CircularProgress restarts its CSS spin animation, which reads as the disk +// "flashing". Memoizing on the visual props keeps the animation continuous. forwardRef so the +// wrapping Tooltip can still anchor to it. +const ProgressCircle = memo( + forwardRef, { deviceLabel: ProgressDeviceLabel } & CircularProgressProps>( + ({ deviceLabel, ...rest }, ref) => ( + + {deviceLabel && {deviceLabel.index}} + + ) + ) +); +ProgressCircle.displayName = 'ProgressCircle'; + export const ProgressIndicator = memo( ({ progressEvent, ...rest }: { progressEvent: S['InvocationProgressEvent'] } & CircularProgressProps) => { const deviceLabel = useProgressDeviceLabel(progressEvent?.device); const message = formatProgressMessage(progressEvent); return ( - - {deviceLabel && {deviceLabel.index}} - + /> ); } From 0037a21835fb67071d0b785480d620e014822c81 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Jun 2026 19:29:09 -0400 Subject: [PATCH 32/33] feat(multi-gpu): offload text encoders to idle GPUs Adds `offload_text_encoders_to_idle_gpus` (default on): when more than one generation device is configured and a GPU is idle, a session's text/prompt encoder runs on the idle GPU instead of the one running its denoise pipeline. This avoids evicting the denoise model from VRAM to make room for the encoder, and lets a cached encoder be reused across generations. Under full load (no idle GPU) behavior is unchanged. Mechanism: - New GENERATION_DEVICE_POOL arbiter (backend/util/device_pool.py) with a per-device exclusive-use lock. A native session blocking-acquires its own device's lock for the whole run; an encoder node try-borrows an idle device's lock for the duration of the node. This makes a borrowed encoder and a native session mutually exclusive on a GPU -- preventing the shared-encoder corruption that produced garbled images -- and is deadlock-free (borrows are non-blocking; a session only ever blocks on its own device). - DefaultSessionRunner re-pins the worker thread to the borrowed device for the whole encoder node; conditioning is stored on the CPU and the denoiser picks it up on its own GPU afterward. - Nodes opt in via @invocation(idle_gpu_offloadable=True), mirroring the existing `bottleneck` ClassVar marker. Applied to the text/prompt encoder nodes (compel + sdxl/refiner, flux, sd3, qwen-image, anima, cogview4, flux2 klein, z-image, flux_redux). Inspired by #9310; supersedes it. Tests: device-pool lock semantics, two concurrency regression tests asserting a session and a borrow never use a GPU at the same time, the runner offload context-manager behavior, and a marker-wiring check. Docs: invokeai-yaml.mdx (config setting) and creating-nodes.mdx (how to support the feature in a node). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../docs/configuration/invokeai-yaml.mdx | 21 ++ .../development/Guides/creating-nodes.mdx | 33 +++ docs/src/generated/settings.json | 11 + .../app/invocations/anima_text_encoder.py | 1 + invokeai/app/invocations/baseinvocation.py | 10 + .../app/invocations/cogview4_text_encoder.py | 1 + invokeai/app/invocations/compel.py | 3 + .../invocations/flux2_klein_text_encoder.py | 1 + invokeai/app/invocations/flux_redux.py | 1 + invokeai/app/invocations/flux_text_encoder.py | 1 + .../invocations/qwen_image_text_encoder.py | 1 + invokeai/app/invocations/sd3_text_encoder.py | 1 + .../app/invocations/z_image_text_encoder.py | 1 + .../app/services/config/config_default.py | 1 + .../session_processor_default.py | 65 +++++- invokeai/backend/util/device_pool.py | 119 +++++++++++ invokeai/frontend/web/openapi.json | 6 + .../frontend/web/src/services/api/schema.ts | 6 + .../session_processor/test_encoder_offload.py | 201 ++++++++++++++++++ tests/backend/util/test_device_pool.py | 158 ++++++++++++++ 20 files changed, 636 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/util/device_pool.py create mode 100644 tests/app/services/session_processor/test_encoder_offload.py create mode 100644 tests/backend/util/test_device_pool.py diff --git a/docs/src/content/docs/configuration/invokeai-yaml.mdx b/docs/src/content/docs/configuration/invokeai-yaml.mdx index 6ac56053928..1c79fbf82aa 100644 --- a/docs/src/content/docs/configuration/invokeai-yaml.mdx +++ b/docs/src/content/docs/configuration/invokeai-yaml.mdx @@ -147,6 +147,27 @@ Notes: During parallel generation, the progress display shows one progress bar per active session, stacked vertically, each disappearing as its session completes. +#### Text Encoder Offload to Idle GPUs + +When more than one GPU is configured for generation but not all of them are busy, InvokeAI can run a session's text/prompt encoder on a currently-idle GPU instead of the GPU running its denoise pipeline. This avoids evicting the denoise model from VRAM just to make room for the encoder, and lets the cached encoder be reused across generations — making repeated generations noticeably smoother. + +This is controlled by the `offload_text_encoders_to_idle_gpus` setting: + +```yaml +offload_text_encoders_to_idle_gpus: true # default value +``` + +| Value | Behavior | +| ------- | ---------------------------------------------------------------------------------------------------------------- | +| `true` | Run text encoders on an idle GPU when one is available. This is the default. | +| `false` | Always run text encoders on the same GPU as the rest of the pipeline (the behavior before this feature existed). | + +Notes: + +- This has no effect unless at least two `generation_devices` are configured. On a single device — or when every GPU is already busy with its own session — encoders run on the session's own GPU, exactly as if the setting were `false`. +- It is purely a placement optimization and does not change generated images. +- A borrowed GPU is used exclusively for the encoder while it runs, so it never interferes with a generation session running on that same GPU. + #### Image Subfolder Strategy By default, generated images are stored in a single flat directory under `outputs/images/`. The `image_subfolder_strategy` setting lets you organize newly-created images into subfolders automatically. You can edit this setting in `invokeai.yaml` or, as an admin user, in the Settings panel. diff --git a/docs/src/content/docs/development/Guides/creating-nodes.mdx b/docs/src/content/docs/development/Guides/creating-nodes.mdx index f2dbee639bc..abc905f6e6a 100644 --- a/docs/src/content/docs/development/Guides/creating-nodes.mdx +++ b/docs/src/content/docs/development/Guides/creating-nodes.mdx @@ -21,6 +21,39 @@ import { Steps, LinkCard } from '@astrojs/starlight/components'; 4. A maintainer will review the pull request and node. If the node is aligned with the direction of the project, you may be asked for permission to include it in the core project. +### Supporting multi-GPU text-encoder offload + +On a machine with more than one GPU, InvokeAI can run several generation sessions at once — one per GPU. When fewer sessions are running than there are GPUs, the spare GPUs sit idle. To put that capacity to use, InvokeAI can run a session's **prompt/text encoder** on a currently-idle GPU instead of on the GPU running the denoise pipeline. This avoids evicting the denoise model from VRAM just to make room for the encoder, and lets the cached encoder be reused across generations. + +This is controlled globally by the `offload_text_encoders_to_idle_gpus` config setting (enabled by default) and opted into **per node** via the `@invocation` decorator: + +```python +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation + + +@invocation( + "my_text_encoder", + title="Prompt - My Model", + category="conditioning", + version="1.0.0", + idle_gpu_offloadable=True, # opt in to idle-GPU offload +) +class MyTextEncoderInvocation(BaseInvocation): + ... +``` + +When the feature is enabled and an idle GPU is available, the **entire node** is temporarily re-pinned to a borrowed idle GPU: any model it loads goes onto that GPU and runs there. If no idle GPU is free (e.g. every GPU is busy with its own session), the node simply runs on its own GPU, unchanged. The borrow holds the idle GPU exclusively for the duration of the node, so it can never run concurrently against a native session on that same GPU. + +Because the whole node is moved to another device, only mark a node `idle_gpu_offloadable=True` if **all** of the following hold: + +- **It is encoder-only.** Its sole GPU work is loading one or more encoder models and running their forward pass. It must not load or run the denoise/transformer or VAE, or do any other work tied to the session's own GPU. +- **It stores its result on the CPU before returning.** Move output tensors to the CPU (`tensor.detach().to("cpu")`) and save them as conditioning/tensors. The denoiser picks them up and moves them onto its own GPU later — this is what makes the cross-GPU handoff safe and device-agnostic. +- **It places inputs on the loaded model's device, not a fixed device.** Resolve the device from the model you just loaded (e.g. `get_effective_device(model)` from `invokeai.backend.model_manager.load.model_cache.utils`, or `TorchDevice.choose_torch_device()`), rather than hard-coding `cuda:0`. The built-in `flux_text_encoder` and `compel` nodes are good references. + +:::caution[Only mark encoder-only nodes] +If a node that also runs the denoiser, VAE, or other session-GPU work is marked `idle_gpu_offloadable=True`, that work will be re-pinned to the wrong GPU and can misplace tensors or raise device-mismatch errors. When in doubt, leave it unset (the default is `False`) — the node will still work correctly, just without the offload optimization. +::: + ### Community Node Template Append the following template to your pull request and the [Community Nodes](../../../workflows/community-nodes) page when submitting a node to be added to the community nodes list: diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 1987a90abce..d2d62d04f57 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -501,6 +501,17 @@ "type": "typing.Union[typing.Literal['auto'], list[str]]", "validation": {} }, + { + "category": "DEVICE", + "default": true, + "description": "When running on multiple GPUs, load text encoders onto a currently-idle GPU instead of the one running the denoise pipeline. This avoids churning the denoise model in and out of VRAM to make room for the encoder, and lets a cached encoder be reused across generations. Has no effect unless at least two `generation_devices` are configured and a GPU is idle; under full load encoders run on the session's own GPU as before.", + "env_var": "INVOKEAI_OFFLOAD_TEXT_ENCODERS_TO_IDLE_GPUS", + "literal_values": [], + "name": "offload_text_encoders_to_idle_gpus", + "required": false, + "type": "", + "validation": {} + }, { "category": "DEVICE", "default": "auto", diff --git a/invokeai/app/invocations/anima_text_encoder.py b/invokeai/app/invocations/anima_text_encoder.py index f1d4fbff8f1..c9bad65f3d0 100644 --- a/invokeai/app/invocations/anima_text_encoder.py +++ b/invokeai/app/invocations/anima_text_encoder.py @@ -59,6 +59,7 @@ category="conditioning", version="1.4.0", classification=Classification.Prototype, + idle_gpu_offloadable=True, ) class AnimaTextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for an Anima image. diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 0546dabebb5..95cac4065a3 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -271,6 +271,12 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi bottleneck: ClassVar[Bottleneck] + idle_gpu_offloadable: ClassVar[bool] = False + """Whether this node's entire execution may be temporarily re-pinned to an idle GPU when + `offload_text_encoders_to_idle_gpus` is enabled in multi-GPU mode. Only set this to True on nodes + that exclusively load encoder model(s), run a forward pass, and store their result on the CPU — + i.e. nodes that do no work tied to the session's own GPU. Set via the `@invocation` decorator.""" + UIConfig: ClassVar[UIConfigBase] model_config = ConfigDict( @@ -459,6 +465,7 @@ def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | N "type", "workflow", "bottleneck", + "idle_gpu_offloadable", } RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"} @@ -643,6 +650,7 @@ def invocation( use_cache: Optional[bool] = True, classification: Classification = Classification.Stable, bottleneck: Bottleneck = Bottleneck.GPU, + idle_gpu_offloadable: bool = False, ) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]: """ Registers an invocation. @@ -655,6 +663,7 @@ def invocation( :param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor. :param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable. :param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound. + :param bool idle_gpu_offloadable: Whether this node's whole execution may run on a borrowed idle GPU when `offload_text_encoders_to_idle_gpus` is enabled. Only set True for encoder-only nodes that store their result on the CPU and do no work on the session's own GPU. Defaults to False. """ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]: @@ -712,6 +721,7 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]: cls.model_fields["use_cache"].default = use_cache cls.bottleneck = bottleneck + cls.idle_gpu_offloadable = idle_gpu_offloadable # Add the invocation type to the model. diff --git a/invokeai/app/invocations/cogview4_text_encoder.py b/invokeai/app/invocations/cogview4_text_encoder.py index 13234889fba..c303e55b828 100644 --- a/invokeai/app/invocations/cogview4_text_encoder.py +++ b/invokeai/app/invocations/cogview4_text_encoder.py @@ -23,6 +23,7 @@ category="prompt", version="1.0.0", classification=Classification.Prototype, + idle_gpu_offloadable=True, ) class CogView4TextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for a cogview4 image.""" diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 99373531d8e..428f72d3964 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -45,6 +45,7 @@ tags=["prompt", "compel"], category="prompt", version="1.2.1", + idle_gpu_offloadable=True, ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -250,6 +251,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: tags=["sdxl", "compel", "prompt"], category="prompt", version="1.2.1", + idle_gpu_offloadable=True, ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -344,6 +346,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: tags=["sdxl", "compel", "prompt"], category="prompt", version="1.1.2", + idle_gpu_offloadable=True, ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" diff --git a/invokeai/app/invocations/flux2_klein_text_encoder.py b/invokeai/app/invocations/flux2_klein_text_encoder.py index b2728d1d7cc..2b4b53faf72 100644 --- a/invokeai/app/invocations/flux2_klein_text_encoder.py +++ b/invokeai/app/invocations/flux2_klein_text_encoder.py @@ -48,6 +48,7 @@ category="prompt", version="1.1.1", classification=Classification.Prototype, + idle_gpu_offloadable=True, ) class Flux2KleinTextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for Flux2 Klein image generation. diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index b68e9911c56..ac1f5764d78 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -50,6 +50,7 @@ class FluxReduxOutput(BaseInvocationOutput): category="conditioning", version="2.1.0", classification=Classification.Beta, + idle_gpu_offloadable=True, ) class FluxReduxInvocation(BaseInvocation): """Runs a FLUX Redux model to generate a conditioning tensor.""" diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 8b3b33fad1c..e3f28e57d72 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -30,6 +30,7 @@ tags=["prompt", "conditioning", "flux"], category="prompt", version="1.1.2", + idle_gpu_offloadable=True, ) class FluxTextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for a flux image.""" diff --git a/invokeai/app/invocations/qwen_image_text_encoder.py b/invokeai/app/invocations/qwen_image_text_encoder.py index d2aecd9f226..9d9347a8cf9 100644 --- a/invokeai/app/invocations/qwen_image_text_encoder.py +++ b/invokeai/app/invocations/qwen_image_text_encoder.py @@ -68,6 +68,7 @@ def _build_prompt(user_prompt: str, num_images: int) -> str: category="conditioning", version="1.2.0", classification=Classification.Prototype, + idle_gpu_offloadable=True, ) class QwenImageTextEncoderInvocation(BaseInvocation): """Encodes text and reference images for Qwen Image using Qwen2.5-VL.""" diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 7af138fe45e..d9f5c3f1f15 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -33,6 +33,7 @@ tags=["prompt", "conditioning", "sd3"], category="prompt", version="1.0.1", + idle_gpu_offloadable=True, ) class Sd3TextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for a SD3 image.""" diff --git a/invokeai/app/invocations/z_image_text_encoder.py b/invokeai/app/invocations/z_image_text_encoder.py index 71af6085d0e..148cff5c269 100644 --- a/invokeai/app/invocations/z_image_text_encoder.py +++ b/invokeai/app/invocations/z_image_text_encoder.py @@ -37,6 +37,7 @@ category="prompt", version="1.1.0", classification=Classification.Prototype, + idle_gpu_offloadable=True, ) class ZImageTextEncoderInvocation(BaseInvocation): """Encodes and preps a prompt for a Z-Image image. diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 8c07c2139f4..d5c8a9634a5 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -206,6 +206,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") generation_devices: Union[Literal["auto"], list[str]] = Field(default="auto", description="Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)") + offload_text_encoders_to_idle_gpus: bool = Field(default=True, description="When running on multiple GPUs, load text encoders onto a currently-idle GPU instead of the one running the denoise pipeline. This avoids churning the denoise model in and out of VRAM to make room for the encoder, and lets a cached encoder be reused across generations. Has no effect unless at least two `generation_devices` are configured and a GPU is idle; under full load encoders run on the session's own GPU as before.") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 93c4554b1fe..e776dc79614 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,9 +1,9 @@ import gc import traceback -from contextlib import suppress +from contextlib import contextmanager, suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from typing import Optional +from typing import Iterator, Optional import torch @@ -33,6 +33,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.device_pool import GENERATION_DEVICE_POOL from invokeai.backend.util.devices import TorchDevice @@ -129,8 +130,9 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): is_canceled=self._is_canceled, ) - # Invoke the node - output = invocation.invoke_internal(context=context, services=self._services) + # Invoke the node, optionally on a borrowed idle GPU (text encoders only). + with self._maybe_offload_to_idle_gpu(invocation): + output = invocation.invoke_internal(context=context, services=self._services) # Save output and history queue_item.session.complete(invocation.id, output) @@ -156,6 +158,45 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): error_traceback=error_traceback, ) + @contextmanager + def _maybe_offload_to_idle_gpu(self, invocation: BaseInvocation) -> Iterator[None]: + """Temporarily re-pin this worker thread to an idle GPU for a text-encoder node. + + When ``offload_text_encoders_to_idle_gpus`` is enabled and an idle generation GPU can be + borrowed, the encoder model loads into that GPU's cache and its forward runs there (all + device-selecting code resolves to the pinned device), keeping the busy GPU's denoise model + resident. The conditioning output is stored on the CPU, so the denoiser picks it up on the + worker's own GPU after the pin is restored. + + The borrow holds the idle device's exclusive-use lock for the whole node, so a native + session on that GPU can never run concurrently against the same cached encoder (which would + corrupt it). If no idle GPU is free, the node runs on the worker's own GPU unchanged. + """ + native_device = TorchDevice.get_session_device() + if ( + native_device is None + or native_device.type != "cuda" + or not invocation.idle_gpu_offloadable + or not self._services.configuration.offload_text_encoders_to_idle_gpus + ): + yield + return + + borrowed_device = GENERATION_DEVICE_POOL.try_borrow(exclude=native_device) + if borrowed_device is None: + yield + return + + self._services.logger.debug( + f"Running {invocation.get_type()} on idle device {borrowed_device} (session device {native_device})." + ) + TorchDevice.set_session_device(borrowed_device) + try: + yield + finally: + TorchDevice.set_session_device(native_device) + GENERATION_DEVICE_POOL.release_borrow(borrowed_device) + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: """Called before a session is run. @@ -388,6 +429,10 @@ def start(self, invoker: Invoker) -> None: devices = self._resolve_devices() + # Register the generation devices so the model loader can discover idle GPUs to host text + # encoders on (see offload_text_encoders_to_idle_gpus). None means legacy single-device mode. + GENERATION_DEVICE_POOL.set_generation_devices([d for d in devices if d is not None]) + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. Profiling uses a process-global cProfile, which # cannot cleanly attribute work when multiple sessions run concurrently, so it is disabled in multi-GPU mode. @@ -582,8 +627,16 @@ def _process( f"on {worker.label}" ) - # Run the graph - worker.runner.run(queue_item=worker.queue_item) + # Run the graph. Hold this GPU's exclusive-use lock for the whole session so no + # other worker can borrow it for text-encoder offload while we're running on it + # (a borrow + concurrent native session on one GPU would corrupt the shared + # cached encoder). Acquired here, after dequeue, so an idle worker doesn't hold + # the lock and block borrows while waiting for work. + GENERATION_DEVICE_POOL.acquire_session(worker.device) + try: + worker.runner.run(queue_item=worker.queue_item) + finally: + GENERATION_DEVICE_POOL.release_session(worker.device) except Exception as e: error_type = e.__class__.__name__ diff --git a/invokeai/backend/util/device_pool.py b/invokeai/backend/util/device_pool.py new file mode 100644 index 00000000000..1e6675161a6 --- /dev/null +++ b/invokeai/backend/util/device_pool.py @@ -0,0 +1,119 @@ +"""Process-global arbiter that lends idle generation GPUs for text-encoder offload. + +In multi-GPU mode (see ``generation_devices``) the session processor runs one generation worker +per GPU. When fewer sessions are running than there are GPUs, some GPUs sit idle. This arbiter lets +a busy worker temporarily *borrow* an idle GPU to host a text encoder, instead of churning the busy +GPU's denoise model in and out of VRAM. + +Correctness hinges on one rule: **a borrowed GPU must never run an encoder at the same time as a +native generation session on that same GPU.** They share that device's single ``ModelCache``, and a +model's forward pass (including in-place LoRA patching) runs with no cache lock held — so two +threads touching the same cached encoder concurrently corrupts it (garbled output). + +To enforce the rule, each generation device has one lock used for *both* roles: + +- A native session holds its device's lock for the entire run (blocking acquire). +- A borrower *try*-acquires another device's lock for the duration of one encoder node; if the lock + is already held (that GPU is running, or just started, a session) the borrow simply fails and the + encoder runs on the worker's own GPU instead. + +Because borrows are non-blocking try-acquires and a session only ever blocking-acquires its *own* +device lock, there is no lock-ordering cycle — the design is deadlock-free. The only cost is that, +in the startup race where a borrow wins the lock a moment before the lent GPU's own session starts, +that session waits out the (short) encoder node before beginning. +""" + +import threading +from typing import Optional + +import torch + +from invokeai.backend.util.devices import TorchDevice + + +class _GenerationDevicePool: + """Arbitrates exclusive use of each generation device between native sessions and borrowers.""" + + def __init__(self) -> None: + self._registry_lock = threading.Lock() + # Registration order is preserved so borrow selection is deterministic (and therefore sticky + # across repeated single-session generations, letting a cached encoder be reused). Maps + # normalized device string -> that device's exclusive-use lock. + self._device_locks: dict[str, threading.Lock] = {} + self._order: list[str] = [] + + def set_generation_devices(self, devices: list[torch.device]) -> None: + """Register the full set of generation devices (called once at processor startup). + + Only CUDA devices participate in idle-offload; others are ignored. + """ + with self._registry_lock: + self._device_locks = {} + self._order = [] + for device in devices: + if device.type != "cuda": + continue + key = str(TorchDevice.normalize(device)) + if key not in self._device_locks: + self._device_locks[key] = threading.Lock() + self._order.append(key) + + def _get_lock(self, device: torch.device) -> Optional[threading.Lock]: + key = str(TorchDevice.normalize(device)) + with self._registry_lock: + return self._device_locks.get(key) + + def acquire_session(self, device: Optional[torch.device]) -> None: + """Take exclusive use of ``device`` for a native generation session (blocking). + + Waits out any in-flight borrow that won the lock first, guaranteeing the session never runs + concurrently with a borrowed encoder on the same GPU. No-op for non-CUDA / unregistered + devices (e.g. legacy single-device mode). + """ + if device is None or device.type != "cuda": + return + lock = self._get_lock(device) + if lock is not None: + lock.acquire() + + def release_session(self, device: Optional[torch.device]) -> None: + """Release the exclusive use taken by :meth:`acquire_session`.""" + if device is None or device.type != "cuda": + return + lock = self._get_lock(device) + if lock is not None: + lock.release() + + def try_borrow(self, exclude: torch.device) -> Optional[torch.device]: + """Try to take exclusive use of an idle CUDA device other than ``exclude`` (non-blocking). + + Returns the borrowed device (whose lock the caller now holds and must release via + :meth:`release_borrow`), or ``None`` if no other registered device is currently free. + Selection is deterministic (lowest registration order) so repeated borrows reuse the same + GPU and the encoder cached there. + """ + if exclude.type != "cuda": + return None + exclude_key = str(TorchDevice.normalize(exclude)) + with self._registry_lock: + candidates = [(key, self._device_locks[key]) for key in self._order if key != exclude_key] + for key, lock in candidates: + if lock.acquire(blocking=False): + return torch.device(key) + return None + + def release_borrow(self, device: torch.device) -> None: + """Release a device taken by :meth:`try_borrow`.""" + lock = self._get_lock(device) + if lock is not None: + lock.release() + + def reset(self) -> None: + """Clear all registered devices (used by tests).""" + with self._registry_lock: + self._device_locks = {} + self._order = [] + + +# Process-global singleton. +GENERATION_DEVICE_POOL = _GenerationDevicePool() diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 522cd1ce4aa..4ca744496d3 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -41226,6 +41226,12 @@ "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", "default": "auto" }, + "offload_text_encoders_to_idle_gpus": { + "type": "boolean", + "title": "Offload Text Encoders To Idle Gpus", + "description": "When running on multiple GPUs, load text encoders onto a currently-idle GPU instead of the one running the denoise pipeline. This avoids churning the denoise model in and out of VRAM to make room for the encoder, and lets a cached encoder be reused across generations. Has no effect unless at least two `generation_devices` are configured and a GPU is idle; under full load encoders run on the session's own GPU as before.", + "default": true + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 75dafa37f34..a18faeaed26 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16556,6 +16556,12 @@ export type components = { * @default auto */ generation_devices?: "auto" | string[]; + /** + * Offload Text Encoders To Idle Gpus + * @description When running on multiple GPUs, load text encoders onto a currently-idle GPU instead of the one running the denoise pipeline. This avoids churning the denoise model in and out of VRAM to make room for the encoder, and lets a cached encoder be reused across generations. Has no effect unless at least two `generation_devices` are configured and a GPU is idle; under full load encoders run on the session's own GPU as before. + * @default true + */ + offload_text_encoders_to_idle_gpus?: boolean; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. diff --git a/tests/app/services/session_processor/test_encoder_offload.py b/tests/app/services/session_processor/test_encoder_offload.py new file mode 100644 index 00000000000..972783a9ea9 --- /dev/null +++ b/tests/app/services/session_processor/test_encoder_offload.py @@ -0,0 +1,201 @@ +"""Tests for DefaultSessionRunner._maybe_offload_to_idle_gpu (idle-GPU text-encoder offload). + +These exercise the re-pinning + borrow-lock logic without needing real CUDA: the session device is +a thread-local set via TorchDevice, and the device pool only manipulates locks keyed by device +string. +""" + +import logging +import threading +import time +from collections.abc import Iterator + +import pytest +import torch + +from invokeai.app.services.session_processor.session_processor_default import DefaultSessionRunner +from invokeai.backend.util.device_pool import GENERATION_DEVICE_POOL +from invokeai.backend.util.devices import TorchDevice + + +@pytest.fixture(autouse=True) +def reset_state() -> Iterator[None]: + GENERATION_DEVICE_POOL.reset() + try: + yield + finally: + TorchDevice.clear_session_device() + GENERATION_DEVICE_POOL.reset() + + +class _FakeInvocation: + def __init__(self, idle_gpu_offloadable: bool, type_str: str = "fake_node"): + self.idle_gpu_offloadable = idle_gpu_offloadable + self._type_str = type_str + + def get_type(self) -> str: + return self._type_str + + +class _FakeConfig: + def __init__(self, enabled: bool = True): + self.offload_text_encoders_to_idle_gpus = enabled + + +class _FakeServices: + def __init__(self, enabled: bool = True): + self.configuration = _FakeConfig(enabled) + self.logger = logging.getLogger("test-encoder-offload") + + +def _runner(enabled: bool = True) -> DefaultSessionRunner: + runner = DefaultSessionRunner() + runner._services = _FakeServices(enabled) # type: ignore[assignment] + return runner + + +def test_encoder_node_repins_to_idle_gpu_and_restores(): + runner = _runner() + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + TorchDevice.set_session_device("cuda:0") + + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(True, "flux_text_encoder")): + # Re-pinned to the borrowed idle GPU for the duration of the node... + assert TorchDevice.get_session_device() == torch.device("cuda:1") + # ...and that GPU is locked, so nothing else can borrow it. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + + # Pin restored and the borrow released. + assert TorchDevice.get_session_device() == torch.device("cuda:0") + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_non_encoder_node_is_not_offloaded(): + runner = _runner() + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + TorchDevice.set_session_device("cuda:0") + + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(False, "denoise_latents")): + assert TorchDevice.get_session_device() == torch.device("cuda:0") + # Idle device was never borrowed. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_no_offload_when_target_running_a_session(): + """With both GPUs busy (the other holds a session lock), the encoder stays on its own GPU.""" + runner = _runner() + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + TorchDevice.set_session_device("cuda:0") + GENERATION_DEVICE_POOL.acquire_session(torch.device("cuda:1")) + try: + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(True, "flux_text_encoder")): + assert TorchDevice.get_session_device() == torch.device("cuda:0") + finally: + GENERATION_DEVICE_POOL.release_session(torch.device("cuda:1")) + + +def test_flag_off_disables_offload(): + runner = _runner(enabled=False) + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + TorchDevice.set_session_device("cuda:0") + + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(True, "flux_text_encoder")): + assert TorchDevice.get_session_device() == torch.device("cuda:0") + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_borrow_released_on_exception(): + runner = _runner() + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + TorchDevice.set_session_device("cuda:0") + + with pytest.raises(RuntimeError): + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(True, "flux_text_encoder")): + raise RuntimeError("node failed") + + # The pin is restored and the borrow lock released even though the node raised. + assert TorchDevice.get_session_device() == torch.device("cuda:0") + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_single_gpu_never_offloads(): + runner = _runner() + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0")]) + TorchDevice.set_session_device("cuda:0") + + with runner._maybe_offload_to_idle_gpu(_FakeInvocation(True, "flux_text_encoder")): + assert TorchDevice.get_session_device() == torch.device("cuda:0") + + +def test_concurrent_workers_never_share_a_gpu(): + """Regression for the garbled-image bug: two sessions running at once must never use the same + GPU for an encoder concurrently. Each worker holds its own GPU's session lock (as the processor + does) and runs encoder nodes that may borrow the other GPU through the real offload path; we + assert no GPU is ever occupied by two workers at the same time. + + Before the fix, a startup race let one worker offload its encoder onto the other's GPU while + that GPU also ran a native session — both touching the same cached encoder. This test exercises + that exact interleaving and would flag it as occupancy > 1. + """ + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + + occupancy = {"cuda:0": 0, "cuda:1": 0} + occ_lock = threading.Lock() + violations: list[str] = [] + + def occupy(device_str: str) -> None: + with occ_lock: + occupancy[device_str] += 1 + if occupancy[device_str] > 1: + violations.append(device_str) + + def vacate(device_str: str) -> None: + with occ_lock: + occupancy[device_str] -= 1 + + def worker(own: str) -> None: + runner = _runner() + own_device = torch.device(own) + encoder = _FakeInvocation(True, "flux_text_encoder") + for _ in range(150): + # The processor holds the device's session lock for the whole run. + GENERATION_DEVICE_POOL.acquire_session(own_device) + TorchDevice.set_session_device(own_device) + occupy(own) + try: + with runner._maybe_offload_to_idle_gpu(encoder): + current = str(TorchDevice.get_session_device()) + if current != own: + # The node was re-pinned to a borrowed GPU; it must be exclusively ours. + occupy(current) + try: + time.sleep(0.0002) + finally: + vacate(current) + else: + time.sleep(0.0001) + finally: + vacate(own) + TorchDevice.clear_session_device() + GENERATION_DEVICE_POOL.release_session(own_device) + + threads = [threading.Thread(target=worker, args=(d,)) for d in ("cuda:0", "cuda:1")] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not violations, f"GPU(s) used by two workers at once: {set(violations)}" + + +def test_real_nodes_declare_the_marker_correctly(): + """The @invocation(idle_gpu_offloadable=...) marker is wired through to the class, and is set on + encoder nodes but not on ordinary nodes.""" + from invokeai.app.invocations.compel import CompelInvocation + from invokeai.app.invocations.flux_text_encoder import FluxTextEncoderInvocation + from invokeai.app.invocations.primitives import IntegerInvocation + + assert FluxTextEncoderInvocation.idle_gpu_offloadable is True + assert CompelInvocation.idle_gpu_offloadable is True + # A non-encoder node defaults to False (never re-pinned to a borrowed GPU). + assert IntegerInvocation.idle_gpu_offloadable is False diff --git a/tests/backend/util/test_device_pool.py b/tests/backend/util/test_device_pool.py new file mode 100644 index 00000000000..402e496889b --- /dev/null +++ b/tests/backend/util/test_device_pool.py @@ -0,0 +1,158 @@ +"""Tests for the idle generation-device arbiter used by text-encoder offload.""" + +import threading +import time +from collections.abc import Iterator + +import pytest +import torch + +from invokeai.backend.util.device_pool import GENERATION_DEVICE_POOL + + +@pytest.fixture(autouse=True) +def reset_pool() -> Iterator[None]: + """The arbiter is a process-global singleton; reset it around each test.""" + GENERATION_DEVICE_POOL.reset() + try: + yield + finally: + GENERATION_DEVICE_POOL.reset() + + +def test_borrow_picks_lowest_other_device(): + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_borrow_excludes_requesting_device(): + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:1")) == torch.device("cuda:0") + + +def test_session_lock_blocks_borrow(): + """A device held by a native session cannot be borrowed.""" + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + GENERATION_DEVICE_POOL.acquire_session(torch.device("cuda:1")) + try: + # The only other device is busy with a session -> no borrow. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + finally: + GENERATION_DEVICE_POOL.release_session(torch.device("cuda:1")) + # Released -> borrowable again. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) == torch.device("cuda:1") + + +def test_borrow_blocks_session_until_released(): + """A native session acquire waits for an in-flight borrow on the same device (startup race).""" + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + borrowed = GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) + assert borrowed == torch.device("cuda:1") + + acquired = threading.Event() + + def native_session(): + GENERATION_DEVICE_POOL.acquire_session(torch.device("cuda:1")) + acquired.set() + + t = threading.Thread(target=native_session) + t.start() + # The session must block while the borrow holds cuda:1. + assert not acquired.wait(timeout=0.2) + GENERATION_DEVICE_POOL.release_borrow(torch.device("cuda:1")) + # Now it can proceed. + assert acquired.wait(timeout=2.0) + t.join() + GENERATION_DEVICE_POOL.release_session(torch.device("cuda:1")) + + +def test_two_borrowers_do_not_share_a_device(): + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0"), torch.device("cuda:1")]) + first = GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) + assert first == torch.device("cuda:1") + # A second borrower (also from cuda:0) finds the only other device already taken -> None. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + GENERATION_DEVICE_POOL.release_borrow(first) + + +def test_single_device_has_no_borrow_target(): + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cuda:0")]) + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + + +def test_deterministic_lowest_order_selection(): + GENERATION_DEVICE_POOL.set_generation_devices( + [torch.device("cuda:0"), torch.device("cuda:1"), torch.device("cuda:2")] + ) + # cuda:1 and cuda:2 are both free; the lowest-order one (cuda:1) is chosen, and the choice is + # stable across calls (release then re-borrow) so a cached encoder can be reused. + for _ in range(3): + device = GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) + assert device == torch.device("cuda:1") + GENERATION_DEVICE_POOL.release_borrow(device) + + +def test_non_cuda_devices_ignored(): + GENERATION_DEVICE_POOL.set_generation_devices([torch.device("cpu"), torch.device("cuda:0")]) + # Only cuda:0 registered; nothing else to borrow. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + # A non-cuda requester never borrows, and a non-cuda session acquire is a no-op. + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cpu")) is None + GENERATION_DEVICE_POOL.acquire_session(torch.device("cpu")) # must not raise + GENERATION_DEVICE_POOL.release_session(torch.device("cpu")) + + +def test_empty_pool_returns_none(): + assert GENERATION_DEVICE_POOL.try_borrow(exclude=torch.device("cuda:0")) is None + + +def test_concurrent_sessions_and_borrows_never_overlap_on_a_device(): + """Regression: a GPU must never be used by a native session and a borrowed encoder at the same + time. That overlap is exactly what corrupted a shared encoder and produced garbled images. Here + we stress the arbiter from several threads and assert exclusive use is always honored. + + With only the busy-flag approach this used before the fix, a borrow could win against a starting + session and both would "use" the device — which this test would catch as occupancy > 1. + """ + device_strs = ["cuda:0", "cuda:1", "cuda:2"] + GENERATION_DEVICE_POOL.set_generation_devices([torch.device(d) for d in device_strs]) + + occupancy = dict.fromkeys(device_strs, 0) + occ_lock = threading.Lock() + violations: list[str] = [] + + def occupy(device_str: str) -> None: + with occ_lock: + occupancy[device_str] += 1 + if occupancy[device_str] > 1: + violations.append(device_str) + + def vacate(device_str: str) -> None: + with occ_lock: + occupancy[device_str] -= 1 + + def worker(own: str) -> None: + own_device = torch.device(own) + for _ in range(200): + GENERATION_DEVICE_POOL.acquire_session(own_device) + occupy(own) # this thread now exclusively owns `own` (as a native session would) + try: + borrowed = GENERATION_DEVICE_POOL.try_borrow(exclude=own_device) + if borrowed is not None: + occupy(str(borrowed)) + try: + time.sleep(0.0002) # widen the window so any overlap is observed + finally: + vacate(str(borrowed)) + GENERATION_DEVICE_POOL.release_borrow(borrowed) + finally: + vacate(own) + GENERATION_DEVICE_POOL.release_session(own_device) + + threads = [threading.Thread(target=worker, args=(d,)) for d in device_strs] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not violations, f"device(s) used concurrently by a session and a borrow: {set(violations)}" From b61a8e15f280c7ad8e2c00d4d4138a7c16d7963f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 29 Jun 2026 08:53:58 -0400 Subject: [PATCH 33/33] fix(multi-gpu): adopt GGUF weights across devices to stop RAM spikes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _build_meta_shell built meta placeholders with torch.empty_like, which GGMLTensor.__torch_dispatch__ rejects (NotImplemented for aten.empty_like). It threw on the first parameter, hit the silent except, and returned None — so GGUF models (e.g. a Q8_0 transformer) never registered a shell and the second GPU re-loaded the full model from disk, stacking a ~20GB transient on the retained copy and spiking RAM to ~70%. Fall back to a plain meta placeholder (logical shape/dtype) when empty_like isn't implemented by a tensor subclass; verified the adopted GGMLTensor shares the quantized storage, so it's one RAM copy across devices. Peak drops ~66→~46GB. Log shell-build failures at debug so a future un-adoptable family is diagnosable instead of silently double-loading. Also restore log_memory_usage's per-cold-load RAM logging (the capture method had no callers), slimmed to baseline→transient-peak process RAM. Co-Authored-By: Claude Opus 4.8 --- .../model_manager/load/load_default.py | 49 +++++++++++++++++-- .../load/test_shared_weight_adoption.py | 48 ++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index de87c797e8e..7f1637111b7 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -14,6 +14,7 @@ from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.memory_snapshot import GB, MemorySnapshot from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ( MODEL_LOAD_LOCK, @@ -168,9 +169,26 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod shell_to_register: Optional[torch.nn.Module] = None if loaded_model is None: + # Optional RAM instrumentation for the cold disk-load path (the only place that runs + # `from_pretrained`, whose construction transient can briefly spike RAM past the + # cache's retained budget). Gated on `log_memory_usage`; captures process RAM before + # make_room, after make_room (retained baseline), and after construction (transient + # peak) so the surge can be attributed without guessing. + log_mem = self._app_config.log_memory_usage + ram_before = MemorySnapshot.capture().process_ram if log_mem else 0 self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + ram_after_room = MemorySnapshot.capture().process_ram if log_mem else 0 with skip_torch_weight_init(): loaded_model = self._load_model(config, submodel_type) + if log_mem: + ram_peak = MemorySnapshot.capture().process_ram + self._logger.info( + f"Cold load RAM for '{cache_key}': " + f"make_room {ram_before / GB:.2f}->{ram_after_room / GB:.2f}GB " + f"({(ram_after_room - ram_before) / GB:+.2f}), " + f"construct {ram_after_room / GB:.2f}->{ram_peak / GB:.2f}GB " + f"({(ram_peak - ram_after_room) / GB:+.2f}) [transient peak]" + ) # Snapshot a meta-weight clone now — before put() applies custom layers or any VRAM # move — so the next device to load this model can adopt these weights (see above). # Skipped in single-device setups, where no other cache will ever adopt it. @@ -414,6 +432,21 @@ def _build_meta_shell(model: AnyModel) -> Optional[torch.nn.Module]: """ if not isinstance(model, torch.nn.Module): return None + + def _meta_like(t: torch.Tensor) -> torch.Tensor: + # A 0-byte stand-in with the same logical shape/dtype as `t`; replaced by the canonical + # tensor on adoption (load_state_dict(assign=True)), so only its shape needs to match. + # `torch.empty_like` is preferred (preserves layout etc.) but is NOT implemented by some + # tensor subclasses — notably the GGUF `GGMLTensor`, whose `__torch_dispatch__` returns + # NotImplemented for `aten.empty_like`. That made `_build_meta_shell` throw on the first + # parameter of every GGUF model (e.g. a Q8_0 quantized transformer), silently disabling + # cross-device adoption for exactly the largest models. For those, fall back to a plain + # meta tensor built from the subclass's reported (dequantized) shape and dtype. + try: + return torch.empty_like(t, device="meta") + except TypeError: + return torch.empty(t.shape, dtype=t.dtype, device="meta") + try: # Persistent buffers come from the canonical state dict on adoption, so they (like params) # are replaced by meta placeholders. Non-persistent buffers are NOT in the state dict, so @@ -424,15 +457,21 @@ def _build_meta_shell(model: AnyModel) -> Optional[torch.nn.Module]: memo: dict[int, object] = {} for param in model.parameters(recurse=True): - memo[id(param)] = torch.nn.Parameter( - torch.empty_like(param, device="meta"), requires_grad=param.requires_grad - ) + memo[id(param)] = torch.nn.Parameter(_meta_like(param), requires_grad=param.requires_grad) for buffer in model.buffers(recurse=True): if id(buffer) in persistent_buffer_ids: - memo[id(buffer)] = torch.empty_like(buffer, device="meta") + memo[id(buffer)] = _meta_like(buffer) return copy.deepcopy(model, memo) - except Exception: + except Exception as e: + # Best-effort: an un-clonable model simply isn't adoptable (the next device loads it from + # disk). Log at debug so a newly-unadoptable model family can be diagnosed rather than + # silently double-loading on every device. + from invokeai.backend.util.logging import InvokeAILogger + + InvokeAILogger.get_logger().debug( + f"Could not build meta-weight shell for {type(model).__name__} ({e!r}); model won't be adopted." + ) return None # This needs to be implemented in the subclass diff --git a/tests/backend/model_manager/load/test_shared_weight_adoption.py b/tests/backend/model_manager/load/test_shared_weight_adoption.py index 3393e2af91b..72203b218ca 100644 --- a/tests/backend/model_manager/load/test_shared_weight_adoption.py +++ b/tests/backend/model_manager/load/test_shared_weight_adoption.py @@ -53,6 +53,54 @@ def test_meta_shell_has_no_real_weight_storage(): assert torch.equal(shell.scale, model.scale) +def test_adopts_gguf_quantized_weights_and_shares_storage(): + """Regression: GGUF models (GGMLTensor params) must be adoptable. + + `GGMLTensor.__torch_dispatch__` returns NotImplemented for `aten.empty_like`, so the original + `_build_meta_shell` threw on the first parameter and silently returned None -> no shell -> the + second GPU re-loaded the (largest) quantized model from disk on every run, spiking RAM. The + meta-shell builder must fall back to a plain meta placeholder, and adoption must share the + quantized storage (the whole point of dedup). + """ + import math + + import gguf + + from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor + + def _make_q8(logical_shape: tuple[int, int]) -> GGMLTensor: + n = math.prod(logical_shape) + nbytes = (n // 32) * 34 # Q8_0: 34 bytes per 32-element block + return GGMLTensor( + torch.zeros(nbytes, dtype=torch.uint8), + gguf.GGMLQuantizationType.Q8_0, + torch.Size(logical_shape), + torch.bfloat16, + ) + + model = _TinyModel() + # Swap the real weight for a GGMLTensor, mirroring the gguf loader's load_state_dict(assign=True). + model.lin.weight = torch.nn.Parameter(_make_q8((4, 4)), requires_grad=False) + orig_storage = model.lin.weight.quantized_data.untyped_storage().data_ptr() + + # The shell must build despite GGMLTensor not supporting empty_like. + shell = ModelLoader._build_meta_shell(model) + assert shell is not None + assert shell.lin.weight.is_meta + + store = SharedCpuWeightsStore() + store.acquire("g", model.state_dict()) + store.set_shell("g", shell) + + adopted = _loader_with_store(store)._try_adopt_shared_weights("g") + + assert adopted is not None + assert not any(t.is_meta for t in adopted.parameters()) + # The adopted GGMLTensor must share the quantized blob's storage -> one RAM copy across devices. + assert isinstance(adopted.lin.weight, GGMLTensor) + assert adopted.lin.weight.quantized_data.untyped_storage().data_ptr() == orig_storage + + def test_build_meta_shell_returns_none_for_non_module(): assert ModelLoader._build_meta_shell({"not": "a module"}) is None # type: ignore[arg-type]