diff --git a/src/renderer/src/stores/settings.ts b/src/renderer/src/stores/settings.ts index 92cd72415..b3ab3d26f 100644 --- a/src/renderer/src/stores/settings.ts +++ b/src/renderer/src/stores/settings.ts @@ -762,70 +762,104 @@ export const useSettingsStore = defineStore('settings', () => { // 更新本地模型状态,不触发后端请求 const updateLocalModelStatus = (providerId: string, modelId: string, enabled: boolean) => { - // 更新allProviderModels中的模型状态 - const providerIndex = allProviderModels.value.findIndex((p) => p.providerId === providerId) - if (providerIndex !== -1) { - const models = allProviderModels.value[providerIndex].models - const modelIndex = models.findIndex((m) => m.id === modelId) - if (modelIndex !== -1) { - models[modelIndex].enabled = enabled + const provider = allProviderModels.value.find((p) => p.providerId === providerId) + const customProvider = customModels.value.find((p) => p.providerId === providerId) + + const providerModel = provider?.models.find((m) => m.id === modelId) + if (providerModel) { + providerModel.enabled = enabled + } + + const customModel = customProvider?.models.find((m) => m.id === modelId) + if (customModel) { + customModel.enabled = enabled + } + + let enabledProvider = enabledModels.value.find((p) => p.providerId === providerId) + let updatedEnabledModels: { providerId: string; models: RENDERER_MODEL_META[] }[] | null = null + + if (!enabledProvider && enabled) { + enabledProvider = { + providerId, + models: [] } + updatedEnabledModels = [...enabledModels.value, enabledProvider] } - // 更新enabledModels中的模型状态 - const enabledProviderIndex = enabledModels.value.findIndex((p) => p.providerId === providerId) - if (enabledProviderIndex !== -1) { - const models = enabledModels.value[enabledProviderIndex].models + if (enabledProvider) { + const models = enabledProvider.models + const modelIndex = models.findIndex((m) => m.id === modelId) + if (enabled) { - // 如果启用,确保模型在列表中 - const modelIndex = models.findIndex((m) => m.id === modelId) - if (modelIndex === -1) { - // 模型不在启用列表中,从allProviderModels查找并添加 - const provider = allProviderModels.value.find((p) => p.providerId === providerId) - const model = provider?.models.find((m) => m.id === modelId) - if (model) { - models.push({ - ...model, - enabled: true, - vision: model.vision || false, - functionCall: model.functionCall || false, - reasoning: model.reasoning || false, - type: model.type || ModelType.Chat - }) + const sourceModel = providerModel ?? customModel ?? models[modelIndex] + if (sourceModel) { + const normalizedModel: RENDERER_MODEL_META = { + ...sourceModel, + enabled: true, + vision: sourceModel.vision ?? false, + functionCall: sourceModel.functionCall ?? false, + reasoning: sourceModel.reasoning ?? false, + type: sourceModel.type ?? ModelType.Chat + } + + if (modelIndex === -1) { + models.push(normalizedModel) + } else { + models[modelIndex] = normalizedModel } } - } else { - // 如果禁用,从列表中移除 - const modelIndex = models.findIndex((m) => m.id === modelId) - if (modelIndex !== -1) { - models.splice(modelIndex, 1) - } + } else if (modelIndex !== -1) { + models.splice(modelIndex, 1) } - } - // 更新customModels中的模型状态 - const customProviderIndex = customModels.value.findIndex((p) => p.providerId === providerId) - if (customProviderIndex !== -1) { - const models = customModels.value[customProviderIndex].models - const modelIndex = models.findIndex((m) => m.id === modelId) - if (modelIndex !== -1) { - models[modelIndex].enabled = enabled + if (!enabled && enabledProvider.models.length === 0) { + updatedEnabledModels = enabledModels.value.filter((p) => p.providerId !== providerId) } } - // 强制触发响应式更新 - enabledModels.value = [...enabledModels.value] + if (!updatedEnabledModels) { + updatedEnabledModels = [...enabledModels.value] + } + + enabledModels.value = updatedEnabledModels console.log('enabledModels updated:', enabledModels.value) } + const getLocalModelEnabledState = (providerId: string, modelId: string): boolean | null => { + const provider = allProviderModels.value.find((p) => p.providerId === providerId) + const providerModel = provider?.models.find((m) => m.id === modelId) + if (providerModel) { + return !!providerModel.enabled + } + + const customProvider = customModels.value.find((p) => p.providerId === providerId) + const customModel = customProvider?.models.find((m) => m.id === modelId) + if (customModel) { + return !!customModel.enabled + } + + const enabledProvider = enabledModels.value.find((p) => p.providerId === providerId) + if (enabledProvider) { + return enabledProvider.models.some((model) => model.id === modelId) + } + + return null + } + // 更新模型状态 const updateModelStatus = async (providerId: string, modelId: string, enabled: boolean) => { + const previousState = getLocalModelEnabledState(providerId, modelId) + updateLocalModelStatus(providerId, modelId, enabled) + try { await llmP.updateModelStatus(providerId, modelId, enabled) // 调用成功后,刷新该 provider 的模型列表 await refreshProviderModels(providerId) } catch (error) { console.error('Failed to update model status:', error) + if (previousState !== null && previousState !== enabled) { + updateLocalModelStatus(providerId, modelId, previousState) + } } } @@ -966,13 +1000,42 @@ export const useSettingsStore = defineStore('settings', () => { // 更新provider的启用状态 const updateProviderStatus = async (providerId: string, enable: boolean): Promise => { - // 更新时间戳 + const providerIndex = providers.value.findIndex((p) => p.id === providerId) + const previousProvider = providerIndex !== -1 ? { ...providers.value[providerIndex] } : null + + if (providerIndex !== -1) { + const nextProviders = [...providers.value] + nextProviders[providerIndex] = { + ...nextProviders[providerIndex], + enable + } + providers.value = nextProviders + } + + const previousTimestamp = providerTimestamps.value[providerId] providerTimestamps.value[providerId] = Date.now() - // 保存时间戳 - await saveProviderTimestamps() - await updateProviderConfig(providerId, { enable }) - await optimizeProviderOrder(providerId, enable) + try { + await saveProviderTimestamps() + await updateProviderConfig(providerId, { enable }) + await optimizeProviderOrder(providerId, enable) + } catch (error) { + if (providerIndex !== -1 && previousProvider) { + const revertedProviders = [...providers.value] + revertedProviders[providerIndex] = previousProvider + providers.value = revertedProviders + } + + if (previousTimestamp === undefined) { + delete providerTimestamps.value[providerId] + } else { + providerTimestamps.value[providerId] = previousTimestamp + } + + await saveProviderTimestamps() + console.error('Failed to update provider status:', error) + throw error + } } const optimizeProviderOrder = async (providerId: string, enable: boolean): Promise => {