diff --git a/.gitignore b/.gitignore index 7713158b6..046389d07 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ databases/ **/.DS_Store .history addresses/ -tools/mcp.json \ No newline at end of file +tools/mcp.json +tools/image_generation_config.json \ No newline at end of file diff --git a/llm-service/app/main.py b/llm-service/app/main.py index 35b550edf..34490e7e3 100644 --- a/llm-service/app/main.py +++ b/llm-service/app/main.py @@ -61,7 +61,7 @@ from typing import AsyncGenerator import opik -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request, Response, HTTPException from fastapi.middleware.cors import CORSMiddleware from uvicorn.logging import DefaultFormatter @@ -198,3 +198,46 @@ async def log_request_received( app.include_router(index.router) + +# Serve cached images via FastAPI from the llm-service/cache directory +_cache_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "cache")) +os.makedirs(_cache_dir, exist_ok=True) + + +@app.get("/cache/{filename:path}") +async def serve_cached_image(filename: str) -> Response: + """Serve cached images with proper MIME type detection.""" + import mimetypes + from pathlib import Path + + file_path = Path(_cache_dir) / filename + + if not file_path.exists() or not file_path.is_file(): + raise HTTPException(status_code=404, detail="Image not found") + + # Determine MIME type based on file extension + mime_type, _ = mimetypes.guess_type(str(file_path)) + + # Fallback MIME types for common image formats + if not mime_type: + file_ext = file_path.suffix.lower() + mime_type = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".bmp": "image/bmp", + ".webp": "image/webp", + ".svg": "image/svg+xml", + ".ico": "image/x-icon", + }.get(file_ext, "application/octet-stream") + + # Read file content + with open(file_path, "rb") as f: + content = f.read() + + return Response( + content=content, + media_type=mime_type, + headers={"Cache-Control": "public, max-age=31536000"}, # Cache for 1 year + ) diff --git a/llm-service/app/routers/index/tools/__init__.py b/llm-service/app/routers/index/tools/__init__.py index 28145f4d3..4372406f6 100644 --- a/llm-service/app/routers/index/tools/__init__.py +++ b/llm-service/app/routers/index/tools/__init__.py @@ -38,15 +38,23 @@ import json import logging import os +import re from typing import Any, Optional, cast, Annotated from urllib.parse import unquote -import re from fastapi import APIRouter, Header, HTTPException from pydantic import BaseModel from .... import exceptions from ....config import settings +from ....services.models import get_model_source, ModelSource +from ....services.models.providers import BedrockModelProvider +from ....services.query.agents.agent_tools.image_generation import ( + BEDROCK_STABLE_DIFFUSION_MODEL_ID, + BEDROCK_TITAN_IMAGE_MODEL_ID, + ImageGenerationTools, + IMAGE_GENERATION_TOOL_METADATA, +) from ....services.utils import has_admin_rights logger = logging.getLogger(__name__) @@ -74,7 +82,19 @@ class Tool(ToolMetadata): env: Optional[dict[str, str]] = None +class ImageGenerationConfig(BaseModel): + """ + Represents the complete image generation configuration. + """ + + enabled: bool = True + selected_tool: Optional[str] = None + + mcp_json_path: str = os.path.join(settings.tools_dir, "mcp.json") +image_generation_config_path: str = os.path.join( + settings.tools_dir, "image_generation_config.json" +) def get_mcp_config() -> dict[str, Any]: @@ -97,6 +117,49 @@ def get_mcp_config() -> dict[str, Any]: ) +def get_image_generation_config() -> ImageGenerationConfig: + """ + Gets the complete image generation configuration (enabled state + selected tool). + """ + if not os.path.exists(image_generation_config_path): + # Default configuration - disabled by default + return ImageGenerationConfig(enabled=False, selected_tool=None) + + try: + with open(image_generation_config_path, "r") as f: + data = json.load(f) + return ImageGenerationConfig(**data) + except Exception: + logger.error( + "Failed to get image generation config from %s", + image_generation_config_path, + ) + # Default configuration - disabled by default + return ImageGenerationConfig(enabled=False, selected_tool=None) + + +def set_image_generation_config(config: ImageGenerationConfig) -> None: + """ + Sets the complete image generation configuration. + """ + os.makedirs(os.path.dirname(image_generation_config_path), exist_ok=True) + with open(image_generation_config_path, "w") as f: + json.dump(config.model_dump(), f, indent=2) + + +def get_selected_image_generation_tool() -> Optional[str]: + """ + Gets the currently selected image generation tool. + Returns None if image generation is disabled or no tool is selected. + """ + config = get_image_generation_config() + + if not config.enabled: + return None + + return config.selected_tool + + @router.get( "", summary="Returns a list of available tools.", @@ -104,9 +167,132 @@ def get_mcp_config() -> dict[str, Any]: ) @exceptions.propagates def tools() -> list[ToolMetadata]: - + # Get MCP tools from config mcp_config = get_mcp_config() - return [ToolMetadata(**server) for server in mcp_config["mcp_servers"]] + tool_list = [ToolMetadata(**server) for server in mcp_config["mcp_servers"]] + + return tool_list + + +def get_image_generation_tool_metadata() -> list[ToolMetadata]: + # Get current model provider + model_source = get_model_source() + # Add image generation tools based on the current model provider + if model_source == ModelSource.OPENAI: + tool_metadata = IMAGE_GENERATION_TOOL_METADATA[ + ImageGenerationTools.OPENAI_IMAGE_GENERATION + ] + return [ + ToolMetadata( + name=ImageGenerationTools.OPENAI_IMAGE_GENERATION, + metadata={ + "description": tool_metadata["description"], + "display_name": tool_metadata["display_name"], + }, + ) + ] + if model_source == ModelSource.BEDROCK: + supported_model_ids = [ + BEDROCK_STABLE_DIFFUSION_MODEL_ID, + BEDROCK_TITAN_IMAGE_MODEL_ID, + ] + available_models = BedrockModelProvider.list_image_generation_models() + supported_bedrock_image_generation_tools = [] + if not available_models: + return [] + for model in available_models: + if model.model_id not in supported_model_ids: + continue + if model.model_id == BEDROCK_STABLE_DIFFUSION_MODEL_ID: + tool_metadata = IMAGE_GENERATION_TOOL_METADATA[ + ImageGenerationTools.BEDROCK_STABLE_DIFFUSION + ] + supported_bedrock_image_generation_tools.append( + ToolMetadata( + name=ImageGenerationTools.BEDROCK_STABLE_DIFFUSION, + metadata={ + "description": tool_metadata["description"], + "display_name": tool_metadata["display_name"], + }, + ) + ) + elif model.model_id == BEDROCK_TITAN_IMAGE_MODEL_ID: + tool_metadata = IMAGE_GENERATION_TOOL_METADATA[ + ImageGenerationTools.BEDROCK_TITAN_IMAGE + ] + supported_bedrock_image_generation_tools.append( + ToolMetadata( + name=ImageGenerationTools.BEDROCK_TITAN_IMAGE, + metadata={ + "description": tool_metadata["description"], + "display_name": tool_metadata["display_name"], + }, + ) + ) + return supported_bedrock_image_generation_tools + # Return empty list for other model providers + return [] + + +@router.get( + "/image-generation", + summary="Returns a list of available image generation tools.", + response_model=list[ToolMetadata], +) +@exceptions.propagates +def image_generation_tools() -> list[ToolMetadata]: + """ + Returns a list of available image generation tools based on the current model provider. + """ + return get_image_generation_tool_metadata() + + +@router.get( + "/image-generation/config", + summary="Returns the complete image generation configuration.", + response_model=ImageGenerationConfig, +) +@exceptions.propagates +def get_image_generation_config_endpoint() -> ImageGenerationConfig: + """ + Returns the complete image generation configuration. + """ + return get_image_generation_config() + + +@router.post( + "/image-generation/config", + summary="Sets the complete image generation configuration.", + response_model=ImageGenerationConfig, +) +@exceptions.propagates +def set_image_generation_config_endpoint( + config: ImageGenerationConfig, + remote_user: Annotated[str | None, Header()] = None, + remote_user_perm: Annotated[str, Header()] = None, +) -> ImageGenerationConfig: + """ + Sets the complete image generation configuration. + """ + if not has_admin_rights(remote_user, remote_user_perm): + raise HTTPException( + status_code=401, + detail="You do not have permission to modify tool settings.", + ) + + # If enabling and selecting a tool, validate that the tool exists + if config.enabled and config.selected_tool is not None: + available_tools = get_image_generation_tool_metadata() + available_tool_names = [tool.name for tool in available_tools] + + if config.selected_tool not in available_tool_names: + raise HTTPException( + status_code=400, + detail=f"Invalid tool selection. Available tools: {available_tool_names}", + ) + + set_image_generation_config(config) + return config @router.post( @@ -169,6 +355,16 @@ def delete_tool( decoded_name = unquote(name) + # Prevent deletion of image generation tools + available_image_tools = get_image_generation_tool_metadata() + image_tool_names = [tool.name for tool in available_image_tools] + + if decoded_name in image_tool_names: + raise HTTPException( + status_code=400, + detail="Image generation tools cannot be deleted.", + ) + mcp_config = get_mcp_config() # Find the tool with the given name diff --git a/llm-service/app/services/chat/chat.py b/llm-service/app/services/chat/chat.py index 91102b00d..0984154a3 100644 --- a/llm-service/app/services/chat/chat.py +++ b/llm-service/app/services/chat/chat.py @@ -156,6 +156,11 @@ def finalize_response( evaluations.append(Evaluation(name="relevance", value=relevance)) evaluations.append(Evaluation(name="faithfulness", value=faithfulness)) response_source_nodes = format_source_nodes(chat_response) + + # remove the sandbox prefix from the response for OpenAI image generation responses + if "(sandbox:" in chat_response.response: + chat_response.response = chat_response.response.replace("(sandbox:", "(") + new_chat_message = RagStudioChatMessage( id=response_id, session_id=session.id, diff --git a/llm-service/app/services/chat/streaming_chat.py b/llm-service/app/services/chat/streaming_chat.py index 9bda03723..d720498ce 100644 --- a/llm-service/app/services/chat/streaming_chat.py +++ b/llm-service/app/services/chat/streaming_chat.py @@ -39,7 +39,11 @@ import uuid from typing import Optional, Generator -from llama_index.core.base.llms.types import ChatResponse, ChatMessage +from llama_index.core.base.llms.types import ( + ChatResponse, + ChatMessage, + TextBlock, +) from llama_index.core.chat_engine.types import ( AgentChatResponse, StreamingAgentChatResponse, @@ -95,7 +99,6 @@ def stream_chat( if not query_configuration.use_tool_calling and ( len(session.get_all_data_source_ids()) == 0 or total_data_sources_size == 0 ): - # put a poison pill in the queue to stop the tool events stream return _stream_direct_llm_chat(session, response_id, query, user_name) condensed_question, streaming_chat_response = build_streamer( @@ -122,7 +125,9 @@ def _run_streaming_chat( streaming_chat_response: StreamingAgentChatResponse, condensed_question: Optional[str] = None, ) -> Generator[ChatResponse, None, None]: - response: ChatResponse = ChatResponse(message=ChatMessage(content=query)) + response: ChatResponse = ChatResponse( + message=ChatMessage(blocks=[TextBlock(text=query)]) + ) if streaming_chat_response.chat_stream: for response in streaming_chat_response.chat_stream: response.additional_kwargs["response_id"] = response_id @@ -162,7 +167,9 @@ def build_streamer( chat_history = retrieve_chat_history(session.id) chat_messages = list( map( - lambda message: ChatMessage(role=message.role, content=message.content), + lambda message: ChatMessage( + role=message.role, blocks=[TextBlock(text=message.content)] + ), chat_history, ) ) @@ -190,9 +197,17 @@ def _stream_direct_llm_chat( record_direct_llm_mlflow_run(response_id, session, user_name) chat_response = llm_completion.stream_completion( - session.id, query, session.inference_model + session.id, + ChatMessage( + blocks=[ + TextBlock(text=query), + ] + ), + session.inference_model, + ) + response: ChatResponse = ChatResponse( + message=ChatMessage(blocks=[TextBlock(text=query)]) ) - response: ChatResponse = ChatResponse(message=ChatMessage(content=query)) for response in chat_response: response.additional_kwargs["response_id"] = response_id yield response diff --git a/llm-service/app/services/chat_history/simple_chat_history_manager.py b/llm-service/app/services/chat_history/simple_chat_history_manager.py index 1176134fb..69d88d4d6 100644 --- a/llm-service/app/services/chat_history/simple_chat_history_manager.py +++ b/llm-service/app/services/chat_history/simple_chat_history_manager.py @@ -40,7 +40,7 @@ import os from typing import List -from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.base.llms.types import ChatMessage, MessageRole, TextBlock from llama_index.core.storage.chat_store import SimpleChatStore from app.config import settings @@ -137,7 +137,9 @@ def append_to_history( self._build_chat_key(session_id), ChatMessage( role=MessageRole.USER, - content=message.rag_message.user, + blocks=[ + TextBlock(text=message.rag_message.user) + ], additional_kwargs={ "id": message.id, }, @@ -147,7 +149,9 @@ def append_to_history( self._build_chat_key(session_id), ChatMessage( role=MessageRole.ASSISTANT, - content=message.rag_message.assistant, + blocks=[ + TextBlock(text=message.rag_message.assistant) + ], additional_kwargs={ "id": message.id, "source_nodes": message.source_nodes, diff --git a/llm-service/app/services/llm_completion.py b/llm-service/app/services/llm_completion.py index 922ea5861..5b1c834c9 100644 --- a/llm-service/app/services/llm_completion.py +++ b/llm-service/app/services/llm_completion.py @@ -71,7 +71,7 @@ def completion(session_id: int, question: str, model_name: str) -> ChatResponse: def stream_completion( - session_id: int, question: str, model_name: str + session_id: int, query: ChatMessage, model_name: str ) -> Generator[ChatResponse, None, None]: """ Streamed version of the completion function. @@ -84,7 +84,7 @@ def stream_completion( map(lambda x: make_chat_messages(x), chat_history) ) ) - messages.append(ChatMessage.from_str(question, role="user")) + messages.append(query) stream = model.stream_chat(messages) return stream diff --git a/llm-service/app/services/models/__init__.py b/llm-service/app/services/models/__init__.py index da248cc22..143632ee2 100644 --- a/llm-service/app/services/models/__init__.py +++ b/llm-service/app/services/models/__init__.py @@ -41,7 +41,7 @@ from .reranking import Reranking from ._model_source import ModelSource -__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"] +__all__ = ["Embedding", "LLM", "Reranking", "ModelSource", "get_model_source"] def get_model_source() -> ModelSource: diff --git a/llm-service/app/services/models/llm.py b/llm-service/app/services/models/llm.py index e8283ac32..c58161ed3 100644 --- a/llm-service/app/services/models/llm.py +++ b/llm-service/app/services/models/llm.py @@ -38,7 +38,7 @@ from typing import Literal, Optional from fastapi import HTTPException from llama_index.core import llms -from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.base.llms.types import ChatMessage, MessageRole, TextBlock from . import _model_type, _noop from .providers._model_provider import ModelProvider @@ -76,7 +76,7 @@ def test_llm_chat(cls, model_name: str) -> Literal["ok"]: messages=[ ChatMessage( role=MessageRole.USER, - content="Are you available to answer questions?", + blocks=[TextBlock(text="Are you available to answer questions?")], ) ] ) diff --git a/llm-service/app/services/models/providers/_model_provider.py b/llm-service/app/services/models/providers/_model_provider.py index 954d93b2f..5794170b4 100644 --- a/llm-service/app/services/models/providers/_model_provider.py +++ b/llm-service/app/services/models/providers/_model_provider.py @@ -105,6 +105,12 @@ def list_embedding_models() -> list[ModelResponse]: def list_reranking_models() -> list[ModelResponse]: """Return names and IDs of available reranking models.""" raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def list_image_generation_models() -> list[ModelResponse]: + """Return names and IDs of available image generation models.""" + raise NotImplementedError @staticmethod @abc.abstractmethod diff --git a/llm-service/app/services/models/providers/azure.py b/llm-service/app/services/models/providers/azure.py index c079e1aa3..1cb4a44bc 100644 --- a/llm-service/app/services/models/providers/azure.py +++ b/llm-service/app/services/models/providers/azure.py @@ -81,6 +81,10 @@ def list_embedding_models() -> list[ModelResponse]: @staticmethod def list_reranking_models() -> list[ModelResponse]: return [] + + @staticmethod + def list_image_generation_models() -> list[ModelResponse]: + return [] @staticmethod def get_llm_model(name: str) -> AzureOpenAI: diff --git a/llm-service/app/services/models/providers/bedrock.py b/llm-service/app/services/models/providers/bedrock.py index f74ab7f70..2ffb9178b 100644 --- a/llm-service/app/services/models/providers/bedrock.py +++ b/llm-service/app/services/models/providers/bedrock.py @@ -278,6 +278,23 @@ def list_embedding_models() -> list[ModelResponse]: ) return models + + @staticmethod + def list_image_generation_models() -> list[ModelResponse]: + modality: BedrockModality = TypeAdapter(BedrockModality).validate_python( + "IMAGE" + ) + available_models = BedrockModelProvider.list_available_models(modality) + models = [] + for model in available_models: + models.append( + ModelResponse( + model_id=model["modelId"], + name=model["modelName"], + available=True, + ) + ) + return models @staticmethod def list_reranking_models() -> list[ModelResponse]: diff --git a/llm-service/app/services/models/providers/caii.py b/llm-service/app/services/models/providers/caii.py index 52d0c5b10..392594ec7 100644 --- a/llm-service/app/services/models/providers/caii.py +++ b/llm-service/app/services/models/providers/caii.py @@ -78,6 +78,11 @@ def list_embedding_models() -> list[ModelResponse]: @timed_lru_cache(maxsize=1, seconds=300) def list_reranking_models() -> list[ModelResponse]: return get_caii_reranking_models() + + @staticmethod + @timed_lru_cache(maxsize=32, seconds=300) + def list_image_generation_models() -> list[ModelResponse]: + return [] @staticmethod @timed_lru_cache(maxsize=32, seconds=300) diff --git a/llm-service/app/services/models/providers/openai.py b/llm-service/app/services/models/providers/openai.py index 9bea7cacb..31ed119ba 100644 --- a/llm-service/app/services/models/providers/openai.py +++ b/llm-service/app/services/models/providers/openai.py @@ -81,6 +81,11 @@ def list_embedding_models() -> list[ModelResponse]: @staticmethod def list_reranking_models() -> list[ModelResponse]: return [] + + @staticmethod + def list_image_generation_models() -> list[ModelResponse]: + # TODO: Implement this when openAI model discovery is implemented + return [] @staticmethod def _http_client() -> Optional[httpx.Client]: diff --git a/llm-service/app/services/query/agents/agent_tools/image_generation.py b/llm-service/app/services/query/agents/agent_tools/image_generation.py new file mode 100644 index 000000000..79eab821b --- /dev/null +++ b/llm-service/app/services/query/agents/agent_tools/image_generation.py @@ -0,0 +1,346 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +import abc +import base64 +import json +import os +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Literal, Optional + +import boto3 +from llama_index.core.tools.tool_spec.base import BaseToolSpec +from llama_index.tools.openai import ( + OpenAIImageGenerationToolSpec as LlamaIndexOpenAIImageGenerationToolSpec, +) +from llama_index.tools.openai.image_generation.base import DEFAULT_SIZE + +from app.services.query.agents.agent_tools.stable_diffusion_types import ( + StableDiffusionRequest, + GenerationMode, + AspectRatio, +) +from app.services.query.agents.agent_tools.titan_image_types import ( + TitanImageRequest, + TextToImageParams, + TitanImageGenerationConfig, + ValidTitanImageSizes, +) + +BEDROCK_STABLE_DIFFUSION_MODEL_ID = "stability.sd3-5-large-v1:0" +BEDROCK_TITAN_IMAGE_MODEL_ID = "amazon.titan-image-generator-v2:0" +OPENAI_IMAGE_GENERATION_MODEL_ID = "dall-e-3" + + +# Define image generation tool IDs for different providers +class ImageGenerationTools(str, Enum): + """Enum for image generation tool IDs.""" + + OPENAI_IMAGE_GENERATION = "openai-image-generation" + BEDROCK_STABLE_DIFFUSION = "bedrock-stable-diffusion" + BEDROCK_TITAN_IMAGE = "bedrock-titan-image" + + +# Tool metadata for UI display +IMAGE_GENERATION_TOOL_METADATA: Dict[str, Dict[str, Any]] = { + ImageGenerationTools.OPENAI_IMAGE_GENERATION: { + "display_name": "OpenAI Image Generation", + "description": "Generate images using OpenAI's DALL-E model", + }, + ImageGenerationTools.BEDROCK_STABLE_DIFFUSION: { + "display_name": "Stable Diffusion (Bedrock)", + "description": "Generate images using Stable Diffusion models on AWS Bedrock", + }, + ImageGenerationTools.BEDROCK_TITAN_IMAGE: { + "display_name": "Titan Image Generator (Bedrock)", + "description": "Generate images using Amazon's Titan Image Generator on AWS Bedrock", + }, +} + + +class ImageGeneratorToolSpec(abc.ABC, BaseToolSpec): + """Base class for image generation tool specs.""" + + spec_functions = ["image_generation"] + + @staticmethod + def get_cache_dir() -> str: + """Return the cache directory.""" + return "./cache" + + @abc.abstractmethod + def image_generation(self, *args: Any, **kwargs: Any) -> str: + """Generate an image based on the provided parameters.""" + raise NotImplementedError("Subclasses must implement this method.") + + +class OpenAIImageGenerationToolSpec( + LlamaIndexOpenAIImageGenerationToolSpec, + ImageGeneratorToolSpec, +): + """OpenAI Image Generation tool spec.""" + + def __init__(self, api_key: Optional[str] = None) -> None: + """Initialize with parameters.""" + super().__init__( + api_key=api_key, + cache_dir=ImageGeneratorToolSpec.get_cache_dir(), + ) + + def image_generation( + self, + text: str, + model: Optional[str] = OPENAI_IMAGE_GENERATION_MODEL_ID, + quality: Optional[str] = "standard", + num_images: Optional[int] = 1, + size: Optional[str] = DEFAULT_SIZE, + style: Optional[str] = "vivid", + timeout: Optional[int] = None, + download: Optional[bool] = True, # For suppressing signature error + ) -> str: + """ + This tool accepts a natural language string and will use OpenAI's DALL-E model to generate an image. + + Args: + text: The text to generate an image from. + + model: The model to use for image generation. Defaults to `dall-e-3`. + Must be one of `dall-e-2` or `dall-e-3`. + + num_images: The number of images to generate. Defaults to 1. + Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. + + quality: The quality of the image that will be generated. Defaults to `standard`. + Must be one of `standard` or `hd`. `hd` creates images with finer + details and greater consistency across the image. This param is only supported + for `dall-e-3`. + + size: The size of the generated images. Defaults to `1024x1024`. + Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. + Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. + + style: The style of the generated images. Defaults to `vivid`. + Must be one of `vivid` or `natural`. + Vivid causes the model to lean towards generating hyper-real and dramatic images. + Natural causes the model to produce more natural, less hyper-real looking images. + This param is only supported for `dall-e-3`. + + timeout: Override the client-level default timeout for this request, in seconds. Defaults to `None`. + download: For suppressing signature error, not used in this implementation. + Returns: + str: Path to the generated image in the cache directory. + """ + image_path = super().image_generation( + text=text, + model=model, + quality=quality, + num_images=num_images, + size=size, + style=style, + timeout=timeout, + download=True, + ) + image_name = Path(image_path[0]).name + return f"/llm-service/cache/{image_name}" + + +class BedrockStableDiffusionToolSpec(ImageGeneratorToolSpec): + """Bedrock Stable Diffusion Image Generation tool spec.""" + + def __init__( + self, + model: str = BEDROCK_STABLE_DIFFUSION_MODEL_ID, + **kwargs: Any, + ) -> None: + """Initialize with parameters.""" + super().__init__(**kwargs) + self.client = boto3.client("bedrock-runtime") + self.model = model + + def image_generation( + self, + text: str, + image_name: str, + seed: int = 42, + negative_text: str = "", + aspect_ratio: Optional[AspectRatio] = AspectRatio.RATIO_5_4, + ) -> str: + """ + Generate an image using Stable Diffusion models on Bedrock. + + Parameters: + text (str): The prompt for image generation. + image_name (str): The name to save the generated image as. + seed (int, optional): Random seed for generation. + negative_text (str, optional): Negative prompt for image generation. + aspect_ratio (AspectRatio, optional): Aspect ratio for the generated image. + + Returns: + str: Path to the generated image in the cache directory. + """ + # Create Stable Diffusion Pydantic model instance + sd_request = StableDiffusionRequest( + prompt=text, + negative_prompt=negative_text, + mode=GenerationMode.TEXT_TO_IMAGE, + seed=seed, + aspect_ratio=aspect_ratio, + ) + + # Convert to JSON for API request + request = sd_request.model_dump_json() + + # Call the Bedrock API + return _get_image_from_bedrock( + boto3_bedrock_client=self.client, + model=self.model, + request=request, + image_name=image_name, + cache_dir=self.get_cache_dir(), + ) + + +class BedrockTitanImageToolSpec(ImageGeneratorToolSpec): + """Bedrock Titan Image Generation tool spec.""" + + def __init__( + self, + model: str = BEDROCK_TITAN_IMAGE_MODEL_ID, + **kwargs: Any, + ) -> None: + """Initialize with parameters.""" + super().__init__(**kwargs) + self.client = boto3.client("bedrock-runtime") + self.model = model + + def image_generation( + self, + text: str, + image_name: str, + quality: Literal["standard", "premium"] = "standard", + num_images: int = 1, + seed: int = 42, + negative_text: str = "", + cfg_scale: float = 8.0, + size: ValidTitanImageSizes = ValidTitanImageSizes.SMALL, + ) -> str: + """ + Generate an image using Amazon Titan Image Generator on Bedrock. + + Parameters: + text (str): The prompt for image generation. + image_name (str): The name to save the generated image as. + quality (Literal["standard", "premium"], optional): Image quality. + num_images (int, optional): Number of images to generate. + seed (int, optional): Random seed for generation. + negative_text (str, optional): Negative prompt for image generation. + cfg_scale (float, optional): Configuration scale for generation. + size (ValidTitanImageSizes, optional): Image size for generation. + + Returns: + str: Path to the generated image in the cache directory. + """ + # Create Titan Pydantic model instances + text_to_image_params = TextToImageParams(text=text, negativeText=negative_text) + image_generation_config = TitanImageGenerationConfig( + numberOfImages=num_images, + quality=quality, + cfgScale=cfg_scale, + width=size[0], + height=size[1], + seed=seed, + ) + + titan_request = TitanImageRequest( + taskType="TEXT_IMAGE", + textToImageParams=text_to_image_params, + imageGenerationConfig=image_generation_config, + ) + + # Convert to JSON for API request + request = titan_request.model_dump_json() + + # Call the Bedrock API + return _get_image_from_bedrock( + boto3_bedrock_client=self.client, + model=self.model, + request=request, + image_name=image_name, + cache_dir=self.get_cache_dir(), + ) + + +def _get_image_from_bedrock( + boto3_bedrock_client: Any, + model: str, + request: str, + image_name: str, + cache_dir: str, +) -> str: + """ + Helper function to get an image from Bedrock and save it to the cache directory. + + Parameters: + boto3_bedrock_client: The Bedrock client. + model: The model ID to use for image generation. + request: The request payload for the model. + image_name (str): The name to save the generated image as. + cache_dir (str): The directory to save the image in. + + Returns: + str: Path to the generated image in the cache directory. + """ + response = boto3_bedrock_client.invoke_model(modelId=model, body=request) + model_response = json.loads(response["body"].read()) + base64_image_data = model_response["images"][0] + image_data = base64.b64decode(base64_image_data) + + # Save the image to the cache directory + image_path = os.path.join(cache_dir, f"{image_name}.png") + os.makedirs(cache_dir, exist_ok=True) + if os.path.exists(image_path): + # use a different name for the image + i = 1 + # increment the number until the image does not exist + while os.path.exists(image_path): + image_name = f"{image_name}_{i}" + image_path = os.path.join(cache_dir, f"{image_name}.png") + i += 1 + with open(image_path, "wb") as file: + file.write(image_data) + return f"/llm-service/cache/{image_name}.png" diff --git a/llm-service/app/services/query/agents/agent_tools/stable_diffusion_types.py b/llm-service/app/services/query/agents/agent_tools/stable_diffusion_types.py new file mode 100644 index 000000000..f88199e6e --- /dev/null +++ b/llm-service/app/services/query/agents/agent_tools/stable_diffusion_types.py @@ -0,0 +1,134 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +from enum import Enum +from typing import Optional, TypeVar + +from pydantic import BaseModel, Field, field_validator +from pydantic_core.core_schema import ValidationInfo + +T = TypeVar("T") + + +class GenerationMode(str, Enum): + """Generation mode for Stable Diffusion models.""" + + TEXT_TO_IMAGE = "text-to-image" + IMAGE_TO_IMAGE = "image-to-image" + + +class AspectRatio(str, Enum): + """Available aspect ratios for generated images.""" + + RATIO_16_9 = "16:9" + RATIO_1_1 = "1:1" + RATIO_21_9 = "21:9" + RATIO_2_3 = "2:3" + RATIO_3_2 = "3:2" + RATIO_4_5 = "4:5" + RATIO_5_4 = "5:4" + RATIO_9_16 = "9:16" + RATIO_9_21 = "9:21" + + +class OutputFormat(str, Enum): + """Available output formats for generated images.""" + + JPEG = "jpeg" + PNG = "png" + + +class StableDiffusionRequest(BaseModel): + """ + Stable Diffusion Request model for Amazon Bedrock. + + This model includes all parameters needed for both text-to-image and image-to-image generation. + """ + + # Required fields + prompt: str = Field( + ..., + min_length=1, + max_length=10000, + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ) + + # Optional fields with defaults + mode: GenerationMode = Field( + default=GenerationMode.TEXT_TO_IMAGE, + description="Controls whether this is a text-to-image or image-to-image generation.", + ) + + seed: int = Field( + default=0, + ge=0, + le=4294967294, + description="A specific value that is used to guide the 'randomness' of the generation. (Use 0 for a random seed.)", + ) + + output_format: OutputFormat = Field( + default=OutputFormat.PNG, + description="Dictates the content-type of the generated image.", + ) + + negative_prompt: str = Field( + default="", + max_length=10000, + description="Keywords of what you do not wish to see in the output image.", + ) + + aspect_ratio: Optional[AspectRatio] = Field( + default=AspectRatio.RATIO_1_1, + description="Controls the aspect ratio of the generated image. Only valid for text-to-image requests.", + ) + + @field_validator("*") + @classmethod + def validate_mode_specific_fields(cls, v: T, info: ValidationInfo) -> T: + """Validate that the appropriate fields are provided based on the generation mode.""" + field_name = info.field_name + mode = info.data.get("mode") + + # Skip validation if mode is not set yet + if not mode: + return v + + if mode == GenerationMode.TEXT_TO_IMAGE: + # For text-to-image, image and strength should not be provided + if field_name in ["image", "strength"] and v is not None: + raise ValueError( + f"{field_name} should not be provided for text-to-image generation" + ) + + return v + + class Config: + """Configuration for the StableDiffusionRequest model.""" + + use_enum_values = True # Use the string values of enums diff --git a/llm-service/app/services/query/agents/agent_tools/titan_image_types.py b/llm-service/app/services/query/agents/agent_tools/titan_image_types.py new file mode 100644 index 000000000..717e8cfe8 --- /dev/null +++ b/llm-service/app/services/query/agents/agent_tools/titan_image_types.py @@ -0,0 +1,83 @@ +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +import enum +from typing import Literal, Optional + +from pydantic import BaseModel, Field + + +class TextToImageParams(BaseModel): + """Parameters for text to image generation.""" + + text: str = Field(..., description="The text prompt for image generation") + negativeText: Optional[str] = Field( + None, description="A text prompt to define what not to include in the image" + ) + + +class TitanImageGenerationConfig(BaseModel): + """Configuration parameters for image generation.""" + + numberOfImages: int = Field( + 1, ge=1, le=5, description="Number of images to generate (1-5)" + ) + quality: Literal["standard", "premium"] = Field( + "standard", description="Quality of the generated image" + ) + cfgScale: float = Field( + 8.0, ge=1.1, le=10.0, description="Configuration scale parameter (1.1-10.0)" + ) + width: int = Field(512, description="Width of the generated image in pixels") + height: int = Field(512, description="Height of the generated image in pixels") + seed: Optional[int] = Field( + 42, + ge=0, + le=2147483646, + description="Seed for reproducible results (0-2,147,483,646)", + ) + + +class TitanImageRequest(BaseModel): + """Titan Image Generation Request model matching the Amazon Bedrock Titan model requirements.""" + + taskType: Literal["TEXT_IMAGE"] = Field( + default="TEXT_IMAGE", description="Type of task for image generation" + ) + textToImageParams: TextToImageParams = Field( + ..., description="Text to image parameters" + ) + imageGenerationConfig: TitanImageGenerationConfig = Field( + ..., description="Image generation configuration" + ) + + +class ValidTitanImageSizes(tuple[int, int], enum.Enum): + LARGE = (1024, 1024) + MEDIUM = (768, 768) + SMALL = (512, 512) diff --git a/llm-service/app/services/query/agents/tool_calling_querier.py b/llm-service/app/services/query/agents/tool_calling_querier.py index 104f50ba2..5f6bc5310 100644 --- a/llm-service/app/services/query/agents/tool_calling_querier.py +++ b/llm-service/app/services/query/agents/tool_calling_querier.py @@ -51,7 +51,12 @@ AgentInput, AgentSetup, ) -from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse +from llama_index.core.base.llms.types import ( + ChatMessage, + MessageRole, + ChatResponse, + TextBlock, +) from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.schema import NodeWithScore @@ -61,8 +66,19 @@ from llama_index.llms.bedrock_converse.utils import get_model_name from app.ai.indexing.summary_indexer import SummaryIndexer -from app.services.metadata_apis.session_metadata_api import Session +from app.config import settings +from app.services.metadata_apis.session_metadata_api import ( + Session, + SessionQueryConfiguration, +) +from app.services.models import get_model_source, ModelSource from app.services.models.providers import BedrockModelProvider +from app.services.query.agents.agent_tools.image_generation import ( + BedrockStableDiffusionToolSpec, + BedrockTitanImageToolSpec, + OpenAIImageGenerationToolSpec, +) +from app.services.query.agents.agent_tools.image_generation import ImageGenerationTools from app.services.query.agents.agent_tools.mcp import get_llama_index_tools from app.services.query.agents.agent_tools.retriever import ( build_retriever_tool, @@ -116,6 +132,34 @@ def should_use_retrieval( return len(data_source_ids) > 0, data_source_summaries +def get_bedrock_image_generation_tools( + model_id: Optional[str] = None, +) -> list[BaseTool]: + """ + Get the appropriate Bedrock image generation tool based on model type. + + Args: + model_id: Optional model ID to determine the tool type. If not provided, + defaults to Stable Diffusion. + + Returns: + List of BaseTool objects for image generation. + """ + tools: list[BaseTool] + if not model_id: + # Default to Stable Diffusion if no model specified + tools = BedrockStableDiffusionToolSpec().to_tool_list() + else: + model_id_lower = model_id.lower() + if "titan" in model_id_lower: + tools = BedrockTitanImageToolSpec(model=model_id).to_tool_list() + elif "stability" in model_id_lower or "sd" in model_id_lower: + tools = BedrockStableDiffusionToolSpec(model=model_id).to_tool_list() + else: + tools = [] + return tools + + DEFAULT_AGENT_PROMPT = """\ ### DATE AND TIME Today's date is {date} and the current time is {time}. This date and time \ @@ -141,7 +185,11 @@ def should_use_retrieval( you need to answer the question with the provided sources or tools, \ you truthfully say you do not know and let the user know how you arrived \ at the response and what information you used (links if any) to arrive \ -at it and ask for clarification or more information. +at it and ask for clarification or more information. +6. If you use the image generation tool, you will provide the image \ +name in the response as a markdown image link ![](), \ +where is the name of the image and \ +is the URL or path of the image generated by the image generation tool. \ ### OUTPUT FORMAT As the agent, you will provide an answer based solely on the provided \ @@ -204,13 +252,24 @@ def stream_chat( if session.query_configuration and session.query_configuration.selected_tools: for tool_name in session.query_configuration.selected_tools: try: + # Check if it's an image generation tool + if tool_name in [t.value for t in ImageGenerationTools]: + # Skip adding here, will be handled separately + continue mcp_tools.extend(get_llama_index_tools(tool_name)) except ValueError as e: logger.warning(f"Could not create adapter for tool {tool_name}: {e}") continue - # Use the existing chat engine with the enhanced query for streaming response tools: list[BaseTool] = mcp_tools + + # Add image generation tools only if they are selected in the session + image_generator_tools = get_image_generator_tools(session.query_configuration) + + # Add the image generator tools to the list of tools + if image_generator_tools: + tools.extend(image_generator_tools) + # Use tool calling only if retrieval is not the only tool to optimize performance if tools and use_retrieval and chat_engine: retrieval_tool = build_retriever_tool( @@ -227,6 +286,64 @@ def stream_chat( return StreamingAgentChatResponse(chat_stream=gen, source_nodes=source_nodes) +def get_image_generator_tools( + query_configuration: SessionQueryConfiguration, +) -> list[BaseTool]: + image_generator_tools: list[BaseTool] = [] + model_source = get_model_source() + + # Import the function to get currently selected image generation tool + from app.routers.index.tools import get_selected_image_generation_tool + + # Get the globally selected image generation tool + selected_image_tool = get_selected_image_generation_tool() + + # Process user-selected image generation tools only if they match the globally selected tool + if query_configuration and query_configuration.selected_tools: + for tool_name in query_configuration.selected_tools: + if tool_name in [t.value for t in ImageGenerationTools]: + # Only allow this tool if it's the globally selected one + if tool_name != selected_image_tool: + logger.warning( + "Image generation tool %s is selected in session but " + "not globally available. Skipping.", + tool_name, + ) + continue + + # Check if the tool matches the current model provider and is available + if ( + tool_name == ImageGenerationTools.OPENAI_IMAGE_GENERATION + and model_source == ModelSource.OPENAI + ): + if settings.openai_api_key is None: + logger.warning( + "OpenAI image generation tool selected but API key not set" + ) + continue + image_generator_tools.extend( + OpenAIImageGenerationToolSpec( + api_key=settings.openai_api_key + ).to_tool_list() + ) + elif ( + tool_name == ImageGenerationTools.BEDROCK_STABLE_DIFFUSION + and model_source == ModelSource.BEDROCK + ): + image_generator_tools.extend( + BedrockStableDiffusionToolSpec().to_tool_list() + ) + elif ( + tool_name == ImageGenerationTools.BEDROCK_TITAN_IMAGE + and model_source == ModelSource.BEDROCK + ): + image_generator_tools.extend( + BedrockTitanImageToolSpec().to_tool_list() + ) + + return image_generator_tools + + def _run_streamer( chat_engine: Optional[FlexibleContextChatEngine], chat_messages: list[ChatMessage], @@ -253,12 +370,26 @@ def _run_streamer( # If no chat engine is provided, we can use the LLM directly direct_chat_gen = llm.stream_chat( messages=chat_messages - + [ChatMessage(role=MessageRole.USER, content=enhanced_query)] + + [ + ChatMessage( + role=MessageRole.USER, + blocks=[ + TextBlock(text=enhanced_query), + ], + ) + ] ) return direct_chat_gen, source_nodes async def agen() -> AsyncGenerator[ChatResponse, None]: - handler = agent.run(user_msg=enhanced_query, chat_history=chat_messages) + handler = agent.run( + user_msg=ChatMessage( + blocks=[ + TextBlock(text=enhanced_query), + ] + ), + chat_history=chat_messages, + ) async for event in handler.stream_events(): if isinstance(event, AgentSetup): diff --git a/llm-service/pyproject.toml b/llm-service/pyproject.toml index 72eb9002c..9692bd852 100644 --- a/llm-service/pyproject.toml +++ b/llm-service/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "llama-index-postprocessor-bedrock-rerank>=0.3.1", "llama-index-postprocessor-nvidia-rerank>=0.4.0", "mlflow==2.22.1", - "llama-index-llms-azure-openai>=0.3.0", + "llama-index-llms-azure-openai>=0.3.4", "llama-index-embeddings-azure-openai>=0.3.0", "llama-index-llms-nvidia>=0.3.2", "llama-index-storage-kvstore-s3>=0.3.0", @@ -48,6 +48,7 @@ dependencies = [ "llama-index-tools-mcp>=0.2.5", "llama-index-readers-docling>=0.3.3", "llama-index-node-parser-docling>=0.3.2", + "llama-index-tools-openai-image-generation>=0.4.0", ] requires-python = ">=3.10,<3.13" readme = "README.md" diff --git a/llm-service/uv.lock b/llm-service/uv.lock index b64716420..324e2f60d 100644 --- a/llm-service/uv.lock +++ b/llm-service/uv.lock @@ -2279,6 +2279,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/c4/7ef21261cd61a65e0394d7d9bd8da59990e9948aa9314406c80a02a5ac3c/llama_index_tools_mcp-0.4.0-py3-none-any.whl", hash = "sha256:1312cccfc6bc35a10af5a184baa3b1242a94090782262909fe75f693260ad7db", size = 13812 }, ] +[[package]] +name = "llama-index-tools-openai-image-generation" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llama-index-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/47/0a40292f7f9a98bcf959018b9ca4ef6df1aa52c9cb9812dfbd3f7a11c105/llama_index_tools_openai_image_generation-0.4.0.tar.gz", hash = "sha256:ffb80420e26d55e936eded4ca1bbbc4174216a8fbfad609759833dd6fc17e869", size = 3697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/a6/de9611742d91ab358093e06d6b14d24e580212bee2063fdbccdd9b969219/llama_index_tools_openai_image_generation-0.4.0-py3-none-any.whl", hash = "sha256:cdc14c29c232229673f3898d64ecd23bc7c9e9882858fd2e854f7bf3f0ab2fec", size = 4207 }, +] + [[package]] name = "llama-index-vector-stores-opensearch" version = "0.6.0" @@ -2348,6 +2360,7 @@ dependencies = [ { name = "llama-index-readers-file" }, { name = "llama-index-storage-kvstore-s3" }, { name = "llama-index-tools-mcp" }, + { name = "llama-index-tools-openai-image-generation" }, { name = "llama-index-vector-stores-opensearch" }, { name = "llama-index-vector-stores-qdrant" }, { name = "llvmlite" }, @@ -2395,7 +2408,7 @@ requires-dist = [ { name = "llama-index-embeddings-azure-openai", specifier = ">=0.3.0" }, { name = "llama-index-embeddings-bedrock", specifier = ">=0.2.1" }, { name = "llama-index-embeddings-openai", specifier = ">=0.1.11" }, - { name = "llama-index-llms-azure-openai", specifier = ">=0.3.0" }, + { name = "llama-index-llms-azure-openai", specifier = ">=0.3.4" }, { name = "llama-index-llms-bedrock", specifier = ">=0.3.4" }, { name = "llama-index-llms-bedrock-converse", specifier = ">=0.7.6" }, { name = "llama-index-llms-nvidia", specifier = ">=0.3.2" }, @@ -2407,6 +2420,7 @@ requires-dist = [ { name = "llama-index-readers-file", specifier = ">=0.1.33" }, { name = "llama-index-storage-kvstore-s3", specifier = ">=0.3.0" }, { name = "llama-index-tools-mcp", specifier = ">=0.2.5" }, + { name = "llama-index-tools-openai-image-generation", specifier = ">=0.4.0" }, { name = "llama-index-vector-stores-opensearch", specifier = ">=0.5.5" }, { name = "llama-index-vector-stores-qdrant", specifier = ">=0.2.17" }, { name = "llvmlite", specifier = "==0.43.0" }, diff --git a/tools/image_generation_config.json b/tools/image_generation_config.json new file mode 100644 index 000000000..21f3c7ff7 --- /dev/null +++ b/tools/image_generation_config.json @@ -0,0 +1,4 @@ +{ + "enabled": false, + "selected_tool": null +} \ No newline at end of file diff --git a/ui/src/api/toolsApi.ts b/ui/src/api/toolsApi.ts index 6177ad106..ece0b0971 100644 --- a/ui/src/api/toolsApi.ts +++ b/ui/src/api/toolsApi.ts @@ -68,6 +68,11 @@ export interface AddToolFormValues { description: string; } +export interface ImageGenerationConfig { + enabled: boolean; + selected_tool: string | null; +} + export const getTools = async (): Promise => { return getRequest(`${llmServicePath}/tools`); }; @@ -100,6 +105,52 @@ export const useAddToolMutation = ({ }); }; +export const getImageGenerationTools = async (): Promise => { + return getRequest(`${llmServicePath}/tools/image-generation`); +}; + +export const useImageGenerationToolsQuery = () => { + return useQuery({ + queryKey: [QueryKeys.getImageGenerationTools], + queryFn: getImageGenerationTools, + }); +}; + +export const getImageGenerationConfig = async (): Promise => { + return getRequest(`${llmServicePath}/tools/image-generation/config`); +}; + +export const useImageGenerationConfigQuery = () => { + return useQuery({ + queryKey: [QueryKeys.getImageGenerationConfig], + queryFn: getImageGenerationConfig, + }); +}; + +export const setImageGenerationConfig = async ( + config: ImageGenerationConfig +): Promise => { + return postRequest(`${llmServicePath}/tools/image-generation/config`, config); +}; + +export const useSetImageGenerationConfigMutation = ({ + onSuccess, + onError, +}: UseMutationType) => { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: setImageGenerationConfig, + onSuccess: (config) => { + void queryClient.invalidateQueries({ queryKey: [QueryKeys.getImageGenerationConfig] }); + void queryClient.invalidateQueries({ queryKey: [QueryKeys.getTools] }); + if (onSuccess) { + onSuccess(config); + } + }, + onError, + }); +}; + export const deleteTool = async (name: string): Promise => { return deleteRequest(`${llmServicePath}/tools/${name}`); }; diff --git a/ui/src/api/utils.ts b/ui/src/api/utils.ts index 4c6110549..944186567 100644 --- a/ui/src/api/utils.ts +++ b/ui/src/api/utils.ts @@ -113,6 +113,8 @@ export enum QueryKeys { "getSessionsForProject" = "getSessionsForProject", "getAmpConfig" = "getAmpConfig", "getTools" = "getTools", + "getImageGenerationTools" = "getImageGenerationTools", + "getImageGenerationConfig" = "getImageGenerationConfig", "getPollingAmpConfig" = "getPollingAmpConfig", "getCAIIModelStatus" = "getCAIIModelStatus", } diff --git a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx index d4ff6d48f..a9df9d4aa 100644 --- a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx +++ b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx @@ -45,7 +45,8 @@ import { Tooltip, Typography, } from "antd"; -import Icon, { DeleteOutlined, DownloadOutlined } from "@ant-design/icons"; +import Icon, { DeleteOutlined } from "@ant-design/icons"; +import { DownloadOutlined } from "@ant-design/icons"; import { RagDocumentResponseType, useDeleteDocumentMutation, @@ -57,19 +58,20 @@ import AiAssistantIcon from "src/cuix/icons/AiAssistantIcon"; import { useState } from "react"; import messageQueue from "src/utils/messageQueue.ts"; import { useQueryClient } from "@tanstack/react-query"; -import { paths, QueryKeys, ragPath } from "src/api/utils.ts"; +import { QueryKeys } from "src/api/utils.ts"; import useModal from "src/utils/useModal.ts"; import { cdlWhite } from "src/cuix/variables.ts"; import ReadyColumn from "pages/DataSources/ManageTab/ReadyColumn.tsx"; import SummaryColumn from "pages/DataSources/ManageTab/SummaryColumn.tsx"; import { ColumnsType } from "antd/es/table"; +import { paths, ragPath } from "src/api/utils.ts"; import { downloadFile } from "src/utils/downloadFile.ts"; const columns = ( dataSourceId: string, handleDeleteFile: (document: RagDocumentResponseType) => void, simpleColumns: boolean, - summarizationModel?: string + summarizationModel?: string, ): TableProps["columns"] => { let columns: ColumnsType = [ { @@ -253,7 +255,7 @@ const UploadedFilesTable = ({ dataSourceId, handleDeleteFileModal, simplifiedTable, - summarizationModel + summarizationModel, )} /> { children={data.rag_message.assistant.trimStart()} components={{ img: ( - props: ComponentProps<"img">, + props: ComponentProps<"img"> ): ReactElement | undefined => { + // check if the image src starts with `sandbox:` and replace it with `` + if (props.src?.startsWith("sandbox:")) { + props.src = props.src.replace("sandbox:", ""); + } return {props.alt}; }, a: ( - props: ComponentProps<"a">, + props: ComponentProps<"a"> ): ReactElement | undefined => { const { href, className, children, ...other } = props; if (className === "rag_citation") { @@ -67,7 +71,7 @@ export const MarkdownResponse = ({ data }: { data: ChatMessageType }) => { } const { source_nodes } = data; const sourceNodeIndex = source_nodes.findIndex( - (source_node) => source_node.node_id === href, + (source_node) => source_node.node_id === href ); if (sourceNodeIndex >= 0) { return ( diff --git a/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx b/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx index cdc5e021e..276f12b17 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/ToolsManager.tsx @@ -47,13 +47,18 @@ import { Tooltip, Typography, } from "antd"; -import { useToolsQuery } from "src/api/toolsApi.ts"; +import { + useToolsQuery, + useImageGenerationConfigQuery, + useImageGenerationToolsQuery, +} from "src/api/toolsApi.ts"; import { Dispatch, ReactNode, SetStateAction, useContext, useState, + useMemo, } from "react"; import { ToolOutlined } from "@ant-design/icons"; import { cdlBlue600, cdlOrange500, cdlWhite } from "src/cuix/variables.ts"; @@ -70,15 +75,12 @@ import { Link } from "@tanstack/react-router"; import { getAmpConfigQueryOptions } from "src/api/ampMetadataApi.ts"; const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { - const { data, isLoading } = useToolsQuery(); + const { data: regularTools, isLoading: regularToolsLoading } = + useToolsQuery(); + const { data: imageConfig } = useImageGenerationConfigQuery(); + const { data: imageTools } = useImageGenerationToolsQuery(); const { data: config } = useSuspenseQuery(getAmpConfigQueryOptions); - const toolsList = data?.map((tool) => ({ - name: tool.name, - displayName: tool.metadata.display_name, - description: tool.metadata.description, - })); - const queryClient = useQueryClient(); const updateSession = useUpdateSessionMutation({ @@ -113,12 +115,44 @@ const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { } else { handleUpdateSession( activeSession.queryConfiguration.selectedTools.filter( - (tool) => tool !== title, - ), + (tool) => tool !== title + ) ); } }; + // Combine regular tools with the selected image generation tool (if enabled) + const toolsList = useMemo(() => { + const tools = []; + + // Add regular tools + if (regularTools) { + tools.push( + ...regularTools.map((tool) => ({ + name: tool.name, + displayName: tool.metadata.display_name, + description: tool.metadata.description, + })) + ); + } + + // Only add image generation tool if it's enabled AND selected + if (imageConfig?.enabled && imageConfig.selected_tool && imageTools) { + const selectedTool = imageTools.find( + (tool) => tool.name === imageConfig.selected_tool + ); + if (selectedTool) { + tools.push({ + name: selectedTool.name, + displayName: selectedTool.metadata.display_name, + description: selectedTool.metadata.description, + }); + } + } + + return tools; + }, [regularTools, imageConfig, imageTools]); + return ( @@ -148,8 +182,8 @@ const ToolsManagerContent = ({ activeSession }: { activeSession: Session }) => { ( { avatar={ { handleCheck(item.name, e.target.checked); diff --git a/ui/src/pages/Tools/ToolsPage.tsx b/ui/src/pages/Tools/ToolsPage.tsx index e7a6b480e..e1d976e91 100644 --- a/ui/src/pages/Tools/ToolsPage.tsx +++ b/ui/src/pages/Tools/ToolsPage.tsx @@ -36,7 +36,7 @@ * DATA. */ -import { useState } from "react"; +import { useState, useEffect } from "react"; import { Alert, Button, @@ -44,14 +44,25 @@ import { Flex, Layout, Modal, + Radio, + RadioChangeEvent, Table, Typography, + Switch, } from "antd"; -import { DeleteOutlined, PlusOutlined } from "@ant-design/icons"; +import { + DeleteOutlined, + PlusOutlined, + SaveOutlined, + CloseOutlined, +} from "@ant-design/icons"; import { Tool, useDeleteToolMutation, useToolsQuery, + useImageGenerationToolsQuery, + useImageGenerationConfigQuery, + useSetImageGenerationConfigMutation, } from "src/api/toolsApi.ts"; import messageQueue from "src/utils/messageQueue.ts"; import useModal from "src/utils/useModal.ts"; @@ -60,7 +71,24 @@ import { AddNewToolModal } from "pages/Tools/AddNewToolModal.tsx"; const ToolsPage = () => { const confirmDeleteModal = useModal(); const { data: tools = [], isLoading, error: toolsError } = useToolsQuery(); + const { data: imageTools = [], isLoading: imageToolsLoading } = + useImageGenerationToolsQuery(); + const { data: imageConfig } = useImageGenerationConfigQuery(); const [isModalVisible, setIsModalVisible] = useState(false); + const [tempSelectedImageTool, setTempSelectedImageTool] = useState< + string | null + >(null); + const [tempEnabled, setTempEnabled] = useState(false); + const [isInitialized, setIsInitialized] = useState(false); + + // Update temp values when the actual config loads + useEffect(() => { + if (imageConfig !== undefined && !isInitialized) { + setTempSelectedImageTool(imageConfig.selected_tool ?? null); + setTempEnabled(imageConfig.enabled); + setIsInitialized(true); + } + }, [imageConfig, isInitialized]); const deleteToolMutation = useDeleteToolMutation({ onSuccess: () => { @@ -72,6 +100,59 @@ const ToolsPage = () => { }, }); + const setImageConfigMutation = useSetImageGenerationConfigMutation({ + onSuccess: (savedConfig) => { + messageQueue.success("Image generation configuration updated"); + // Update temp states to match the saved config + setTempEnabled(savedConfig.enabled); + setTempSelectedImageTool(savedConfig.selected_tool); + }, + onError: (error) => { + messageQueue.error(`Failed to update configuration: ${error.message}`); + }, + }); + + // Filter out image generation tools from regular tools + const imageToolNames = imageTools.map((tool) => tool.name); + const regularTools = tools.filter( + (tool) => !imageToolNames.includes(tool.name), + ); + + const handleTempImageToolSelectionChange = (e: RadioChangeEvent) => { + const value = e.target.value as string; + setTempSelectedImageTool(value); + }; + + const handleSaveImageConfig = () => { + let selectedTool = tempSelectedImageTool; + + // Auto-select single tool if enabled but no tool selected + if (tempEnabled && !selectedTool && imageTools.length === 1) { + selectedTool = imageTools[0].name; + } + + setImageConfigMutation.mutate({ + enabled: tempEnabled, + selected_tool: tempEnabled ? selectedTool : null, + }); + }; + + const handleToggleEnabled = (enabled: boolean) => { + setTempEnabled(enabled); + }; + + const handleUnselectImageTool = () => { + setTempSelectedImageTool(null); + }; + + const isImageGenerationTool = (toolName: string) => { + return imageToolNames.includes(toolName); + }; + + const hasConfigChanged = + tempEnabled !== (imageConfig?.enabled ?? false) || + tempSelectedImageTool !== (imageConfig?.selected_tool ?? null); + const columns = [ { title: "Internal Name", @@ -94,14 +175,16 @@ const ToolsPage = () => { width: 80, render: (_: unknown, tool: Tool) => ( <> - { isModalVisible={isModalVisible} /> + {imageTools.length > 0 && ( + + + (Beta) Image Generation Tools + + + {/* Global Enable/Disable Toggle */} + + + + Image Generation + + {tempEnabled + ? "Image generation tools are available for use in chat sessions" + : "Image generation is disabled for all chat sessions"} + + + + + + {tempEnabled ? "Enabled" : "Disabled"} + + + + + + {/* Tool Selection Section */} +
+ + {imageTools.length > 1 + ? "Select which image generation tool to make available for use in chat sessions. Only one image generation tool can be active at a time." + : "The following image generation tool is available for use in chat sessions."} + + + {imageToolsLoading ? ( + + Loading image generation tools... + + ) : ( + + {imageTools.length === 1 ? ( + // Show simple display for single tool + + + + {imageTools[0].metadata.display_name || + imageTools[0].name} + + + {imageTools[0].metadata.description} + + + + ) : ( + // Show radio buttons for multiple tools + + + + {imageTools.map((tool) => ( + + + + {tool.metadata.display_name || tool.name} + + + {tool.metadata.description} + + + + ))} + + + {imageTools.length > 1 && ( + + + + )} + + )} + + )} +
+ + {/* Save Configuration Button - Always Visible */} + + + + + +
+ )} ); }; diff --git a/ui/src/utils/downloadFile.ts b/ui/src/utils/downloadFile.ts index e5adeaf78..0fb011ada 100644 --- a/ui/src/utils/downloadFile.ts +++ b/ui/src/utils/downloadFile.ts @@ -41,7 +41,7 @@ import messageQueue from "src/utils/messageQueue.ts"; export const downloadFile = async ( url: string, filename: string, - options?: { pageNumber?: string }, + options?: { pageNumber?: string } ) => { const isPdf = filename.toLowerCase().endsWith(".pdf"); @@ -62,7 +62,7 @@ export const downloadFile = async ( window.open( `${objectUrl}#page=${options.pageNumber}`, "_blank", - "noopener", + "noopener" ); // Note: do not revoke immediately to avoid breaking the viewer tab } catch {