From ac79186e4755e96692ee3014a2b726db976160df Mon Sep 17 00:00:00 2001 From: Lucian Hardy Date: Sun, 27 Apr 2025 01:24:30 +1000 Subject: [PATCH 1/3] feat(ui): model relationship management Adds full support for managing model-to-model relationships in the UI and backend. Introduces RelatedModels subpanel for linking and unlinking models in model management. - Adds REST API routes for adding, removing, and retrieving model relationships. - New database migration: creates model_relationships table for bidirectional links. - New service layer (model_relationships) for relationship management. - Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access. - Added 'Show Only Related' toggle badge to MainModelPicker filter bar **Amended commit to remove changes to ParamMainModelSelect.tsx and MainModelPicker.tsx to avoid conflict with upstream deletion/ rewrite** --- invokeai/app/api/dependencies.py | 6 + .../app/api/routers/model_relationships.py | 196 ++++++++++++ invokeai/app/api_app.py | 2 + invokeai/app/services/invocation_services.py | 6 + .../model_relationship_records_base.py | 57 ++++ .../model_relationship_records_sqlite.py | 89 ++++++ .../model_relationships_base.py | 42 +++ .../model_relationships_common.py | 8 + .../model_relationships_default.py | 30 ++ .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_20.py | 37 +++ invokeai/frontend/web/public/locales/en.json | 2 + .../hooks/useRelatedGroupedModelCombobox.ts | 92 ++++++ .../src/common/hooks/useRelatedModelKeys.ts | 14 + .../src/common/hooks/useSelectedModelKeys.ts | 34 ++ .../features/lora/components/LoRASelect.tsx | 4 +- .../subpanels/ModelPanel/ModelView.tsx | 4 + .../subpanels/ModelPanel/RelatedModels.tsx | 300 ++++++++++++++++++ .../api/endpoints/modelRelationships.ts | 67 ++++ .../frontend/web/src/services/api/index.ts | 1 + .../frontend/web/src/services/api/schema.ts | 264 ++++++++++++++- 21 files changed, 1253 insertions(+), 4 deletions(-) create mode 100644 invokeai/app/api/routers/model_relationships.py create mode 100644 invokeai/app/services/model_relationship_records/model_relationship_records_base.py create mode 100644 invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_base.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_common.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_default.py create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py create mode 100644 invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts create mode 100644 invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts create mode 100644 invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx create mode 100644 invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a5a7dbef9c7..83b8bb219d8 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -23,6 +23,8 @@ from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk from invokeai.app.services.model_manager.model_manager_default import ModelManagerService from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL +from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService +from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import SqliteModelRelationshipRecordStorage from invokeai.app.services.names.names_default import SimpleNameService from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -136,6 +138,8 @@ def initialize( download_queue=download_queue_service, events=events, ) + model_relationships = ModelRelationshipsService() + model_relationship_records = SqliteModelRelationshipRecordStorage(db=db) names = SimpleNameService() performance_statistics = InvocationStatsService() session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) @@ -161,6 +165,8 @@ def initialize( logger=logger, model_images=model_images_service, model_manager=model_manager, + model_relationships=model_relationships, + model_relationship_records=model_relationship_records, download_queue=download_queue_service, names=names, performance_statistics=performance_statistics, diff --git a/invokeai/app/api/routers/model_relationships.py b/invokeai/app/api/routers/model_relationships.py new file mode 100644 index 00000000000..d550cdcb787 --- /dev/null +++ b/invokeai/app/api/routers/model_relationships.py @@ -0,0 +1,196 @@ +"""FastAPI route for model relationship records.""" + +from fastapi import HTTPException, APIRouter, Path, Body, status +from pydantic import BaseModel, Field +from typing import List +from invokeai.app.api.dependencies import ApiDependencies + +model_relationships_router = APIRouter( + prefix="/v1/model_relationships", + tags=["model_relationships"] +) + +# === Schemas === + +class ModelRelationshipCreateRequest(BaseModel): + model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[ + "aa3b247f-90c9-4416-bfcd-aeaa57a5339e", + "ac32b914-10ab-496e-a24a-3068724b9c35", + "d944abfd-c7c3-42e2-a4ff-da640b29b8b4", + "b1c2d3e4-f5a6-7890-abcd-ef1234567890", + "12345678-90ab-cdef-1234-567890abcdef", + "fedcba98-7654-3210-fedc-ba9876543210" + ]) + model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[ + "3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", + "f0c3da4e-d9ff-42b5-a45c-23be75c887c9", + "38170dd8-f1e5-431e-866c-2c81f1277fcc", + "c57fea2d-7646-424c-b9ad-c0ba60fc68be", + "10f7807b-ab54-46a9-ab03-600e88c630a1", + "f6c1d267-cf87-4ee0-bee0-37e791eacab7" + ]) + +class ModelRelationshipBatchRequest(BaseModel): + model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples= + [[ + "aa3b247f-90c9-4416-bfcd-aeaa57a5339e", + "ac32b914-10ab-496e-a24a-3068724b9c35", + ],[ + "b1c2d3e4-f5a6-7890-abcd-ef1234567890", + "12345678-90ab-cdef-1234-567890abcdef", + "fedcba98-7654-3210-fedc-ba9876543210" + ],[ + "3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", + ]]) + +# === Routes === + +@model_relationships_router.get( + "/i/{model_key}", + operation_id="get_related_models", + response_model=list[str], + responses={ + 200: { + "description": "A list of related model keys was retrieved successfully", + "content": { + "application/json": { + "example": [ + "15e9eb28-8cfe-47c9-b610-37907a79fc3c", + "71272e82-0e5f-46d5-bca9-9a61f4bd8a82", + "a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" + ] + } + }, + }, + 404: {"description": "The specified model could not be found"}, + 422: {"description": "Validation error"}, + }, +) +async def get_related_models( + model_key: str = Path(..., description="The key of the model to get relationships for") + ) -> list[str]: + """ + Get a list of model keys related to a given model. + """ + try: + return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@model_relationships_router.post( + "/", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 204: {"description": "The relationship was successfully created"}, + 400: {"description": "Invalid model keys or self-referential relationship"}, + 409: {"description": "The relationship already exists"}, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Add Model Relationship", + description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.", +) +async def add_model_relationship( + req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate") +) -> None: + """ + Add a relationship between two models. + + Relationships are bidirectional and will be accessible from both models. + + - Raises 400 if keys are invalid or identical. + - Raises 409 if the relationship already exists. + """ + try: + if req.model_key_1 == req.model_key_2: + raise HTTPException(status_code=400, detail="Cannot relate a model to itself.") + + ApiDependencies.invoker.services.model_relationships.add_model_relationship( + req.model_key_1, + req.model_key_2, + ) + except ValueError as e: + raise HTTPException(status_code=409, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@model_relationships_router.delete( + "/", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 204: {"description": "The relationship was successfully removed"}, + 400: {"description": "Invalid model keys or self-referential relationship"}, + 404: {"description": "The relationship does not exist"}, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Remove Model Relationship", + description="Removes a **bidirectional** relationship between two models. The relationship must already exist." +) +async def remove_model_relationship( + req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect") +) -> None: + """ + Removes a bidirectional relationship between two model keys. + + - Raises 400 if attempting to unlink a model from itself. + - Raises 404 if the relationship was not found. + """ + try: + if req.model_key_1 == req.model_key_2: + raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.") + + ApiDependencies.invoker.services.model_relationships.remove_model_relationship( + req.model_key_1, + req.model_key_2, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@model_relationships_router.post( + "/batch", + operation_id="get_related_models_batch", + response_model=List[str], + responses={ + 200: { + "description": "Related model keys retrieved successfully", + "content": { + "application/json": { + "example": [ + "ca562b14-995e-4a42-90c1-9528f1a5921d", + "cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f", + "18ca7649-6a9e-47d5-bc17-41ab1e8cec81", + "7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1", + "c382eaa3-0e28-4ab0-9446-408667699aeb", + "71272e82-0e5f-46d5-bca9-9a61f4bd8a82", + "a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" + ] + } + } + }, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Get Related Model Keys (Batch)", + description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering." +) +async def get_related_models_batch( + req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections") + ) -> list[str]: + """ + Accepts multiple model keys and returns a flat list of all unique related keys. + + Useful when working with multiple selections in the UI or cross-model comparisons. + """ + try: + all_related: set[str] = set() + for key in req.model_keys: + related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key) + all_related.update(related) + return list(all_related) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index fda232496e7..22b77748cf2 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -22,6 +22,7 @@ download_queue, images, model_manager, + model_relationships, session_queue, style_presets, utilities, @@ -125,6 +126,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") +app.include_router(model_relationships.model_relationships_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 933c57b4a08..3dbb2686adf 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -27,6 +27,8 @@ from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase + from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase + from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC from invokeai.app.services.names.names_base import NameServiceBase from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase @@ -54,6 +56,8 @@ def __init__( logger: "Logger", model_images: "ModelImageFileStorageBase", model_manager: "ModelManagerServiceBase", + model_relationships: "ModelRelationshipsServiceABC", + model_relationship_records: "ModelRelationshipRecordStorageBase", download_queue: "DownloadQueueServiceBase", performance_statistics: "InvocationStatsServiceBase", session_queue: "SessionQueueBase", @@ -81,6 +85,8 @@ def __init__( self.logger = logger self.model_images = model_images self.model_manager = model_manager + self.model_relationships = model_relationships + self.model_relationship_records = model_relationship_records self.download_queue = download_queue self.performance_statistics = performance_statistics self.session_queue = session_queue diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_base.py b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py new file mode 100644 index 00000000000..7921523db46 --- /dev/null +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from invokeai.backend.model_manager.config import AnyModelConfig + +class ModelRelationshipRecordStorageBase(ABC): + """Abstract base class for model-to-model relationship record storage.""" + + @abstractmethod + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Creates a relationship between two models by keys.""" + pass + + @abstractmethod + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Removes a relationship between two models by keys.""" + pass + + @abstractmethod + def get_related_model_keys(self, model_key: str) -> list[str]: + """Gets all models keys related to a given model key.""" + pass + + @abstractmethod + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + """Get related model keys for multiple models given a list of keys.""" + pass + + @abstractmethod + def get_related_model_key_count(self, model_key: str) -> int: + """Gets the number of relations for a given model key.""" + pass + + """ Below are methods that use ModelConfigs instead of model keys, as convenience methods. + These methods are not required to be implemented, but they are potentially useful for later development. + They are not used in the current codebase.""" + + @abstractmethod + def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + """Creates a relationship between two models using ModelConfigs.""" + pass + + @abstractmethod + def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + """Removes a relationship between two models using ModelConfigs.""" + pass + + @abstractmethod + def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: + """Gets all model keys related to a given model using it's config.""" + pass + + @abstractmethod + def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: + """Gets the number of relations for a given model config.""" + pass \ No newline at end of file diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py new file mode 100644 index 00000000000..4f87f5ef4c2 --- /dev/null +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py @@ -0,0 +1,89 @@ +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +import sqlite3 +from typing import cast, TYPE_CHECKING +from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase +if TYPE_CHECKING: + from invokeai.backend.model_manager.config import AnyModelConfig + +class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase): + def __init__(self, db: SqliteDatabase) -> None: + super().__init__() + self._conn = db.conn + + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + if model_key_1 == model_key_2: + raise ValueError("Cannot relate a model to itself.") + a, b = sorted([model_key_1, model_key_2]) + try: + cursor = self._conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)", + (a, b), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + a, b = sorted([model_key_1, model_key_2]) + try: + cursor = self._conn.cursor() + cursor.execute( + "DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?", + (a, b), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + def get_related_model_keys(self, model_key: str) -> list[str]: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ? + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ? + """, + (model_key, model_key), + ) + return [row[0] for row in cursor.fetchall()] + + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + cursor = self._conn.cursor() + + key_list = ','.join('?' for _ in model_keys) + cursor.execute(f""" + SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list}) + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list}) + """, + model_keys + model_keys + ) + return [row[0] for row in cursor.fetchall()] + + def get_related_model_key_count(self, model_key: str) -> int: + cursor = self._conn.execute( + """ + SELECT COUNT(*) FROM ( + SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ? + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ? + ) + """, + (model_key, model_key), + ) + return cast(int, cursor.fetchone()[0]) + + def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + self.add_model_relationship(model_1.key, model_2.key) + + def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + self.remove_model_relationship(model_1.key, model_2.key) + + def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: + return self.get_related_model_keys(model.key) + + def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: + return self.get_related_model_key_count(model.key) \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_base.py b/invokeai/app/services/model_relationships/model_relationships_base.py new file mode 100644 index 00000000000..b60404d5710 --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_base.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod + +from invokeai.backend.model_manager.config import AnyModelConfig + + +class ModelRelationshipsServiceABC(ABC): + """High-level service for managing model-to-model relationships.""" + + @abstractmethod + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Creates a relationship between two models keys.""" + pass + + @abstractmethod + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Removes a relationship between two models keys.""" + pass + + @abstractmethod + def get_related_model_keys(self, model_key: str) -> list[str]: + """Gets all models keys related to a given model key.""" + pass + + @abstractmethod + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + """Get related model keys for multiple models.""" + pass + + @abstractmethod + def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + """Creates a relationship from model objects.""" + pass + + @abstractmethod + def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + """Removes a relationship from model objects.""" + pass + + @abstractmethod + def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]: + """Gets all model keys related to a given model object.""" + pass \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_common.py b/invokeai/app/services/model_relationships/model_relationships_common.py new file mode 100644 index 00000000000..6170b549d92 --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_common.py @@ -0,0 +1,8 @@ +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +from datetime import datetime + + +class ModelRelationship(BaseModelExcludeNull): + model_key_1: str + model_key_2: str + created_at: datetime \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_default.py b/invokeai/app/services/model_relationships/model_relationships_default.py new file mode 100644 index 00000000000..1e6f338661e --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_default.py @@ -0,0 +1,30 @@ +from invokeai.backend.model_manager.config import AnyModelConfig +from .model_relationships_base import ModelRelationshipsServiceABC +from invokeai.app.services.invoker import Invoker + +class ModelRelationshipsService(ModelRelationshipsServiceABC): + __invoker: Invoker + + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker + + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2) + + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2) + + def get_related_model_keys(self, model_key: str) -> list[str]: + return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key) + + def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + self.add_model_relationship(model_1.key, model_2.key) + + def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + self.remove_model_relationship(model_1.key, model_2.key) + + def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]: + return self.get_related_model_keys(model.key) + + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys) \ No newline at end of file diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 233bb72cda2..7c825616c16 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -22,6 +22,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -61,6 +62,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_17()) migrator.register_migration(build_migration_18()) migrator.register_migration(build_migration_19(app_config=config)) + migrator.register_migration(build_migration_20()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py new file mode 100644 index 00000000000..6b2050fac0d --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py @@ -0,0 +1,37 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +class Migration20Callback: + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + """ + -- many-to-many relationship table for models + CREATE TABLE IF NOT EXISTS model_relationships ( + -- model_key_1 and model_key_2 are the same as the key(primary key) in the models table + model_key_1 TEXT NOT NULL, + model_key_2 TEXT NOT NULL, + created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + PRIMARY KEY (model_key_1, model_key_2), + -- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates + FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE, + FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """ + -- Creates an index to keep performance equal when searching for model_key_1 or model_key_2 + CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2 + ON model_relationships(model_key_2) + """ + ) + + +def build_migration_20() -> Migration: + return Migration( + from_version=19, + to_version=20, + callback=Migration20Callback(), + ) \ No newline at end of file diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index b885df54528..f6d07a4831d 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -840,6 +840,8 @@ "predictionType": "Prediction Type", "prune": "Prune", "pruneTooltip": "Prune finished imports from queue", + "relatedModels": "Related Models", + "showOnlyRelatedModels": "Related", "repo_id": "Repo ID", "repoVariant": "Repo Variant", "scanFolder": "Scan Folder", diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts new file mode 100644 index 00000000000..af14f5460e6 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts @@ -0,0 +1,92 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import type { GroupBase } from 'chakra-react-select'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { useTranslation } from 'react-i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +import { useGroupedModelCombobox } from './useGroupedModelCombobox'; +import { useRelatedModelKeys } from './useRelatedModelKeys'; +import { useSelectedModelKeys } from './useSelectedModelKeys'; + +type UseRelatedGroupedModelComboboxArg = { + modelConfigs: T[]; + selectedModel?: ModelIdentifierField | null; + onChange: (value: T | null) => void; + getIsDisabled?: (model: T) => boolean; + isLoading?: boolean; + groupByType?: boolean; +}; + +// Custom hook to overlay the grouped model combobox with related models on top! +// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models +// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag. + +type UseRelatedGroupedModelComboboxReturn = { + value: ComboboxOption | undefined | null; + options: GroupBase[]; + onChange: ComboboxOnChange; + placeholder: string; + noOptionsMessage: () => string; +}; + +export function useRelatedGroupedModelCombobox({ + modelConfigs, + selectedModel, + onChange, + isLoading = false, + getIsDisabled, + groupByType, +}: UseRelatedGroupedModelComboboxArg): UseRelatedGroupedModelComboboxReturn { + const { t } = useTranslation(); + + const selectedKeys = useSelectedModelKeys(); + + const relatedKeys = useRelatedModelKeys(selectedKeys); + + // Base grouped options + const base = useGroupedModelCombobox({ + modelConfigs, + selectedModel, + onChange, + getIsDisabled, + isLoading, + groupByType, + }); + + // If no related models selected, just return base + if (relatedKeys.size === 0) { + return base; + } + + const relatedOptions: ComboboxOption[] = []; + const updatedGroups: GroupBase[] = []; + + for (const group of base.options) { + const remainingOptions: ComboboxOption[] = []; + + for (const option of group.options) { + if (relatedKeys.has(option.value)) { + relatedOptions.push({ ...option, label: `* ${option.label}` }); + } else { + remainingOptions.push(option); + } + } + + if (remainingOptions.length > 0) { + updatedGroups.push({ + label: group.label, + options: remainingOptions, + }); + } + } + + const finalOptions: GroupBase[] = + relatedOptions.length > 0 + ? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups] + : updatedGroups; + + return { + ...base, + options: finalOptions, + }; +} diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts new file mode 100644 index 00000000000..fc0711b969e --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts @@ -0,0 +1,14 @@ +import { useMemo } from 'react'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; + +/** + * Fetches related model keys for a given set of selected model keys. + * Returns a Set for fast lookup. + */ +export const useRelatedModelKeys = (selectedKeys: Set) => { + const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], { + skip: selectedKeys.size === 0, + }); + + return useMemo(() => new Set(related), [related]); +}; diff --git a/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts b/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts new file mode 100644 index 00000000000..83e1d3ac597 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts @@ -0,0 +1,34 @@ +import { useAppSelector } from 'app/store/storeHooks'; + +/** + * Gathers all currently selected model keys from parameters and loras. + * This includes the main model, VAE, refiner model, controlnet, and loras. + */ +export const useSelectedModelKeys = () => { + return useAppSelector((state) => { + const keys = new Set(); + const main = state.params.model; + const vae = state.params.vae; + const refiner = state.params.refinerModel; + const controlnet = state.params.controlLora; + const loras = state.loras.loras.map((l) => l.model); + + if (main) { + keys.add(main.key); + } + if (vae) { + keys.add(vae.key); + } + if (refiner) { + keys.add(refiner.key); + } + if (controlnet) { + keys.add(controlnet.key); + } + for (const lora of loras) { + keys.add(lora.key); + } + + return keys; + }); +}; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 90dacac0fa4..c6e8091c824 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -3,7 +3,7 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton'; @@ -38,7 +38,7 @@ const LoRASelect = () => { [dispatch] ); - const { options, onChange } = useGroupedModelCombobox({ + const { options, onChange } = useRelatedGroupedModelCombobox({ modelConfigs, getIsDisabled, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index 42e0689733e..cb17dcf00c0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -11,6 +11,7 @@ import type { AnyModelConfig } from 'services/api/types'; import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings'; import { ModelAttrView } from './ModelAttrView'; +import { RelatedModels } from './RelatedModels'; type Props = { modelConfig: AnyModelConfig; @@ -83,6 +84,9 @@ export const ModelView = memo(({ modelConfig }: Props) => { )} )} + + + ); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx new file mode 100644 index 00000000000..f1846c57d76 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx @@ -0,0 +1,300 @@ +/** + * RelatedModels.tsx + * + * Panel for managing and displaying model-to-model relationships. + * + * Allows adding/removing bidirectional links between models, organized visually + * with color-coded tags, dividers between types, and sorted dropdown selection. + */ + +import { + Box, + Button, + Combobox, + Divider, + Flex, + FormControl, + FormErrorMessage, + FormLabel, + Tag, + TagCloseButton, + TagLabel, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiPlusBold } from 'react-icons/pi'; +import { + useAddModelRelationshipMutation, + useGetRelatedModelIdsQuery, + useRemoveModelRelationshipMutation, +} from 'services/api/endpoints/modelRelationships'; +import { useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +type Props = { + modelConfig: AnyModelConfig; +}; + +// Determines if two models are compatible for relationship linking based on their base type. +// +// Models with a base of 'any' are considered universally compatible. +// This is a known flaw: 'any'-based links may allow relationships that are +// meaningless in practice and could bloat the database over time. +// +// TODO: In the future, refine this logic to more strictly validate +// relationships based on model types or actual usage patterns. +const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => { + if (a.base === 'any' || b.base === 'any') { + return true; + } + return a.base === b.base; +}; + +export const RelatedModels = memo(({ modelConfig }: Props) => { + const { t } = useTranslation(); + const [addModelRelationship, { isLoading: isAdding }] = useAddModelRelationshipMutation(); + const [removeModelRelationship, { isLoading: isRemoving }] = useRemoveModelRelationshipMutation(); + const isLoading = isAdding || isRemoving; + const [selectedKey, setSelectedKey] = useState(''); + const { data: modelConfigs } = useGetModelConfigsQuery(); + const { data: relatedModels = [] } = useGetRelatedModelIdsQuery(modelConfig.key); + const relatedIDs = useMemo(() => new Set(relatedModels), [relatedModels]); + // Used to prioritize certain model types in UI sorting + const MODEL_TYPE_PRIORITY = useMemo(() => ['main', 'lora'], []); + + //Get all modelConfigs that are not already related to the current model. + const availableModels = useMemo(() => { + if (!modelConfigs) { + return []; + } + + return Object.values(modelConfigs.entities).filter( + (m): m is AnyModelConfig => + !!m && + m.key !== modelConfig.key && + !relatedIDs.has(m.key) && + isBaseCompatible(modelConfig, m) && + !(modelConfig.type === 'main' && m.type === 'main') // still block main↔main + ); + }, [modelConfigs, modelConfig, relatedIDs]); + + // Tracks validation errors for current input (e.g., duplicate key or no selection). + const errors = useMemo(() => { + const errs: string[] = []; + if (!selectedKey) { + return errs; + } + if (relatedIDs.has(selectedKey)) { + errs.push('Item already promoted'); + } + return errs; + }, [selectedKey, relatedIDs]); + + // Handles linking a selected model to the current one via API. + const handleAdd = useCallback(async () => { + const target = availableModels.find((m) => m.key === selectedKey); + if (!target) { + return; + } + + setSelectedKey(''); + await Promise.all([addModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); + }, [modelConfig, availableModels, addModelRelationship, selectedKey]); + + const { + options, + onChange: comboboxOnChange, + placeholder, + noOptionsMessage, + } = useGroupedModelCombobox({ + modelConfigs: availableModels, + selectedModel: null, + onChange: (model) => { + if (!model) { + return; + } + setSelectedKey(model.key); + }, + groupByType: true, + }); + + // Unlinks an existing related model via API. + const handleRemove = useCallback( + async (id: string) => { + const target = modelConfigs?.entities[id]; + if (!target) { + return; + } + + await Promise.all([removeModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); + }, + [modelConfig, modelConfigs, removeModelRelationship] + ); + + // Finds the selected model's combobox option to control current dropdown state. + const selectedOption = useMemo(() => { + return options.flatMap((group) => group.options).find((o) => o.value === selectedKey) ?? null; + }, [selectedKey, options]); + + const makeRemoveHandler = useCallback((id: string) => () => handleRemove(id), [handleRemove]); + + // Defines custom tag colors for model types in the UI. + // + // The default UI color scheme (mostly grey and orange) felt too flat, + // so this mapping provides a slightly more expressive color flow. + // + // Note: This is purely aesthetic. Safe to remove if project preferences change. + const getModelTagColor = (type: string): string => { + switch (type) { + case 'main': + case 'checkpoint': + return 'orange'; + case 'lora': + case 'lycoris': + return 'purple'; + case 'embedding': + case 'embedding_file': + return 'teal'; + case 'vae': + return 'blue'; + case 'controlnet': + case 'ip_adapter': + case 't2i_adapter': + return 'cyan'; + case 'onnx': + case 'bnb_quantized_int8b': + case 'bnb_quantized_nf4b': + case 'gguf_quantized': + return 'pink'; + case 't5_encoder': + case 'clip_embed': + case 'clip_vision': + case 'siglip': + return 'green'; + default: + return 'base'; + } + }; + + // Force group priority order: Main first, then LoRA + const getTypeFromLabel = (label: string): string => label.split('/')[1]?.trim().toLowerCase() || ''; + + const sortedOptions = useMemo(() => { + return [...options].sort((a, b) => { + const aType = getTypeFromLabel(a.label ?? ''); + const bType = getTypeFromLabel(b.label ?? ''); + + const aIndex = MODEL_TYPE_PRIORITY.indexOf(aType); + const bIndex = MODEL_TYPE_PRIORITY.indexOf(bType); + + const aScore = aIndex === -1 ? 99 : aIndex; + const bScore = bIndex === -1 ? 99 : bIndex; + + return aScore - bScore; + }); + }, [options, MODEL_TYPE_PRIORITY]); + + return ( + + {t('modelManager.relatedModels')} + 0}> + + + + + {errors.map((error) => ( + {error} + ))} + + + + { + // Render the related model tags as styled components. + // + // Models are grouped visually by type, sorted with 'main' and 'lora' types at the front. + // A vertical Divider is inserted when the type changes between adjacent models. + // Tags include: + // - Colored background based on model type (via getModelTagColor) + // - Tooltip showing ": " + // - Ellipsis-truncated tag name for compact layout + // - A close button to remove the relationship + [...relatedModels] + .sort((aKey, bKey) => { + const a = modelConfigs?.entities[aKey]; + const b = modelConfigs?.entities[bKey]; + if (!a || !b) { + return 0; + } + + // Floats Mains and LoRAs to the front + const aPriority = MODEL_TYPE_PRIORITY.indexOf(a.type); + const bPriority = MODEL_TYPE_PRIORITY.indexOf(b.type); + + const aScore = aPriority === -1 ? 99 : aPriority; + const bScore = bPriority === -1 ? 99 : bPriority; + + return aScore - bScore || a.type.localeCompare(b.type) || a.name.localeCompare(b.name); + }) + .reduce((acc, id, index, arr) => { + const model = modelConfigs?.entities[id]; + if (!model) { + return acc; + } + + const modelName = model.name ?? id; + const modelType = model.type ?? 'unknown'; + const modelTypeLabel = modelType.replace(/_/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()); + + // Create a divider if the previous model is of a different type. Just a small dash of visual flair. + const prevId = index > 0 ? arr[index - 1] : undefined; + const prevModel = prevId ? modelConfigs?.entities[prevId] : null; + const needsDivider = prevModel && prevModel.type !== model.type; + + if (needsDivider) { + acc.push(); + } + + acc.push( + + + + {modelName} + + + + + ); + + return acc; + }, []) + } + + + + ); +}); + +RelatedModels.displayName = 'RelatedModels'; diff --git a/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts b/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts new file mode 100644 index 00000000000..ffb815aac25 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts @@ -0,0 +1,67 @@ +/** + * modelRelationships.ts + * + * RTK Query API slice for managing model-to-model relationships. + * + * Endpoints provided: + * - Fetch related models for a single model + * - Add a relationship between two models + * - Remove a relationship between two models + * - Fetch related models for multiple models in batch + * + * Provides and invalidates cache tags for seamless UI updates after add/remove operations. + */ + +import { api } from '..'; + +const REL_TAG = 'ModelRelationships'; // Needed for UI updates on relationship changes. + +const modelRelationshipsApi = api.injectEndpoints({ + endpoints: (build) => ({ + getRelatedModelIds: build.query({ + query: (model_key) => `/api/v1/model_relationships/i/${model_key}`, + providesTags: (result, error, model_key) => [{ type: REL_TAG, id: model_key }], + }), + + addModelRelationship: build.mutation({ + query: (payload) => ({ + url: `/api/v1/model_relationships/`, + method: 'POST', + body: payload, + }), + invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [ + { type: REL_TAG, id: model_key_1 }, + { type: REL_TAG, id: model_key_2 }, + ], + }), + + removeModelRelationship: build.mutation({ + query: (payload) => ({ + url: `/api/v1/model_relationships/`, + method: 'DELETE', + body: payload, + }), + invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [ + { type: REL_TAG, id: model_key_1 }, + { type: REL_TAG, id: model_key_2 }, + ], + }), + + getRelatedModelIdsBatch: build.query({ + query: (model_keys) => ({ + url: `/api/v1/model_relationships/batch`, + method: 'POST', + body: { model_keys }, + }), + providesTags: (result, error, model_keys) => model_keys.map((key) => ({ type: 'ModelRelationships', id: key })), + }), + }), + overrideExisting: false, +}); + +export const { + useGetRelatedModelIdsQuery, + useAddModelRelationshipMutation, + useRemoveModelRelationshipMutation, + useGetRelatedModelIdsBatchQuery, +} = modelRelationshipsApi; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 8740e465b6f..7bc59202a46 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -34,6 +34,7 @@ const tagTypes = [ 'InvocationCacheStatus', 'ModelConfig', 'ModelInstalls', + 'ModelRelationships', 'ModelScanFolderResults', 'T2IAdapterModel', 'MainModel', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 7eb6a9e4689..ba2c63f73d3 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -867,6 +867,70 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/model_relationships/i/{model_key}": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Related Models + * @description Get a list of model keys related to a given model. + */ + get: operations["get_related_models"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/model_relationships/": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Add Model Relationship + * @description Creates a **bidirectional** relationship between two models, allowing each to reference the other as related. + */ + post: operations["add_model_relationship_api_v1_model_relationships__post"]; + /** + * Remove Model Relationship + * @description Removes a **bidirectional** relationship between two models. The relationship must already exist. + */ + delete: operations["remove_model_relationship_api_v1_model_relationships__delete"]; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/model_relationships/batch": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Get Related Model Keys (Batch) + * @description Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering. + */ + post: operations["get_related_models_batch"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/app/version": { parameters: { query?: never; @@ -11922,14 +11986,14 @@ export type components = { * Convert Cache Dir * Format: path * @description Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions). - * @default models/.convert_cache + * @default models\.convert_cache */ convert_cache_dir?: string; /** * Download Cache Dir * Format: path * @description Path to the directory that contains dynamically downloaded models. - * @default models/.download_cache + * @default models\.download_cache */ download_cache_dir?: string; /** @@ -16442,6 +16506,27 @@ export type components = { */ config_path?: string | null; }; + /** ModelRelationshipBatchRequest */ + ModelRelationshipBatchRequest: { + /** + * Model Keys + * @description List of model keys to fetch related models for + */ + model_keys: string[]; + }; + /** ModelRelationshipCreateRequest */ + ModelRelationshipCreateRequest: { + /** + * Model Key 1 + * @description The key of the first model in the relationship + */ + model_key_1: string; + /** + * Model Key 2 + * @description The key of the second model in the relationship + */ + model_key_2: string; + }; /** * ModelRepoVariant * @description Various hugging face variants on the diffusers format. @@ -23618,6 +23703,181 @@ export interface operations { }; }; }; + get_related_models: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The key of the model to get relationships for */ + model_key: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description A list of related model keys was retrieved successfully */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description The specified model could not be found */ + 404: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + add_model_relationship_api_v1_model_relationships__post: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipCreateRequest"]; + }; + }; + responses: { + /** @description The relationship was successfully created */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Invalid model keys or self-referential relationship */ + 400: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description The relationship already exists */ + 409: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + remove_model_relationship_api_v1_model_relationships__delete: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipCreateRequest"]; + }; + }; + responses: { + /** @description The relationship was successfully removed */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Invalid model keys or self-referential relationship */ + 400: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description The relationship does not exist */ + 404: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + get_related_models_batch: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipBatchRequest"]; + }; + }; + responses: { + /** @description Related model keys retrieved successfully */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; app_version: { parameters: { query?: never; From 89d6070d281bfd5c09b97d1dca1b29bb824919b5 Mon Sep 17 00:00:00 2001 From: Lucian Hardy Date: Tue, 6 May 2025 16:15:35 +1000 Subject: [PATCH 2/3] chore(backend): Removed unused model_relationship methods removed unused AnyModelConfig related methods, removed unused get_related_model_key_count method. --- .../model_relationship_records_base.py | 33 ------------------- .../model_relationship_records_sqlite.py | 30 +---------------- .../model_relationships_base.py | 18 ---------- 3 files changed, 1 insertion(+), 80 deletions(-) diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_base.py b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py index 7921523db46..ee60f2aca1d 100644 --- a/invokeai/app/services/model_relationship_records/model_relationship_records_base.py +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py @@ -1,8 +1,4 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from invokeai.backend.model_manager.config import AnyModelConfig class ModelRelationshipRecordStorageBase(ABC): """Abstract base class for model-to-model relationship record storage.""" @@ -25,33 +21,4 @@ def get_related_model_keys(self, model_key: str) -> list[str]: @abstractmethod def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: """Get related model keys for multiple models given a list of keys.""" - pass - - @abstractmethod - def get_related_model_key_count(self, model_key: str) -> int: - """Gets the number of relations for a given model key.""" - pass - - """ Below are methods that use ModelConfigs instead of model keys, as convenience methods. - These methods are not required to be implemented, but they are potentially useful for later development. - They are not used in the current codebase.""" - - @abstractmethod - def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: - """Creates a relationship between two models using ModelConfigs.""" - pass - - @abstractmethod - def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: - """Removes a relationship between two models using ModelConfigs.""" - pass - - @abstractmethod - def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: - """Gets all model keys related to a given model using it's config.""" - pass - - @abstractmethod - def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: - """Gets the number of relations for a given model config.""" pass \ No newline at end of file diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py index 4f87f5ef4c2..cd6f3cf5ec9 100644 --- a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py @@ -1,9 +1,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase import sqlite3 -from typing import cast, TYPE_CHECKING from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase -if TYPE_CHECKING: - from invokeai.backend.model_manager.config import AnyModelConfig class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase): def __init__(self, db: SqliteDatabase) -> None: @@ -61,29 +58,4 @@ def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: """, model_keys + model_keys ) - return [row[0] for row in cursor.fetchall()] - - def get_related_model_key_count(self, model_key: str) -> int: - cursor = self._conn.execute( - """ - SELECT COUNT(*) FROM ( - SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ? - UNION - SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ? - ) - """, - (model_key, model_key), - ) - return cast(int, cursor.fetchone()[0]) - - def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: - self.add_model_relationship(model_1.key, model_2.key) - - def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: - self.remove_model_relationship(model_1.key, model_2.key) - - def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: - return self.get_related_model_keys(model.key) - - def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: - return self.get_related_model_key_count(model.key) \ No newline at end of file + return [row[0] for row in cursor.fetchall()] \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_base.py b/invokeai/app/services/model_relationships/model_relationships_base.py index b60404d5710..94ba6774b6c 100644 --- a/invokeai/app/services/model_relationships/model_relationships_base.py +++ b/invokeai/app/services/model_relationships/model_relationships_base.py @@ -1,8 +1,5 @@ from abc import ABC, abstractmethod -from invokeai.backend.model_manager.config import AnyModelConfig - - class ModelRelationshipsServiceABC(ABC): """High-level service for managing model-to-model relationships.""" @@ -24,19 +21,4 @@ def get_related_model_keys(self, model_key: str) -> list[str]: @abstractmethod def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: """Get related model keys for multiple models.""" - pass - - @abstractmethod - def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: - """Creates a relationship from model objects.""" - pass - - @abstractmethod - def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: - """Removes a relationship from model objects.""" - pass - - @abstractmethod - def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]: - """Gets all model keys related to a given model object.""" pass \ No newline at end of file From 3dc2e7121e30eb3145fa267073b919ba3390590f Mon Sep 17 00:00:00 2001 From: Lucian Hardy Date: Tue, 6 May 2025 16:17:13 +1000 Subject: [PATCH 3/3] chore(ui): Refactor RelatedModels.tsx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major cleanup of RelatedModels.tsx for improved readability, structure, and maintainability. Dried out repetitive logic Consolidated model type sorting into reusable helpers Added disallowed model type relationships to prevent broken connections (e.g. VAE ↔ LoRA) - Aware this introduces a new constraint—open to feedback (see PR comment) Some naming and types may still need refinement; happy to revisit --- .../subpanels/ModelPanel/RelatedModels.tsx | 323 ++++++++++-------- 1 file changed, 187 insertions(+), 136 deletions(-) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx index f1846c57d76..74dd23f40c5 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx @@ -37,6 +37,13 @@ type Props = { modelConfig: AnyModelConfig; }; +type ModelGroup = { + type: string; + label: string; + color: string; + models: AnyModelConfig[]; +}; + // Determines if two models are compatible for relationship linking based on their base type. // // Models with a base of 'any' are considered universally compatible. @@ -52,6 +59,49 @@ const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => { return a.base === b.base; }; +// Drying out and setting up for potential export + +// Defines custom tag colors for model types in the UI. +// +// The default UI color scheme (mostly grey and orange) felt too flat, +// so this mapping provides a slightly more expressive color flow. +// +// Note: This is purely aesthetic. Safe to remove if project preferences change. +const getModelTagColor = (type: string): string => { + switch (type) { + case 'main': + case 'checkpoint': + return 'orange'; + case 'lora': + case 'lycoris': + return 'purple'; + case 'embedding': + case 'embedding_file': + return 'teal'; + case 'vae': + return 'blue'; + case 'controlnet': + case 'ip_adapter': + case 't2i_adapter': + return 'cyan'; + case 'onnx': + case 'bnb_quantized_int8b': + case 'bnb_quantized_nf4b': + case 'gguf_quantized': + return 'pink'; + case 't5_encoder': + case 'clip_embed': + case 'clip_vision': + case 'siglip': + return 'green'; + default: + return 'base'; + } +}; + +// Extracts model type from a label string (e.g., 'Base/LoRA' → 'lora') +const getTypeFromLabel = (label: string): string => label.split('/')[1]?.trim().toLowerCase() || ''; + export const RelatedModels = memo(({ modelConfig }: Props) => { const { t } = useTranslation(); const [addModelRelationship, { isLoading: isAdding }] = useAddModelRelationshipMutation(); @@ -61,14 +111,49 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { const { data: modelConfigs } = useGetModelConfigsQuery(); const { data: relatedModels = [] } = useGetRelatedModelIdsQuery(modelConfig.key); const relatedIDs = useMemo(() => new Set(relatedModels), [relatedModels]); - // Used to prioritize certain model types in UI sorting + + // Defines model types to prioritize first in UI sorting. + // Types not listed here will appear afterward in default order. const MODEL_TYPE_PRIORITY = useMemo(() => ['main', 'lora'], []); + // Defines disallowed connection types. + const DISALLOWED_RELATIONSHIPS = useMemo( + () => + new Set([ + 'main|main', + 'vae|vae', + 'controlnet|controlnet', + 'clip_vision|clip_vision', + 'control_lora|control_lora', + 'clip_embed|clip_embed', + 'spandrel_image_to_image|spandrel_image_to_image', + 'siglip|siglip', + 'flux_redux|flux_redux', + ]), + [] + ); + + // Drying out sorting + const prioritySort = useCallback( + (a: string, b: string): number => { + const aIndex = MODEL_TYPE_PRIORITY.indexOf(a); + const bIndex = MODEL_TYPE_PRIORITY.indexOf(b); + + const aScore = aIndex === -1 ? 99 : aIndex; + const bScore = bIndex === -1 ? 99 : bIndex; + + return aScore - bScore; + }, + [MODEL_TYPE_PRIORITY] + ); + //Get all modelConfigs that are not already related to the current model. const availableModels = useMemo(() => { if (!modelConfigs) { return []; } + const isDisallowedRelationship = (a: string, b: string): boolean => + DISALLOWED_RELATIONSHIPS.has(`${a}|${b}`) || DISALLOWED_RELATIONSHIPS.has(`${b}|${a}`); return Object.values(modelConfigs.entities).filter( (m): m is AnyModelConfig => @@ -76,9 +161,9 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { m.key !== modelConfig.key && !relatedIDs.has(m.key) && isBaseCompatible(modelConfig, m) && - !(modelConfig.type === 'main' && m.type === 'main') // still block main↔main + !isDisallowedRelationship(modelConfig.type, m.type) ); - }, [modelConfigs, modelConfig, relatedIDs]); + }, [modelConfigs, modelConfig, relatedIDs, DISALLOWED_RELATIONSHIPS]); // Tracks validation errors for current input (e.g., duplicate key or no selection). const errors = useMemo(() => { @@ -100,7 +185,7 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { } setSelectedKey(''); - await Promise.all([addModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); + await addModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key }); }, [modelConfig, availableModels, addModelRelationship, selectedKey]); const { @@ -120,81 +205,65 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { groupByType: true, }); - // Unlinks an existing related model via API. - const handleRemove = useCallback( - async (id: string) => { - const target = modelConfigs?.entities[id]; - if (!target) { - return; - } - - await Promise.all([removeModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); - }, - [modelConfig, modelConfigs, removeModelRelationship] - ); - // Finds the selected model's combobox option to control current dropdown state. const selectedOption = useMemo(() => { return options.flatMap((group) => group.options).find((o) => o.value === selectedKey) ?? null; }, [selectedKey, options]); - const makeRemoveHandler = useCallback((id: string) => () => handleRemove(id), [handleRemove]); - - // Defines custom tag colors for model types in the UI. - // - // The default UI color scheme (mostly grey and orange) felt too flat, - // so this mapping provides a slightly more expressive color flow. - // - // Note: This is purely aesthetic. Safe to remove if project preferences change. - const getModelTagColor = (type: string): string => { - switch (type) { - case 'main': - case 'checkpoint': - return 'orange'; - case 'lora': - case 'lycoris': - return 'purple'; - case 'embedding': - case 'embedding_file': - return 'teal'; - case 'vae': - return 'blue'; - case 'controlnet': - case 'ip_adapter': - case 't2i_adapter': - return 'cyan'; - case 'onnx': - case 'bnb_quantized_int8b': - case 'bnb_quantized_nf4b': - case 'gguf_quantized': - return 'pink'; - case 't5_encoder': - case 'clip_embed': - case 'clip_vision': - case 'siglip': - return 'green'; - default: - return 'base'; + const sortedOptions = useMemo(() => { + return [...options].sort((a, b) => prioritySort(getTypeFromLabel(a.label ?? ''), getTypeFromLabel(b.label ?? ''))); + }, [options, prioritySort]); + + const groupedModelConfigs = useMemo(() => { + if (!modelConfigs) { + return []; } - }; - // Force group priority order: Main first, then LoRA - const getTypeFromLabel = (label: string): string => label.split('/')[1]?.trim().toLowerCase() || ''; + const models = [...relatedModels].map((id) => modelConfigs.entities[id]).filter((m): m is AnyModelConfig => !!m); - const sortedOptions = useMemo(() => { - return [...options].sort((a, b) => { - const aType = getTypeFromLabel(a.label ?? ''); - const bType = getTypeFromLabel(b.label ?? ''); + models.sort((a, b) => prioritySort(a.type, b.type) || a.type.localeCompare(b.type) || a.name.localeCompare(b.name)); - const aIndex = MODEL_TYPE_PRIORITY.indexOf(aType); - const bIndex = MODEL_TYPE_PRIORITY.indexOf(bType); + const groupsMap = new Map(); - const aScore = aIndex === -1 ? 99 : aIndex; - const bScore = bIndex === -1 ? 99 : bIndex; + for (const model of models) { + if (!groupsMap.has(model.type)) { + groupsMap.set(model.type, { + type: model.type, + label: model.type.replace(/_/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()), + color: getModelTagColor(model.type), + models: [], + }); + } + groupsMap.get(model.type)!.models.push(model); + } - return aScore - bScore; - }); - }, [options, MODEL_TYPE_PRIORITY]); + return Array.from(groupsMap.values()); + }, [modelConfigs, relatedModels, prioritySort]); + + const removeHandlers = useMemo(() => { + const map = new Map void>(); + if (!modelConfigs) { + return map; + } + + for (const group of groupedModelConfigs) { + for (const model of group.models) { + map.set(model.key, () => { + const target = modelConfigs.entities[model.key]; + if (!target) { + return; + } + + removeModelRelationship({ + model_key_1: modelConfig.key, + model_key_2: model.key, + }).unwrap(); + }); + } + } + + return map; + }, [groupedModelConfigs, modelConfig.key, modelConfigs, removeModelRelationship]); return ( @@ -204,7 +273,7 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { @@ -224,77 +293,59 @@ export const RelatedModels = memo(({ modelConfig }: Props) => { - { - // Render the related model tags as styled components. - // - // Models are grouped visually by type, sorted with 'main' and 'lora' types at the front. - // A vertical Divider is inserted when the type changes between adjacent models. - // Tags include: - // - Colored background based on model type (via getModelTagColor) - // - Tooltip showing ": " - // - Ellipsis-truncated tag name for compact layout - // - A close button to remove the relationship - [...relatedModels] - .sort((aKey, bKey) => { - const a = modelConfigs?.entities[aKey]; - const b = modelConfigs?.entities[bKey]; - if (!a || !b) { - return 0; - } - - // Floats Mains and LoRAs to the front - const aPriority = MODEL_TYPE_PRIORITY.indexOf(a.type); - const bPriority = MODEL_TYPE_PRIORITY.indexOf(b.type); - - const aScore = aPriority === -1 ? 99 : aPriority; - const bScore = bPriority === -1 ? 99 : bPriority; - - return aScore - bScore || a.type.localeCompare(b.type) || a.name.localeCompare(b.name); - }) - .reduce((acc, id, index, arr) => { - const model = modelConfigs?.entities[id]; - if (!model) { - return acc; - } - - const modelName = model.name ?? id; - const modelType = model.type ?? 'unknown'; - const modelTypeLabel = modelType.replace(/_/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()); - - // Create a divider if the previous model is of a different type. Just a small dash of visual flair. - const prevId = index > 0 ? arr[index - 1] : undefined; - const prevModel = prevId ? modelConfigs?.entities[prevId] : null; - const needsDivider = prevModel && prevModel.type !== model.type; - - if (needsDivider) { - acc.push(); - } - - acc.push( - - - - {modelName} - - - - - ); - - return acc; - }, []) - } + {groupedModelConfigs.map((group, i) => { + const withDivider = i < groupedModelConfigs.length - 1; + + return ( + + + {withDivider && } + + ); + })} ); }); +const ModelTag = ({ + model, + onRemove, + isLoading, +}: { + model: AnyModelConfig; + onRemove: () => void; + isLoading: boolean; +}) => { + return ( + + + + {model.name} + + + + + ); +}; + +const ModelTagGroup = ({ + group, + isLoading, + removeHandlers, +}: { + group: ModelGroup; + isLoading: boolean; + removeHandlers: Map void>; +}) => { + return ( + + {group.models.map((model) => ( + + ))} + + ); +}; + RelatedModels.displayName = 'RelatedModels';