Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/app/routes/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from app.usecases.global_config import (
get_logo_path,
get_global_available_models,
get_default_model,
)

router = APIRouter(tags=["config"])
Expand All @@ -12,8 +13,10 @@
def get_global_config():
"""Get global configuration including available models."""
global_models = get_global_available_models()
default_model = get_default_model()
logo_path = get_logo_path()
return {
"globalAvailableModels": global_models,
"defaultModel": default_model,
"logoPath": logo_path,
}
5 changes: 4 additions & 1 deletion backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from app.stream import ConverseApiStreamHandler, OnStopInput, OnThinking
from app.usecases.bot import fetch_bot, modify_bot_last_used_time, modify_bot_stats
from app.usecases.global_config import get_default_model
from app.user import User
from app.utils import get_current_time
from app.vector_search import (
Expand Down Expand Up @@ -633,8 +634,10 @@ def chat_output_from_message(
def propose_conversation_title(
user_id: str,
conversation_id: str,
model: type_model_name = "claude-v3-haiku",
) -> str:
# Use the configured default model for generating conversation titles
model = get_default_model()

PROMPT = """Reading the conversation above, what is the appropriate title for the conversation? When answering the title, please follow the rules below:
<rules>
- Title length must be from 15 to 20 characters.
Expand Down
10 changes: 10 additions & 0 deletions backend/app/usecases/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import logging
import os

from app.routes.schemas.conversation import type_model_name

logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s - %(message)s")
logger = logging.getLogger(__name__)

GLOBAL_AVAILABLE_MODELS = os.environ.get("GLOBAL_AVAILABLE_MODELS")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL")
LOGO_PATH = os.environ.get("LOGO_PATH", "")


Expand Down Expand Up @@ -39,6 +42,13 @@ def get_global_available_models() -> list[str]:
return []


def get_default_model() -> type_model_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you suggested, there are several "default models" on this application:

  • Title generation
  • New chat screen
  • etc

Why don't we have get_default_model which returns BaseSchema? e.g.

class DefaultModelSchema(BaseSchema):
  title_generation_model: type_model_name
  chat_model: type_model_name
  # ...

"""Return the configured default model."""
if not DEFAULT_MODEL:
raise ValueError("DEFAULT_MODEL environment variable must be set")
return DEFAULT_MODEL # type: ignore[return-value]


def get_logo_path() -> str:
"""Return the configured drawer logo path."""
return LOGO_PATH
1 change: 1 addition & 0 deletions cdk/bin/bedrock-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const chat = new BedrockChatStack(
enableBotStoreReplicas: params.enableBotStoreReplicas,
botStoreLanguage: params.botStoreLanguage,
globalAvailableModels: params.globalAvailableModels,
defaultModel: params.defaultModel,
tokenValidMinutes: params.tokenValidMinutes,
devAccessIamRoleArn: params.devAccessIamRoleArn,
allowedCountries: params.allowedCountries,
Expand Down
2 changes: 2 additions & 0 deletions cdk/lib/bedrock-chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export interface BedrockChatStackProps extends StackProps {
readonly enableBotStoreReplicas: boolean;
readonly botStoreLanguage: Language;
readonly globalAvailableModels?: string[];
readonly defaultModel?: string;
readonly tokenValidMinutes: number;
readonly alternateDomainName?: string;
readonly hostedZoneId?: string;
Expand Down Expand Up @@ -225,6 +226,7 @@ export class BedrockChatStack extends cdk.Stack {
enableLambdaSnapStart: props.enableLambdaSnapStart,
openSearchEndpoint: botStore?.openSearchEndpoint,
globalAvailableModels: props.globalAvailableModels,
defaultModel: props.defaultModel,
logoPath: props.logoPath,
});
props.documentBucket.grantReadWrite(backendApi.handler);
Expand Down
2 changes: 2 additions & 0 deletions cdk/lib/constructs/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export interface ApiProps {
readonly enableLambdaSnapStart: boolean;
readonly openSearchEndpoint?: string;
readonly globalAvailableModels?: string[];
readonly defaultModel?: string;
readonly logoPath?: string;
}

Expand Down Expand Up @@ -268,6 +269,7 @@ export class Api extends Construct {
GLOBAL_AVAILABLE_MODELS: props.globalAvailableModels
? JSON.stringify(props.globalAvailableModels)
: "[]",
DEFAULT_MODEL: props.defaultModel!,
OPENSEARCH_DOMAIN_ENDPOINT: props.openSearchEndpoint || "",
LOGO_PATH: props.logoPath || "",
USE_STRANDS: "true",
Expand Down
4 changes: 4 additions & 0 deletions cdk/lib/utils/parameter-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ const BedrockChatParametersSchema = BaseParametersSchema.extend({
// If not configured (empty array), all models are available
globalAvailableModels: z.array(z.string()).default([]),

// Default model to be selected when user first visits the app
defaultModel: z.string().default("claude-v3.7-sonnet"),

// Frontend branding
logoPath: z.string().default(""),

Expand Down Expand Up @@ -245,6 +248,7 @@ export function resolveBedrockChatParameters(
enableBotStoreReplicas: app.node.tryGetContext("EnableBotStoreReplicas"),
botStoreLanguage: app.node.tryGetContext("botStoreLanguage"),
globalAvailableModels: app.node.tryGetContext("globalAvailableModels"),
defaultModel: app.node.tryGetContext("defaultModel"),
logoPath: app.node.tryGetContext("logoPath"),
devAccessIamRoleArn: app.node.tryGetContext("devAccessIamRoleArn"),
};
Expand Down
1 change: 1 addition & 0 deletions frontend/src/@types/global-config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { AVAILABLE_MODEL_KEYS } from '../constants/index';

export interface GlobalConfig {
globalAvailableModels: string[];
defaultModel?: string;
logoPath?: string;
}

Expand Down
30 changes: 25 additions & 5 deletions frontend/src/components/SwitchBedrockModel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { BaseProps } from '../@types/common';
import useModel from '../hooks/useModel';
import { Popover, Transition } from '@headlessui/react';
import { Fragment } from 'react/jsx-runtime';
import { useMemo } from 'react';
import { useMemo, useEffect } from 'react';
import { PiCaretDown, PiCheck } from 'react-icons/pi';
import { ActiveModels } from '../@types/bot';
import { toCamelCase } from '../utils/StringUtils';
Expand All @@ -17,6 +17,7 @@ const SwitchBedrockModel: React.FC<Props> = (props) => {
availableModels: allModels,
modelId,
setModelId,
getDefaultModel,
} = useModel(props.botId, props.activeModels);

const availableModels = useMemo(() => {
Expand All @@ -32,11 +33,30 @@ const SwitchBedrockModel: React.FC<Props> = (props) => {
});
}, [allModels, props.activeModels]);

const modelName = useMemo(() => {
return (
availableModels.find((model) => model.modelId === modelId)?.label ?? ''
// Automatically switch to the default model if the current model is not available
useEffect(() => {
const isCurrentModelAvailable = availableModels.some(
(model) => model.modelId === modelId
);
}, [availableModels, modelId]);

if (!isCurrentModelAvailable && availableModels.length > 0) {
const defaultModelId = getDefaultModel();
if (defaultModelId) {
setModelId(defaultModelId);
}
}
}, [availableModels, modelId, setModelId, getDefaultModel]);

const modelName = useMemo(() => {
const foundModel = availableModels.find((model) => model.modelId === modelId);
if (foundModel) {
return foundModel.label;
}
// Fallback to the default model's label if the current model is not found
const defaultModelId = getDefaultModel();
const defaultModel = availableModels.find((model) => model.modelId === defaultModelId);
return defaultModel?.label ?? '';
}, [availableModels, modelId, getDefaultModel]);

return (
<div className="">
Expand Down
31 changes: 18 additions & 13 deletions frontend/src/hooks/useModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ const LLAMA_SUPPORTED_MEDIA_TYPES = [
'image/webp',
];

const DEFAULT_MODEL: Model = 'claude-v3.7-sonnet';

const useModelState = create<{
modelId: Model;
setModelId: (m: Model) => void;
}>((set) => ({
modelId: DEFAULT_MODEL,
modelId: '' as Model, // Will be set by useEffect based on config/localStorage
setModelId: (m) => {
set({
modelId: m,
Expand Down Expand Up @@ -285,7 +283,7 @@ const useModel = (botId?: string | null, activeModels?: ActiveModels) => {
const { modelId, setModelId } = useModelState();
const [recentUseModelId, setRecentUseModelId] = useLocalStorage(
'recentUseModelId',
DEFAULT_MODEL
'' // Will use getDefaultModel() if localStorage is empty
);

// Save the model id by each bot
Expand All @@ -306,16 +304,22 @@ const useModel = (botId?: string | null, activeModels?: ActiveModels) => {
}, [processedActiveModels, availableModels]);

const getDefaultModel = useCallback(() => {
// check default model is available
const defaultModelAvailable = filteredModels.some(
(m: ModelItem) => m.modelId === DEFAULT_MODEL
);
if (defaultModelAvailable) {
return DEFAULT_MODEL;
// Use the default model from global config if available
const configDefaultModel = globalConfig?.defaultModel as Model | undefined;

if (configDefaultModel) {
// Check if the configured default model is available
const defaultModelAvailable = filteredModels.some(
(m: ModelItem) => m.modelId === configDefaultModel
);
if (defaultModelAvailable) {
return configDefaultModel;
}
}
// If the default model is not available, select the first model on the list
return filteredModels[0]?.modelId ?? DEFAULT_MODEL;
}, [filteredModels]);

// If config default is not available or not set yet, select the first model
return filteredModels[0]?.modelId;
}, [filteredModels, globalConfig?.defaultModel]);

// select the model via list of activeModels
const selectModel = useCallback(
Expand Down Expand Up @@ -405,6 +409,7 @@ const useModel = (botId?: string | null, activeModels?: ActiveModels) => {
}) ?? [],
availableModels: filteredModels,
forceReasoningEnabled: model?.forceReasoningEnabled ?? false,
getDefaultModel,
};
};

Expand Down