Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SeedreamProvider,
)
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
from invokeai.app.services.files.files_disk import DiskFileService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
Expand Down Expand Up @@ -107,6 +108,7 @@ def initialize(
raise ValueError("Output folder is not set")

image_files = DiskImageFileStorage(f"{output_folder}/images")
files = DiskFileService(output_folder / "files", max_file_size=config.max_file_upload_size_bytes)

model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
Expand Down Expand Up @@ -197,6 +199,7 @@ def initialize(
bulk_download=bulk_download,
configuration=configuration,
events=events,
files=files,
image_files=image_files,
image_records=image_records,
images=images,
Expand Down
101 changes: 101 additions & 0 deletions invokeai/app/api/routers/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from fastapi import HTTPException, Path, Response, UploadFile
from fastapi.concurrency import run_in_threadpool
from fastapi.routing import APIRouter

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.files.files_base import FileServiceBase
from invokeai.app.services.files.files_common import (
FileAccessDeniedException,
FileDTO,
FileMetadataException,
FileNotFoundException,
FileTooLargeException,
UnsupportedFileTypeException,
)

files_router = APIRouter(prefix="/v1/files", tags=["files"])


def _get_file_service() -> FileServiceBase:
file_service = ApiDependencies.invoker.services.files
if file_service is None:
raise HTTPException(status_code=503, detail="Managed file service is not available")
return file_service


@files_router.post(
"/upload",
operation_id="upload_file",
responses={
201: {"description": "The file was uploaded successfully"},
413: {"description": "The file is too large"},
415: {"description": "The file type is not supported"},
},
status_code=201,
response_model=FileDTO,
)
async def upload_file(
current_user: CurrentUserOrDefault,
file: UploadFile,
response: Response,
) -> FileDTO:
"""Uploads a managed file for node inputs."""
try:
file_dto = await run_in_threadpool(
_get_file_service().save,
file_name=file.filename or "",
content_type=file.content_type,
file=file.file,
user_id=current_user.user_id,
)
response.status_code = 201
return file_dto
except UnsupportedFileTypeException as e:
raise HTTPException(status_code=415, detail=str(e))
except FileTooLargeException as e:
raise HTTPException(status_code=413, detail=str(e))
except FileMetadataException as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
await file.close()


@files_router.get(
"/i/{file_id}",
operation_id="get_file_dto",
response_model=FileDTO,
)
async def get_file_dto(
current_user: CurrentUserOrDefault,
file_id: str = Path(description="The managed file ID."),
) -> FileDTO:
"""Gets metadata for a managed file."""
try:
return await run_in_threadpool(_get_file_service().get_dto, file_id, user_id=current_user.user_id)
except FileAccessDeniedException:
raise HTTPException(status_code=403, detail="Not authorized to access this file")
except FileNotFoundException:
raise HTTPException(status_code=404, detail="File not found")
except FileMetadataException as e:
raise HTTPException(status_code=500, detail=str(e))


@files_router.delete(
"/i/{file_id}",
operation_id="delete_file",
status_code=204,
)
async def delete_file(
current_user: CurrentUserOrDefault,
file_id: str = Path(description="The managed file ID."),
) -> None:
"""Deletes a managed file."""
try:
await run_in_threadpool(_get_file_service().delete, file_id, user_id=current_user.user_id)
except FileAccessDeniedException:
raise HTTPException(status_code=403, detail="Not authorized to delete this file")
except FileNotFoundException:
raise HTTPException(status_code=404, detail="File not found")
except FileMetadataException as e:
raise HTTPException(status_code=500, detail=str(e))
62 changes: 61 additions & 1 deletion invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.types import ASGIApp, Message, Receive, Scope, Send

import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
Expand All @@ -23,6 +24,7 @@
client_state,
custom_nodes,
download_queue,
files,
images,
model_manager,
model_relationships,
Expand Down Expand Up @@ -125,6 +127,62 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
return response


MANAGED_FILE_UPLOAD_MULTIPART_OVERHEAD_BYTES = 1024 * 1024


class RequestBodyTooLarge(Exception):
pass


class ManagedFileUploadSizeLimitMiddleware:
def __init__(self, app: ASGIApp, max_file_bytes: int) -> None:
self.app = app
self.max_file_bytes = max_file_bytes
self.max_body_bytes = max_file_bytes + MANAGED_FILE_UPLOAD_MULTIPART_OVERHEAD_BYTES

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope.get("type") != "http" or scope.get("method") != "POST" or scope.get("path") != "/api/v1/files/upload":
await self.app(scope, receive, send)
return

headers = dict(scope.get("headers", []))
content_length_header = headers.get(b"content-length")
content_length = None
if content_length_header is not None:
try:
content_length = int(content_length_header)
except ValueError:
content_length = None

if content_length is not None and content_length > self.max_body_bytes:
response = JSONResponse(
{"detail": f"File exceeds the maximum size of {self.max_file_bytes} bytes."},
status_code=413,
)
await response(scope, receive, send)
return

received_bytes = 0

async def limited_receive() -> Message:
nonlocal received_bytes
message = await receive()
if message["type"] == "http.request":
received_bytes += len(message.get("body", b""))
if received_bytes > self.max_body_bytes:
raise RequestBodyTooLarge
return message

try:
await self.app(scope, limited_receive, send)
except RequestBodyTooLarge:
response = JSONResponse(
{"detail": f"File exceeds the maximum size of {self.max_file_bytes} bytes."},
status_code=413,
)
await response(scope, receive, send)


class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
"""When a request is made to the root path with a query string, redirect to the root path without the query string.

Expand All @@ -146,6 +204,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
# Add the middleware
app.add_middleware(RedirectRootWithQueryStringMiddleware)
app.add_middleware(SlidingWindowTokenMiddleware)
app.add_middleware(ManagedFileUploadSizeLimitMiddleware, max_file_bytes=app_config.max_file_upload_size_bytes)


# Add event handler
Expand Down Expand Up @@ -176,6 +235,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(files.files_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image")


class FileField(BaseModel):
"""A managed file primitive field"""

file_id: str = Field(description="The id of the managed file")


class BoardField(BaseModel):
"""A board primitive field"""

Expand Down
1 change: 1 addition & 0 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class InvokeAIAppConfig(BaseSettings):
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.")
max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.")
max_file_upload_size_bytes: int = Field(default=50 * 1024 * 1024, gt=0, description="Maximum size in bytes for managed file uploads.")

# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
Expand Down
23 changes: 23 additions & 0 deletions invokeai/app/services/files/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from invokeai.app.services.files.files_base import FileServiceBase
from invokeai.app.services.files.files_common import (
FileAccessDeniedException,
FileDTO,
FileMetadataException,
FileNotFoundException,
FileStorageException,
FileTooLargeException,
UnsupportedFileTypeException,
)
from invokeai.app.services.files.files_disk import DiskFileService

__all__ = [
"DiskFileService",
"FileAccessDeniedException",
"FileDTO",
"FileMetadataException",
"FileNotFoundException",
"FileServiceBase",
"FileStorageException",
"FileTooLargeException",
"UnsupportedFileTypeException",
]
33 changes: 33 additions & 0 deletions invokeai/app/services/files/files_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import IO, Any, BinaryIO

from invokeai.app.services.files.files_common import FileDTO


class FileServiceBase(ABC):
@abstractmethod
def save(
self,
file_name: str,
content_type: str | None,
file: BinaryIO,
user_id: str | None,
) -> FileDTO:
"""Saves a managed file and returns its DTO."""

@abstractmethod
def get_dto(self, file_id: str, user_id: str | None = None) -> FileDTO:
"""Gets a managed file DTO."""

@abstractmethod
def get_path(self, file_id: str, user_id: str | None = None) -> Path:
"""Gets the path to a managed file."""

@abstractmethod
def open(self, file_id: str, mode: str = "rb", user_id: str | None = None) -> IO[Any]:
"""Opens a managed file."""

@abstractmethod
def delete(self, file_id: str, user_id: str | None = None) -> None:
"""Deletes a managed file."""
65 changes: 65 additions & 0 deletions invokeai/app/services/files/files_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from datetime import datetime

from pydantic import BaseModel, Field

DEFAULT_FILE_UPLOAD_MAX_BYTES = 50 * 1024 * 1024

SUPPORTED_FILE_EXTENSIONS = frozenset(
{
".csv",
".json",
".md",
".markdown",
".pdf",
".txt",
".yaml",
".yml",
}
)

SUPPORTED_FILE_MIME_TYPES = frozenset(
{
"application/json",
"application/pdf",
"application/x-yaml",
"application/yaml",
"text/csv",
"text/markdown",
"text/plain",
"text/x-markdown",
"text/x-yaml",
"text/yaml",
}
)


class FileDTO(BaseModel):
file_id: str = Field(description="The managed file ID.")
file_name: str = Field(description="The original file name.")
content_type: str = Field(description="The uploaded file content type.")
size_bytes: int = Field(description="The size of the file in bytes.", ge=0)
created_at: datetime = Field(description="When the file was uploaded.")


class FileStorageException(Exception):
"""Base exception for managed file storage errors."""


class FileNotFoundException(FileStorageException):
"""Raised when a managed file cannot be found."""


class UnsupportedFileTypeException(FileStorageException):
"""Raised when a managed file upload is not allowed."""


class FileTooLargeException(FileStorageException):
"""Raised when a managed file upload exceeds the size limit."""


class FileAccessDeniedException(FileStorageException):
"""Raised when a user cannot access a managed file."""


class FileMetadataException(FileStorageException):
"""Raised when managed file metadata cannot be read or validated."""
Loading