From b6740d4a51066de9c910c6544f4d15584f743f11 Mon Sep 17 00:00:00 2001 From: Gohankaiju <167270541+Gohankaiju@users.noreply.github.com> Date: Thu, 25 Jun 2026 04:53:34 +0000 Subject: [PATCH 1/2] feat(nodes): add managed file inputs --- invokeai/app/api/dependencies.py | 3 + invokeai/app/api/routers/files.py | 101 +++++++++ invokeai/app/api_app.py | 62 +++++- invokeai/app/invocations/fields.py | 6 + .../app/services/config/config_default.py | 1 + invokeai/app/services/files/__init__.py | 23 ++ invokeai/app/services/files/files_base.py | 33 +++ invokeai/app/services/files/files_common.py | 65 ++++++ invokeai/app/services/files/files_disk.py | 178 +++++++++++++++ invokeai/app/services/invocation_services.py | 3 + .../app/services/shared/invocation_context.py | 50 ++++- invokeai/app/util/custom_openapi.py | 3 +- invokeai/frontend/web/public/locales/en.json | 5 +- .../Invocation/fields/InputFieldRenderer.tsx | 10 + .../fields/inputs/FileFieldInputComponent.tsx | 207 ++++++++++++++++++ .../src/features/nodes/store/nodesSlice.ts | 6 + .../src/features/nodes/types/common.test-d.ts | 2 + .../web/src/features/nodes/types/common.ts | 6 + .../web/src/features/nodes/types/constants.ts | 1 + .../web/src/features/nodes/types/field.ts | 30 +++ .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 16 ++ .../web/src/services/api/endpoints/files.ts | 41 ++++ .../frontend/web/src/services/api/index.ts | 1 + .../frontend/web/src/services/api/schema.ts | 207 ++++++++++++++++++ .../frontend/web/src/services/api/types.ts | 17 ++ invokeai/invocation_api/__init__.py | 2 + tests/app/services/files/__init__.py | 0 tests/app/services/files/test_files_disk.py | 164 ++++++++++++++ 29 files changed, 1240 insertions(+), 4 deletions(-) create mode 100644 invokeai/app/api/routers/files.py create mode 100644 invokeai/app/services/files/__init__.py create mode 100644 invokeai/app/services/files/files_base.py create mode 100644 invokeai/app/services/files/files_common.py create mode 100644 invokeai/app/services/files/files_disk.py create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx create mode 100644 invokeai/frontend/web/src/services/api/endpoints/files.ts create mode 100644 tests/app/services/files/__init__.py create mode 100644 tests/app/services/files/test_files_disk.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index e7468c1bca4..75fe3bedd9e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -24,6 +24,7 @@ SeedreamProvider, ) from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models +from invokeai.app.services.files.files_disk import DiskFileService from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage from invokeai.app.services.images.images_default import ImageService @@ -107,6 +108,7 @@ def initialize( raise ValueError("Output folder is not set") image_files = DiskImageFileStorage(f"{output_folder}/images") + files = DiskFileService(output_folder / "files", max_file_size=config.max_file_upload_size_bytes) model_images_folder = config.models_path style_presets_folder = config.style_presets_path @@ -197,6 +199,7 @@ def initialize( bulk_download=bulk_download, configuration=configuration, events=events, + files=files, image_files=image_files, image_records=image_records, images=images, diff --git a/invokeai/app/api/routers/files.py b/invokeai/app/api/routers/files.py new file mode 100644 index 00000000000..b9c5a97753f --- /dev/null +++ b/invokeai/app/api/routers/files.py @@ -0,0 +1,101 @@ +from fastapi import HTTPException, Path, Response, UploadFile +from fastapi.concurrency import run_in_threadpool +from fastapi.routing import APIRouter + +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.files.files_base import FileServiceBase +from invokeai.app.services.files.files_common import ( + FileAccessDeniedException, + FileDTO, + FileMetadataException, + FileNotFoundException, + FileTooLargeException, + UnsupportedFileTypeException, +) + +files_router = APIRouter(prefix="/v1/files", tags=["files"]) + + +def _get_file_service() -> FileServiceBase: + file_service = ApiDependencies.invoker.services.files + if file_service is None: + raise HTTPException(status_code=503, detail="Managed file service is not available") + return file_service + + +@files_router.post( + "/upload", + operation_id="upload_file", + responses={ + 201: {"description": "The file was uploaded successfully"}, + 413: {"description": "The file is too large"}, + 415: {"description": "The file type is not supported"}, + }, + status_code=201, + response_model=FileDTO, +) +async def upload_file( + current_user: CurrentUserOrDefault, + file: UploadFile, + response: Response, +) -> FileDTO: + """Uploads a managed file for node inputs.""" + try: + file_dto = await run_in_threadpool( + _get_file_service().save, + file_name=file.filename or "", + content_type=file.content_type, + file=file.file, + user_id=current_user.user_id, + ) + response.status_code = 201 + return file_dto + except UnsupportedFileTypeException as e: + raise HTTPException(status_code=415, detail=str(e)) + except FileTooLargeException as e: + raise HTTPException(status_code=413, detail=str(e)) + except FileMetadataException as e: + raise HTTPException(status_code=500, detail=str(e)) + finally: + await file.close() + + +@files_router.get( + "/i/{file_id}", + operation_id="get_file_dto", + response_model=FileDTO, +) +async def get_file_dto( + current_user: CurrentUserOrDefault, + file_id: str = Path(description="The managed file ID."), +) -> FileDTO: + """Gets metadata for a managed file.""" + try: + return await run_in_threadpool(_get_file_service().get_dto, file_id, user_id=current_user.user_id) + except FileAccessDeniedException: + raise HTTPException(status_code=403, detail="Not authorized to access this file") + except FileNotFoundException: + raise HTTPException(status_code=404, detail="File not found") + except FileMetadataException as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@files_router.delete( + "/i/{file_id}", + operation_id="delete_file", + status_code=204, +) +async def delete_file( + current_user: CurrentUserOrDefault, + file_id: str = Path(description="The managed file ID."), +) -> None: + """Deletes a managed file.""" + try: + await run_in_threadpool(_get_file_service().delete, file_id, user_id=current_user.user_id) + except FileAccessDeniedException: + raise HTTPException(status_code=403, detail="Not authorized to delete this file") + except FileNotFoundException: + raise HTTPException(status_code=404, detail="File not found") + except FileMetadataException as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 4b79e1eeb0c..4e15842d056 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -7,10 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.types import ASGIApp, Message, Receive, Scope, Send import invokeai.frontend.web as web_dir from invokeai.app.api.dependencies import ApiDependencies @@ -23,6 +24,7 @@ client_state, custom_nodes, download_queue, + files, images, model_manager, model_relationships, @@ -125,6 +127,62 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): return response +MANAGED_FILE_UPLOAD_MULTIPART_OVERHEAD_BYTES = 1024 * 1024 + + +class RequestBodyTooLarge(Exception): + pass + + +class ManagedFileUploadSizeLimitMiddleware: + def __init__(self, app: ASGIApp, max_file_bytes: int) -> None: + self.app = app + self.max_file_bytes = max_file_bytes + self.max_body_bytes = max_file_bytes + MANAGED_FILE_UPLOAD_MULTIPART_OVERHEAD_BYTES + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope.get("type") != "http" or scope.get("method") != "POST" or scope.get("path") != "/api/v1/files/upload": + await self.app(scope, receive, send) + return + + headers = dict(scope.get("headers", [])) + content_length_header = headers.get(b"content-length") + content_length = None + if content_length_header is not None: + try: + content_length = int(content_length_header) + except ValueError: + content_length = None + + if content_length is not None and content_length > self.max_body_bytes: + response = JSONResponse( + {"detail": f"File exceeds the maximum size of {self.max_file_bytes} bytes."}, + status_code=413, + ) + await response(scope, receive, send) + return + + received_bytes = 0 + + async def limited_receive() -> Message: + nonlocal received_bytes + message = await receive() + if message["type"] == "http.request": + received_bytes += len(message.get("body", b"")) + if received_bytes > self.max_body_bytes: + raise RequestBodyTooLarge + return message + + try: + await self.app(scope, limited_receive, send) + except RequestBodyTooLarge: + response = JSONResponse( + {"detail": f"File exceeds the maximum size of {self.max_file_bytes} bytes."}, + status_code=413, + ) + await response(scope, receive, send) + + class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware): """When a request is made to the root path with a query string, redirect to the root path without the query string. @@ -146,6 +204,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): # Add the middleware app.add_middleware(RedirectRootWithQueryStringMiddleware) app.add_middleware(SlidingWindowTokenMiddleware) +app.add_middleware(ManagedFileUploadSizeLimitMiddleware, max_file_bytes=app_config.max_file_upload_size_bytes) # Add event handler @@ -176,6 +235,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): app.include_router(utilities.utilities_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") +app.include_router(files.files_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index e53aeb417b2..c79561342ad 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -240,6 +240,12 @@ class ImageField(BaseModel): image_name: str = Field(description="The name of the image") +class FileField(BaseModel): + """A managed file primitive field""" + + file_id: str = Field(description="The id of the managed file") + + class BoardField(BaseModel): """A board primitive field""" diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index e6cc7c2798c..fbeeace6af4 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -216,6 +216,7 @@ class InvokeAIAppConfig(BaseSettings): max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") + max_file_upload_size_bytes: int = Field(default=50 * 1024 * 1024, gt=0, description="Maximum size in bytes for managed file uploads.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/files/__init__.py b/invokeai/app/services/files/__init__.py new file mode 100644 index 00000000000..85f98c6bdfe --- /dev/null +++ b/invokeai/app/services/files/__init__.py @@ -0,0 +1,23 @@ +from invokeai.app.services.files.files_base import FileServiceBase +from invokeai.app.services.files.files_common import ( + FileAccessDeniedException, + FileDTO, + FileMetadataException, + FileNotFoundException, + FileStorageException, + FileTooLargeException, + UnsupportedFileTypeException, +) +from invokeai.app.services.files.files_disk import DiskFileService + +__all__ = [ + "DiskFileService", + "FileAccessDeniedException", + "FileDTO", + "FileMetadataException", + "FileNotFoundException", + "FileServiceBase", + "FileStorageException", + "FileTooLargeException", + "UnsupportedFileTypeException", +] diff --git a/invokeai/app/services/files/files_base.py b/invokeai/app/services/files/files_base.py new file mode 100644 index 00000000000..dd8f1e02f5b --- /dev/null +++ b/invokeai/app/services/files/files_base.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import IO, Any, BinaryIO + +from invokeai.app.services.files.files_common import FileDTO + + +class FileServiceBase(ABC): + @abstractmethod + def save( + self, + file_name: str, + content_type: str | None, + file: BinaryIO, + user_id: str | None, + ) -> FileDTO: + """Saves a managed file and returns its DTO.""" + + @abstractmethod + def get_dto(self, file_id: str, user_id: str | None = None) -> FileDTO: + """Gets a managed file DTO.""" + + @abstractmethod + def get_path(self, file_id: str, user_id: str | None = None) -> Path: + """Gets the path to a managed file.""" + + @abstractmethod + def open(self, file_id: str, mode: str = "rb", user_id: str | None = None) -> IO[Any]: + """Opens a managed file.""" + + @abstractmethod + def delete(self, file_id: str, user_id: str | None = None) -> None: + """Deletes a managed file.""" diff --git a/invokeai/app/services/files/files_common.py b/invokeai/app/services/files/files_common.py new file mode 100644 index 00000000000..140e0937c5c --- /dev/null +++ b/invokeai/app/services/files/files_common.py @@ -0,0 +1,65 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + +DEFAULT_FILE_UPLOAD_MAX_BYTES = 50 * 1024 * 1024 + +SUPPORTED_FILE_EXTENSIONS = frozenset( + { + ".csv", + ".json", + ".md", + ".markdown", + ".pdf", + ".txt", + ".yaml", + ".yml", + } +) + +SUPPORTED_FILE_MIME_TYPES = frozenset( + { + "application/json", + "application/pdf", + "application/x-yaml", + "application/yaml", + "text/csv", + "text/markdown", + "text/plain", + "text/x-markdown", + "text/x-yaml", + "text/yaml", + } +) + + +class FileDTO(BaseModel): + file_id: str = Field(description="The managed file ID.") + file_name: str = Field(description="The original file name.") + content_type: str = Field(description="The uploaded file content type.") + size_bytes: int = Field(description="The size of the file in bytes.", ge=0) + created_at: datetime = Field(description="When the file was uploaded.") + + +class FileStorageException(Exception): + """Base exception for managed file storage errors.""" + + +class FileNotFoundException(FileStorageException): + """Raised when a managed file cannot be found.""" + + +class UnsupportedFileTypeException(FileStorageException): + """Raised when a managed file upload is not allowed.""" + + +class FileTooLargeException(FileStorageException): + """Raised when a managed file upload exceeds the size limit.""" + + +class FileAccessDeniedException(FileStorageException): + """Raised when a user cannot access a managed file.""" + + +class FileMetadataException(FileStorageException): + """Raised when managed file metadata cannot be read or validated.""" diff --git a/invokeai/app/services/files/files_disk.py b/invokeai/app/services/files/files_disk.py new file mode 100644 index 00000000000..9377ebb7fe5 --- /dev/null +++ b/invokeai/app/services/files/files_disk.py @@ -0,0 +1,178 @@ +import json +from datetime import datetime, timezone +from json import JSONDecodeError +from pathlib import Path +from typing import IO, Any, BinaryIO + +from pydantic import Field, TypeAdapter, ValidationError + +from invokeai.app.services.files.files_base import FileServiceBase +from invokeai.app.services.files.files_common import ( + DEFAULT_FILE_UPLOAD_MAX_BYTES, + SUPPORTED_FILE_EXTENSIONS, + SUPPORTED_FILE_MIME_TYPES, + FileAccessDeniedException, + FileDTO, + FileMetadataException, + FileNotFoundException, + FileTooLargeException, + UnsupportedFileTypeException, +) +from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.logging import InvokeAILogger + + +class FileRecord(FileDTO): + stored_file_name: str = Field(description="The file name used on disk.") + user_id: str | None = Field(default=None, description="The user who uploaded this file.") + + +FileRecordAdapter = TypeAdapter(FileRecord) + + +class DiskFileService(FileServiceBase): + def __init__(self, storage_path: Path, max_file_size: int = DEFAULT_FILE_UPLOAD_MAX_BYTES) -> None: + self._storage_path = storage_path + self._max_file_size = max_file_size + self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__) + self._validate_storage_folder() + + def save( + self, + file_name: str, + content_type: str | None, + file: BinaryIO, + user_id: str | None, + ) -> FileDTO: + safe_file_name = self._sanitize_file_name(file_name) + extension = self._get_supported_extension(safe_file_name) + normalized_content_type = self._normalize_content_type(content_type) + self._validate_content_type(normalized_content_type) + + file_id = uuid_string() + stored_file_name = f"{file_id}{extension}" + file_path = self._get_file_path(stored_file_name) + metadata_path = self._get_metadata_path(file_id) + + size_bytes = 0 + try: + with open(file_path, "wb") as out_file: + while True: + chunk = file.read(1024 * 1024) + if not chunk: + break + size_bytes += len(chunk) + if size_bytes > self._max_file_size: + raise FileTooLargeException(f"File exceeds the maximum size of {self._max_file_size} bytes.") + out_file.write(chunk) + + record = FileRecord( + file_id=file_id, + file_name=safe_file_name, + stored_file_name=stored_file_name, + content_type=normalized_content_type, + size_bytes=size_bytes, + created_at=datetime.now(timezone.utc), + user_id=user_id, + ) + metadata_path.write_text(record.model_dump_json(), encoding="utf-8") + return self._record_to_dto(record) + except Exception: + file_path.unlink(missing_ok=True) + metadata_path.unlink(missing_ok=True) + raise + + def get_dto(self, file_id: str, user_id: str | None = None) -> FileDTO: + record = self._get_record(file_id, user_id=user_id) + self._get_existing_file_path(record) + return self._record_to_dto(record) + + def get_path(self, file_id: str, user_id: str | None = None) -> Path: + record = self._get_record(file_id, user_id=user_id) + return self._get_existing_file_path(record) + + def open(self, file_id: str, mode: str = "rb", user_id: str | None = None) -> IO[Any]: + if mode not in {"rb", "r"}: + raise ValueError("Managed files may only be opened for reading.") + path = self.get_path(file_id, user_id=user_id) + if mode == "r": + return open(path, mode, encoding="utf-8") + return open(path, mode) + + def delete(self, file_id: str, user_id: str | None = None) -> None: + record = self._get_record(file_id, user_id=user_id) + self._get_file_path(record.stored_file_name).unlink(missing_ok=True) + self._get_metadata_path(file_id).unlink(missing_ok=True) + + def _get_record(self, file_id: str, user_id: str | None = None) -> FileRecord: + metadata_path = self._get_metadata_path(file_id) + try: + data = json.loads(metadata_path.read_text(encoding="utf-8")) + record = FileRecordAdapter.validate_python(data) + except FileNotFoundError as e: + raise FileNotFoundException(f"File not found: {file_id}") from e + except (JSONDecodeError, ValidationError) as e: + self._logger.warning(f"Invalid managed file metadata for file: {file_id}") + raise FileMetadataException(f"Invalid file metadata: {file_id}") from e + except OSError as e: + self._logger.warning(f"Unable to read managed file metadata for file: {file_id}: {e}") + raise FileMetadataException(f"Unable to read file metadata: {file_id}") from e + + if record.user_id != user_id: + raise FileAccessDeniedException(f"Not authorized to access file: {file_id}") + return record + + def _get_existing_file_path(self, record: FileRecord) -> Path: + path = self._get_file_path(record.stored_file_name) + if not path.exists(): + raise FileNotFoundException(f"File not found: {record.file_id}") + return path + + def _get_file_path(self, stored_file_name: str) -> Path: + path = self._storage_path / stored_file_name + resolved_base = self._storage_path.resolve() + resolved_path = path.resolve() + if not resolved_path.is_relative_to(resolved_base): + raise ValueError("File path outside storage folder, potential directory traversal detected.") + return resolved_path + + def _get_metadata_path(self, file_id: str) -> Path: + if Path(file_id).name != file_id: + raise ValueError("Invalid file ID, potential directory traversal detected.") + return self._get_file_path(f"{file_id}.meta.json") + + def _validate_storage_folder(self) -> None: + self._storage_path.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _record_to_dto(record: FileRecord) -> FileDTO: + return FileDTO(**record.model_dump(exclude={"stored_file_name", "user_id"})) + + @staticmethod + def _sanitize_file_name(file_name: str) -> str: + safe_file_name = Path(file_name.replace("\x00", "")).name.strip() + if not safe_file_name: + raise UnsupportedFileTypeException("Missing file name.") + return safe_file_name + + @staticmethod + def _get_supported_extension(file_name: str) -> str: + extension = Path(file_name).suffix.lower() + if extension not in SUPPORTED_FILE_EXTENSIONS: + raise UnsupportedFileTypeException(f"Unsupported file extension: {extension}") + return extension + + @staticmethod + def _normalize_content_type(content_type: str | None) -> str: + if not content_type: + return "application/octet-stream" + return content_type.split(";", 1)[0].strip().lower() + + @staticmethod + def _validate_content_type(content_type: str) -> None: + # Some browsers or reverse proxies send application/octet-stream for document files. The extension + # allowlist remains authoritative in that case. + if content_type in {"application/octet-stream", ""}: + return + if content_type not in SUPPORTED_FILE_MIME_TYPES: + raise UnsupportedFileTypeException(f"Unsupported content type: {content_type}") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 2c95f87b41d..dd10603e721 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -22,6 +22,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.external_generation.external_generation_base import ExternalGenerationServiceBase + from invokeai.app.services.files.files_base import FileServiceBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase from invokeai.app.services.images.images_base import ImageServiceABC @@ -79,6 +80,7 @@ def __init__( workflow_thumbnails: "WorkflowThumbnailServiceBase", client_state_persistence: "ClientStatePersistenceABC", users: "UserServiceBase", + files: "FileServiceBase | None" = None, ): self.board_images = board_images self.board_image_records = board_image_records @@ -111,3 +113,4 @@ def __init__( self.workflow_thumbnails = workflow_thumbnails self.client_state_persistence = client_state_persistence self.users = users + self.files = files diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index e38766d5ba2..4870b939f82 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,7 @@ from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union from PIL.Image import Image from pydantic.networks import AnyHttpUrl @@ -12,6 +12,7 @@ from invokeai.app.services.board_records.board_records_common import BoardRecordOrderBy from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.files.files_common import FileDTO from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices @@ -164,6 +165,47 @@ def error(self, message: str) -> None: self._services.logger.error(message) +class FilesInterface(InvocationContextInterface): + def get_dto(self, file_id: str) -> FileDTO: + """Gets metadata for a managed file. + + Args: + file_id: The managed file ID. + + Returns: + The file DTO. + """ + return self._get_files_service().get_dto(file_id, user_id=self._data.queue_item.user_id) + + def get_path(self, file_id: str) -> Path: + """Gets the server-side path to a managed file. + + Args: + file_id: The managed file ID. + + Returns: + The path to the managed file on the InvokeAI server. + """ + return self._get_files_service().get_path(file_id, user_id=self._data.queue_item.user_id) + + def open(self, file_id: str, mode: str = "rb") -> IO[Any]: + """Opens a managed file for reading. + + Args: + file_id: The managed file ID. + mode: The read mode to use. Supported values are "rb" and "r". + + Returns: + A readable file object. + """ + return self._get_files_service().open(file_id, mode=mode, user_id=self._data.queue_item.user_id) + + def _get_files_service(self): + if self._services.files is None: + raise RuntimeError("The managed files service is not available.") + return self._services.files + + class ImagesInterface(InvocationContextInterface): def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None: super().__init__(services, data) @@ -731,6 +773,7 @@ class InvocationContext: config (ConfigInterface): The app config. util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks. boards (BoardsInterface): Methods to interact with boards. + files (FilesInterface): Methods to read managed files uploaded for node inputs. """ def __init__( @@ -743,6 +786,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, + files: FilesInterface, data: InvocationContextData, services: InvocationServices, ) -> None: @@ -762,6 +806,8 @@ def __init__( """Utility methods, including a method to check if an invocation was canceled and step callbacks.""" self.boards = boards """Methods to interact with boards.""" + self.files = files + """Methods to read managed files uploaded for node inputs.""" self._data = data """An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning.""" self._services = services @@ -791,6 +837,7 @@ def build_invocation_context( models = ModelsInterface(services=services, data=data, util=util) images = ImagesInterface(services=services, data=data, util=util) boards = BoardsInterface(services=services, data=data) + files = FilesInterface(services=services, data=data) ctx = InvocationContext( images=images, @@ -803,6 +850,7 @@ def build_invocation_context( conditioning=conditioning, services=services, boards=boards, + files=files, ) return ctx diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index f674fa76218..8eacd42b7d8 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -8,7 +8,7 @@ InvocationRegistry, UIConfigBase, ) -from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra +from invokeai.app.invocations.fields import FileField, InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage @@ -124,6 +124,7 @@ def openapi() -> dict[str, Any]: OutputFieldJSONSchemaExtra, ModelIdentifierField, ProgressImage, + FileField, ] additional_schemas = models_json_schema( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 75367a502db..4228e8c9999 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -13,7 +13,8 @@ "toggleRightPanel": "Toggle Right Panel (G)", "toggleLeftPanel": "Toggle Left Panel (T)", "uploadImage": "Upload Image", - "uploadImages": "Upload Images" + "uploadImages": "Upload Images", + "uploadFile": "Upload File" }, "auth": { "login": { @@ -1904,6 +1905,7 @@ "imageSavingFailed": "Image Saving Failed", "imageUploaded": "Image Uploaded", "imageUploadFailed": "Image Upload Failed", + "fileUploadFailed": "File Upload Failed", "importFailed": "Import Failed", "importSuccessful": "Import Successful", "invalidUpload": "Invalid Upload", @@ -1958,6 +1960,7 @@ "uploadFailedInvalidUploadDesc_withCount_one": "Must be maximum of 1 PNG, JPEG or WEBP image.", "uploadFailedInvalidUploadDesc_withCount_other": "Must be maximum of {{count}} PNG, JPEG or WEBP images.", "uploadFailedInvalidUploadDesc": "Must be PNG, JPEG or WEBP images.", + "fileUploadFailedInvalidUploadDesc": "Must be PDF, Markdown, text, CSV, JSON or YAML files.", "workflowLoaded": "Workflow Loaded", "problemRetrievingWorkflow": "Problem Retrieving Workflow", "workflowDeleted": "Workflow Deleted", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 60a3f8e472a..08f3190ff42 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -27,6 +27,8 @@ import { isColorFieldInputTemplate, isEnumFieldInputInstance, isEnumFieldInputTemplate, + isFileFieldInputInstance, + isFileFieldInputTemplate, isFloatFieldCollectionInputInstance, isFloatFieldCollectionInputTemplate, isFloatFieldInputInstance, @@ -67,6 +69,7 @@ import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; +import FileFieldInputComponent from './inputs/FileFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; import StylePresetFieldInputComponent from './inputs/StylePresetFieldInputComponent'; @@ -202,6 +205,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) return ; } + if (isFileFieldInputTemplate(template)) { + if (!isFileFieldInputInstance(field)) { + return null; + } + return ; + } + if (isBoardFieldInputTemplate(template)) { if (!isBoardFieldInputInstance(field)) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx new file mode 100644 index 00000000000..d834741feb7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx @@ -0,0 +1,207 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Button, Flex, Icon, IconButton, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { fieldFileValueChanged } from 'features/nodes/store/nodesSlice'; +import { NO_DRAG_CLASS } from 'features/nodes/types/constants'; +import type { FileFieldInputInstance, FileFieldInputTemplate } from 'features/nodes/types/field'; +import { toast } from 'features/toast/toast'; +import { filesize } from 'filesize'; +import { memo, type MouseEvent, useCallback, useEffect } from 'react'; +import type { Accept, FileRejection } from 'react-dropzone'; +import { useDropzone } from 'react-dropzone'; +import { useTranslation } from 'react-i18next'; +import { PiFileTextBold, PiUploadBold, PiXBold } from 'react-icons/pi'; +import { useGetFileDTOQuery, useUploadFileMutation } from 'services/api/endpoints/files'; +import type { FileDTO } from 'services/api/types'; +import { $isConnected } from 'services/events/stores'; + +import type { FieldComponentProps } from './types'; + +const addUpperCaseReducer = (acc: string[], ext: string) => { + acc.push(ext); + acc.push(ext.toUpperCase()); + return acc; +}; + +const textFileExtensions = ['.csv', '.json', '.md', '.markdown', '.txt', '.yaml', '.yml'].reduce( + addUpperCaseReducer, + [] as string[] +); + +const fileUploadAccept: Accept = { + 'application/pdf': ['.pdf'].reduce(addUpperCaseReducer, [] as string[]), + 'application/json': ['.json'].reduce(addUpperCaseReducer, [] as string[]), + 'application/yaml': ['.yaml', '.yml'].reduce(addUpperCaseReducer, [] as string[]), + 'application/x-yaml': ['.yaml', '.yml'].reduce(addUpperCaseReducer, [] as string[]), + 'text/csv': ['.csv'].reduce(addUpperCaseReducer, [] as string[]), + 'text/markdown': ['.md', '.markdown'].reduce(addUpperCaseReducer, [] as string[]), + 'text/plain': textFileExtensions, + 'text/x-markdown': ['.md', '.markdown'].reduce(addUpperCaseReducer, [] as string[]), + 'text/x-yaml': ['.yaml', '.yml'].reduce(addUpperCaseReducer, [] as string[]), + 'text/yaml': ['.yaml', '.yml'].reduce(addUpperCaseReducer, [] as string[]), +}; + +const sx = { + '&[data-error=true]': { + borderColor: 'error.500', + borderStyle: 'solid', + }, + '&[data-active=true]': { + borderColor: 'invokeBlue.500', + borderStyle: 'solid', + }, +} satisfies SystemStyleObject; + +const FileFieldInputComponent = (props: FieldComponentProps) => { + const { nodeId, field, fieldTemplate } = props; + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const isConnected = useStore($isConnected); + const [uploadFile, uploadRequest] = useUploadFileMutation(); + const { currentData: fileDTO, isError } = useGetFileDTOQuery(field.value?.file_id ?? skipToken); + + const setValue = useCallback( + (value: FileDTO | undefined) => { + dispatch( + fieldFileValueChanged({ + nodeId, + fieldName: field.name, + value: value ? { file_id: value.file_id } : undefined, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const handleReset = useCallback(() => { + setValue(undefined); + }, [setValue]); + + const handleResetClick = useCallback( + (event: MouseEvent) => { + event.stopPropagation(); + handleReset(); + }, + [handleReset] + ); + + useEffect(() => { + if (isConnected && isError) { + handleReset(); + } + }, [handleReset, isConnected, isError]); + + const onDropAccepted = useCallback( + async (files: File[]) => { + const file = files[0]; + if (!file) { + return; + } + try { + const uploadedFileDTO = await uploadFile({ file }).unwrap(); + setValue(uploadedFileDTO); + } catch { + toast({ + id: 'FILE_UPLOAD_FAILED', + title: t('toast.fileUploadFailed'), + status: 'error', + }); + } + }, + [setValue, t, uploadFile] + ); + + const onDropRejected = useCallback( + (fileRejections: FileRejection[]) => { + if (fileRejections.length === 0) { + return; + } + toast({ + id: 'FILE_UPLOAD_REJECTED', + title: t('toast.uploadFailed'), + description: t('toast.fileUploadFailedInvalidUploadDesc'), + status: 'error', + }); + }, + [t] + ); + + const { getRootProps, getInputProps, isDragActive, open } = useDropzone({ + accept: fileUploadAccept, + multiple: false, + noClick: true, + noKeyboard: true, + onDropAccepted, + onDropRejected, + }); + + const handleOpenClick = useCallback( + (event: MouseEvent) => { + event.stopPropagation(); + open(); + }, + [open] + ); + + return ( + + + {!fileDTO && ( + + )} + {fileDTO && ( + <> + + + + {fileDTO.file_name} + + + {filesize(fileDTO.size_bytes)} + + + } + variant="ghost" + size="sm" + onClick={handleResetClick} + flexShrink={0} + /> + + )} + + ); +}; + +export default memo(FileFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 6713ee8fb42..2d13a6320a8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -34,6 +34,7 @@ import type { ColorFieldValue, EnumFieldValue, FieldValue, + FileFieldValue, FloatFieldValue, FloatGeneratorFieldValue, ImageFieldCollectionValue, @@ -55,6 +56,7 @@ import { zBooleanFieldValue, zColorFieldValue, zEnumFieldValue, + zFileFieldValue, zFloatFieldCollectionValue, zFloatFieldValue, zFloatGeneratorFieldValue, @@ -515,6 +517,9 @@ const slice = createSlice({ fieldImageValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zImageFieldValue); }, + fieldFileValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zFileFieldValue); + }, fieldImageCollectionValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zImageFieldCollectionValue); }, @@ -665,6 +670,7 @@ export const { fieldStylePresetValueChanged, fieldEnumModelValueChanged, fieldImageValueChanged, + fieldFileValueChanged, fieldImageCollectionValueChanged, fieldLabelChanged, fieldModelIdentifierValueChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts index 04e0fea2cc9..835e9075b30 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -3,6 +3,7 @@ import type { Classification, ColorField, ControlField, + FileField, ImageField, ImageOutput, IPAdapterField, @@ -35,6 +36,7 @@ import type z from 'zod'; describe('Common types', () => { // Complex field types test('ImageField', () => assert>()); + test('FileField', () => assert>()); test('BoardField', () => assert>()); test('ColorField', () => assert>()); test('SchedulerField', () => assert['scheduler']>>>()); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index fb2a1ce946a..79eb7b3dd5e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -11,6 +11,12 @@ type ImageFieldCollection = z.infer; export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection => zImageFieldCollection.safeParse(field).success; +export const zFileField = z.object({ + file_id: z.string().trim().min(1), +}); +export type FileField = z.infer; +export const isFileField = (field: unknown): field is FileField => zFileField.safeParse(field).success; + export const zBoardField = z.object({ board_id: z.string().trim().min(1), }); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 9da499ab91c..32d95de0f63 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -44,6 +44,7 @@ export const FIELD_COLORS: { [key: string]: string } = { ControlNetModelField: 'teal.500', EnumField: 'blue.500', FloatField: 'orange.500', + FileField: 'cyan.500', ImageField: 'purple.500', ImageBatchField: 'purple.500', IntegerField: 'red.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index ffd87ae3984..c419ea7be17 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -14,6 +14,7 @@ import { zBaseModelType, zBoardField, zColorField, + zFileField, zImageField, zModelFormat, zModelIdentifierField, @@ -161,6 +162,10 @@ const zImageCollectionFieldType = zFieldTypeBase.extend({ cardinality: z.literal(COLLECTION), originalType: zStatelessFieldType.optional(), }); +const zFileFieldType = zFieldTypeBase.extend({ + name: z.literal('FileField'), + originalType: zStatelessFieldType.optional(), +}); export const isImageCollectionFieldType = ( fieldType: FieldType ): fieldType is z.infer => @@ -211,6 +216,7 @@ const zStatefulFieldType = z.union([ zBooleanFieldType, zEnumFieldType, zImageFieldType, + zFileFieldType, zBoardFieldType, zStylePresetFieldType, zModelIdentifierFieldType, @@ -559,6 +565,26 @@ export const isImageFieldInputTemplate = buildTemplateTypeGuard; +export type FileFieldInputInstance = z.infer; +export type FileFieldInputTemplate = z.infer; +export const isFileFieldInputInstance = buildInstanceTypeGuard(zFileFieldInputInstance); +export const isFileFieldInputTemplate = buildTemplateTypeGuard('FileField', ['SINGLE']); +// #endregion + // #region ImageField Collection export const zImageFieldCollectionValue = z.array(zImageField).optional(); const zImageFieldCollectionInputInstance = zFieldInputInstanceBase.extend({ @@ -1284,6 +1310,7 @@ export const zStatefulFieldValue = z.union([ zBooleanFieldValue, zEnumFieldValue, zImageFieldValue, + zFileFieldValue, zImageFieldCollectionValue, zBoardFieldValue, zStylePresetFieldValue, @@ -1312,6 +1339,7 @@ const zStatefulFieldInputInstance = z.union([ zBooleanFieldInputInstance, zEnumFieldInputInstance, zImageFieldInputInstance, + zFileFieldInputInstance, zImageFieldCollectionInputInstance, zBoardFieldInputInstance, zStylePresetFieldInputInstance, @@ -1339,6 +1367,7 @@ const zStatefulFieldInputTemplate = z.union([ zBooleanFieldInputTemplate, zEnumFieldInputTemplate, zImageFieldInputTemplate, + zFileFieldInputTemplate, zImageFieldCollectionInputTemplate, zBoardFieldInputTemplate, zStylePresetFieldInputTemplate, @@ -1367,6 +1396,7 @@ const zStatefulFieldOutputTemplate = z.union([ zBooleanFieldOutputTemplate, zEnumFieldOutputTemplate, zImageFieldOutputTemplate, + zFileFieldOutputTemplate, zImageFieldCollectionOutputTemplate, zBoardFieldOutputTemplate, zStylePresetFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index ef7b92efdd8..d150b142254 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -8,6 +8,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = ColorField: { r: 0, g: 0, b: 0, a: 1 }, FloatField: 0, ImageField: undefined, + FileField: undefined, IntegerField: 0, ModelIdentifierField: undefined, SchedulerField: 'dpmpp_3m_k', diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index adaa3f413ce..cbdeb863a1d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -7,6 +7,7 @@ import type { EnumFieldInputTemplate, FieldInputTemplate, FieldType, + FileFieldInputTemplate, FloatFieldCollectionInputTemplate, FloatFieldInputTemplate, FloatGeneratorFieldInputTemplate, @@ -318,6 +319,20 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: FileFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageFieldCollectionInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -471,6 +486,7 @@ const TEMPLATE_BUILDER_MAP: Record buildV1Url(`files/${path}`); + +export const filesApi = api.injectEndpoints({ + endpoints: (build) => ({ + getFileDTO: build.query({ + query: (file_id) => ({ url: buildFilesUrl(`i/${file_id}`) }), + providesTags: (result, error, file_id) => [{ type: 'File', id: file_id }], + }), + uploadFile: build.mutation({ + query: ({ file }) => { + const formData = new FormData(); + formData.append('file', file); + return { + url: buildFilesUrl('upload'), + method: 'POST', + body: formData, + }; + }, + invalidatesTags: (result) => (result ? [{ type: 'File', id: result.file_id }] : []), + }), + deleteFile: build.mutation({ + query: (file_id) => ({ + url: buildFilesUrl(`i/${file_id}`), + method: 'DELETE', + }), + invalidatesTags: (result, error, file_id) => [{ type: 'File', id: file_id }], + }), + }), +}); + +export const { useDeleteFileMutation, useGetFileDTOQuery, useUploadFileMutation } = filesApi; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index a586273f3a7..347f5cc9544 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -18,6 +18,7 @@ const tagTypes = [ 'BoardImagesTotal', 'BoardAssetsTotal', 'HFTokenStatus', + 'File', 'Image', 'ImageNameList', 'ImageList', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 3a5da907f35..c2851fb8126 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1001,6 +1001,50 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/files/upload": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Upload File + * @description Uploads a managed file for node inputs. + */ + post: operations["upload_file"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/files/i/{file_id}": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get File Dto + * @description Gets metadata for a managed file. + */ + get: operations["get_file_dto"]; + put?: never; + post?: never; + /** + * Delete File + * @description Deletes a managed file. + */ + delete: operations["delete_file"]; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/images/upload": { parameters: { query?: never; @@ -4350,6 +4394,14 @@ export type components = { */ is_public: boolean; }; + /** Body_upload_file */ + Body_upload_file: { + /** + * File + * Format: binary + */ + file: Blob; + }; /** Body_upload_image */ Body_upload_image: { /** @@ -9812,6 +9864,46 @@ export type components = { * @enum {string} */ FieldKind: "input" | "output" | "internal" | "node_attribute"; + /** FileDTO */ + FileDTO: { + /** + * File Id + * @description The managed file ID. + */ + file_id: string; + /** + * File Name + * @description The original file name. + */ + file_name: string; + /** + * Content Type + * @description The uploaded file content type. + */ + content_type: string; + /** + * Size Bytes + * @description The size of the file in bytes. + */ + size_bytes: number; + /** + * Created At + * Format: date-time + * @description When the file was uploaded. + */ + created_at: string; + }; + /** + * FileField + * @description A managed file primitive field + */ + FileField: { + /** + * File Id + * @description The id of the managed file + */ + file_id: string; + }; /** * Float Batch * @description Create a batched generation, where the workflow is executed once for each float in the batch. @@ -16560,6 +16652,12 @@ export type components = { * @description Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. */ max_queue_history?: number | null; + /** + * Max File Upload Size Bytes + * @description Maximum size in bytes for managed file uploads. + * @default 52428800 + */ + max_file_upload_size_bytes?: number; /** * Allow Nodes * @description List of nodes to allow. Omit to allow all. @@ -34944,6 +35042,115 @@ export interface operations { }; }; }; + upload_file: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "multipart/form-data": components["schemas"]["Body_upload_file"]; + }; + }; + responses: { + /** @description The file was uploaded successfully */ + 201: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["FileDTO"]; + }; + }; + /** @description The file is too large */ + 413: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description The file type is not supported */ + 415: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + get_file_dto: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The managed file ID. */ + file_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["FileDTO"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + delete_file: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The managed file ID. */ + file_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; upload_image: { parameters: { query: { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 27c6fcbf3c3..687a16ccb78 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -90,6 +90,16 @@ const _zImageDTO = z.object({ export type ImageDTO = z.infer; assert>(); +const _zFileDTO = z.object({ + file_id: z.string(), + file_name: z.string(), + content_type: z.string(), + size_bytes: z.number().int().gte(0), + created_at: z.string(), +}); +export type FileDTO = z.infer; +assert>(); + export type BoardDTO = S['BoardDTO']; export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; @@ -610,5 +620,12 @@ export type UploadImageArg = { resize_to?: Dimensions; }; +export type UploadFileArg = { + /** + * The file object to upload + */ + file: File; +}; + export type ImageUploadEntryResponse = S['ImageUploadEntry']; export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json']; diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 7cc5e065fd6..cfbd1b3e871 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -20,6 +20,7 @@ DenoiseMaskField, FieldDescriptions, FieldKind, + FileField, FluxConditioningField, ImageField, Input, @@ -161,6 +162,7 @@ "DenoiseMaskField", "FieldDescriptions", "FieldKind", + "FileField", "FluxConditioningField", "ImageField", "Input", diff --git a/tests/app/services/files/__init__.py b/tests/app/services/files/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/app/services/files/test_files_disk.py b/tests/app/services/files/test_files_disk.py new file mode 100644 index 00000000000..e48a387efd8 --- /dev/null +++ b/tests/app/services/files/test_files_disk.py @@ -0,0 +1,164 @@ +from io import BytesIO +from pathlib import Path + +import pytest + +from invokeai.app.services.files.files_common import ( + FileAccessDeniedException, + FileMetadataException, + FileNotFoundException, + FileTooLargeException, + UnsupportedFileTypeException, +) +from invokeai.app.services.files.files_disk import DiskFileService + + +def test_save_open_get_path_and_delete_round_trip(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + + dto = service.save( + file_name="report.pdf", + content_type="application/pdf", + file=BytesIO(b"%PDF-1.7 test"), + user_id="user-1", + ) + + assert dto.file_name == "report.pdf" + assert dto.content_type == "application/pdf" + assert dto.size_bytes == len(b"%PDF-1.7 test") + + path = service.get_path(dto.file_id, user_id="user-1") + assert path.is_relative_to(tmp_path.resolve()) + assert path.name == f"{dto.file_id}.pdf" + assert path.read_bytes() == b"%PDF-1.7 test" + + with service.open(dto.file_id, user_id="user-1") as opened_file: + assert opened_file.read() == b"%PDF-1.7 test" + + service.delete(dto.file_id, user_id="user-1") + + with pytest.raises(FileNotFoundException): + service.get_dto(dto.file_id, user_id="user-1") + + +def test_json_upload_does_not_collide_with_metadata(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + content = b'{"hello": "world"}' + + dto = service.save( + file_name="data.json", + content_type="application/json", + file=BytesIO(content), + user_id="user-1", + ) + + file_path = service.get_path(dto.file_id, user_id="user-1") + metadata_path = tmp_path / f"{dto.file_id}.meta.json" + + assert file_path.name == f"{dto.file_id}.json" + assert file_path.read_bytes() == content + assert metadata_path.exists() + assert metadata_path != file_path + + +def test_open_text_mode_reads_utf8(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + content = "cafe 日本語" + dto = service.save( + file_name="notes.md", + content_type="text/markdown", + file=BytesIO(content.encode("utf-8")), + user_id=None, + ) + + with service.open(dto.file_id, mode="r", user_id=None) as opened_file: + assert opened_file.read() == content + + +def test_get_dto_fails_when_file_is_missing(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + dto = service.save(file_name="notes.txt", content_type="text/plain", file=BytesIO(b"notes"), user_id=None) + + service.get_path(dto.file_id).unlink() + + with pytest.raises(FileNotFoundException): + service.get_dto(dto.file_id) + + +def test_get_dto_fails_when_metadata_is_invalid(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + dto = service.save(file_name="notes.txt", content_type="text/plain", file=BytesIO(b"notes"), user_id=None) + + (tmp_path / f"{dto.file_id}.meta.json").write_text("{", encoding="utf-8") + + with pytest.raises(FileMetadataException): + service.get_dto(dto.file_id) + + +def test_sanitizes_file_name_to_basename(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + + dto = service.save( + file_name="../nested/data.csv", + content_type="text/csv", + file=BytesIO(b"a,b\n1,2\n"), + user_id=None, + ) + + assert dto.file_name == "data.csv" + assert service.get_path(dto.file_id).is_relative_to(tmp_path.resolve()) + + +@pytest.mark.parametrize("file_name", ["image.png", "archive.zip", "no-extension"]) +def test_rejects_unsupported_extension(tmp_path: Path, file_name: str) -> None: + service = DiskFileService(tmp_path) + + with pytest.raises(UnsupportedFileTypeException): + service.save(file_name=file_name, content_type="application/octet-stream", file=BytesIO(b"data"), user_id=None) + + +@pytest.mark.parametrize("content_type", ["image/png", "application/zip"]) +def test_rejects_unsupported_content_type(tmp_path: Path, content_type: str) -> None: + service = DiskFileService(tmp_path) + + with pytest.raises(UnsupportedFileTypeException): + service.save(file_name="data.json", content_type=content_type, file=BytesIO(b"{}"), user_id=None) + + +@pytest.mark.parametrize("content_type", [None, "", "application/octet-stream"]) +def test_allows_octet_stream_for_allowed_extensions(tmp_path: Path, content_type: str | None) -> None: + service = DiskFileService(tmp_path) + + dto = service.save(file_name="notes.md", content_type=content_type, file=BytesIO(b"# Notes"), user_id=None) + + assert dto.content_type == "application/octet-stream" + + +def test_rejects_files_over_size_limit_and_cleans_up(tmp_path: Path) -> None: + service = DiskFileService(tmp_path, max_file_size=4) + + with pytest.raises(FileTooLargeException): + service.save(file_name="large.txt", content_type="text/plain", file=BytesIO(b"12345"), user_id=None) + + assert list(tmp_path.iterdir()) == [] + + +def test_user_scoped_access(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + dto = service.save(file_name="private.txt", content_type="text/plain", file=BytesIO(b"secret"), user_id="user-1") + + service.get_dto(dto.file_id, user_id="user-1") + + with pytest.raises(FileAccessDeniedException): + service.get_dto(dto.file_id, user_id=None) + + with pytest.raises(FileAccessDeniedException): + service.get_dto(dto.file_id, user_id="user-2") + + +def test_rejects_write_modes(tmp_path: Path) -> None: + service = DiskFileService(tmp_path) + dto = service.save(file_name="notes.txt", content_type="text/plain", file=BytesIO(b"notes"), user_id=None) + + with pytest.raises(ValueError, match="only be opened for reading"): + service.open(dto.file_id, mode="wb") From 2d3cdb3590a198f8ad7a559c65ee717f6d2b11c2 Mon Sep 17 00:00:00 2001 From: Gohankaiju <167270541+Gohankaiju@users.noreply.github.com> Date: Fri, 26 Jun 2026 04:03:15 +0000 Subject: [PATCH 2/2] fix(nodes): clean up managed file input checks --- invokeai/frontend/web/openapi.json | 204 ++++++++++++++++++ .../fields/inputs/FileFieldInputComponent.tsx | 42 +++- .../web/src/features/nodes/types/common.ts | 1 - .../web/src/services/api/endpoints/files.ts | 2 +- 4 files changed, 241 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 32fe4f2409f..dafdcbbbf41 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -4154,6 +4154,144 @@ ] } }, + "/api/v1/files/upload": { + "post": { + "tags": ["files"], + "summary": "Upload File", + "description": "Uploads a managed file for node inputs.", + "operationId": "upload_file", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_file" + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "The file was uploaded successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FileDTO" + } + } + } + }, + "413": { + "description": "The file is too large" + }, + "415": { + "description": "The file type is not supported" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } + }, + "/api/v1/files/i/{file_id}": { + "get": { + "tags": ["files"], + "summary": "Get File Dto", + "description": "Gets metadata for a managed file.", + "operationId": "get_file_dto", + "security": [ + { + "HTTPBearer": [] + } + ], + "parameters": [ + { + "name": "file_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "description": "The managed file ID.", + "title": "File Id" + }, + "description": "The managed file ID." + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FileDTO" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["files"], + "summary": "Delete File", + "description": "Deletes a managed file.", + "operationId": "delete_file", + "security": [ + { + "HTTPBearer": [] + } + ], + "parameters": [ + { + "name": "file_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "description": "The managed file ID.", + "title": "File Id" + }, + "description": "The managed file ID." + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/v1/images/upload": { "post": { "tags": ["images"], @@ -13035,6 +13173,18 @@ "required": ["is_public"], "title": "Body_update_workflow_is_public" }, + "Body_upload_file": { + "properties": { + "file": { + "type": "string", + "format": "binary", + "title": "File" + } + }, + "type": "object", + "required": ["file"], + "title": "Body_upload_file" + }, "Body_upload_image": { "properties": { "file": { @@ -23584,6 +23734,53 @@ "title": "FieldKind", "type": "string" }, + "FileDTO": { + "properties": { + "file_id": { + "type": "string", + "title": "File Id", + "description": "The managed file ID." + }, + "file_name": { + "type": "string", + "title": "File Name", + "description": "The original file name." + }, + "content_type": { + "type": "string", + "title": "Content Type", + "description": "The uploaded file content type." + }, + "size_bytes": { + "type": "integer", + "minimum": 0.0, + "title": "Size Bytes", + "description": "The size of the file in bytes." + }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At", + "description": "When the file was uploaded." + } + }, + "type": "object", + "required": ["file_id", "file_name", "content_type", "size_bytes", "created_at"], + "title": "FileDTO" + }, + "FileField": { + "description": "A managed file primitive field", + "properties": { + "file_id": { + "description": "The id of the managed file", + "title": "File Id", + "type": "string" + } + }, + "required": ["file_id"], + "title": "FileField", + "type": "object" + }, "FloatBatchInvocation": { "category": "batch", "class": "invocation", @@ -41184,6 +41381,13 @@ "title": "Max Queue History", "description": "Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true." }, + "max_file_upload_size_bytes": { + "type": "integer", + "exclusiveMinimum": 0.0, + "title": "Max File Upload Size Bytes", + "description": "Maximum size in bytes for managed file uploads.", + "default": 52428800 + }, "allow_nodes": { "anyOf": [ { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx index d834741feb7..45decca0f7c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FileFieldInputComponent.tsx @@ -13,7 +13,7 @@ import type { Accept, FileRejection } from 'react-dropzone'; import { useDropzone } from 'react-dropzone'; import { useTranslation } from 'react-i18next'; import { PiFileTextBold, PiUploadBold, PiXBold } from 'react-icons/pi'; -import { useGetFileDTOQuery, useUploadFileMutation } from 'services/api/endpoints/files'; +import { useDeleteFileMutation, useGetFileDTOQuery, useUploadFileMutation } from 'services/api/endpoints/files'; import type { FileDTO } from 'services/api/types'; import { $isConnected } from 'services/events/stores'; @@ -60,7 +60,9 @@ const FileFieldInputComponent = (props: FieldComponentProps { @@ -75,10 +77,33 @@ const FileFieldInputComponent = (props: FieldComponentProps { + const clearValue = useCallback(() => { setValue(undefined); }, [setValue]); + const deleteManagedFile = useCallback( + async (file_id: string) => { + try { + await deleteFile(file_id).unwrap(); + } catch { + toast({ + id: 'FILE_DELETE_FAILED', + title: t('toast.somethingWentWrong'), + status: 'error', + }); + } + }, + [deleteFile, t] + ); + + const handleReset = useCallback(() => { + const file_id = currentFileId; + clearValue(); + if (file_id) { + void deleteManagedFile(file_id); + } + }, [clearValue, currentFileId, deleteManagedFile]); + const handleResetClick = useCallback( (event: MouseEvent) => { event.stopPropagation(); @@ -89,9 +114,9 @@ const FileFieldInputComponent = (props: FieldComponentProps { if (isConnected && isError) { - handleReset(); + clearValue(); } - }, [handleReset, isConnected, isError]); + }, [clearValue, isConnected, isError]); const onDropAccepted = useCallback( async (files: File[]) => { @@ -100,8 +125,12 @@ const FileFieldInputComponent = (props: FieldComponentProps} variant="ghost" size="sm" + isDisabled={deleteRequest.isLoading} onClick={handleResetClick} flexShrink={0} /> diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 79eb7b3dd5e..14b20d0d698 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -15,7 +15,6 @@ export const zFileField = z.object({ file_id: z.string().trim().min(1), }); export type FileField = z.infer; -export const isFileField = (field: unknown): field is FileField => zFileField.safeParse(field).success; export const zBoardField = z.object({ board_id: z.string().trim().min(1), diff --git a/invokeai/frontend/web/src/services/api/endpoints/files.ts b/invokeai/frontend/web/src/services/api/endpoints/files.ts index fdf22979583..da1ed76f009 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/files.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/files.ts @@ -10,7 +10,7 @@ import { api, buildV1Url } from '..'; */ const buildFilesUrl = (path: string = '') => buildV1Url(`files/${path}`); -export const filesApi = api.injectEndpoints({ +const filesApi = api.injectEndpoints({ endpoints: (build) => ({ getFileDTO: build.query({ query: (file_id) => ({ url: buildFilesUrl(`i/${file_id}`) }),