Skip to content

Commit 35e268b

Browse files
committed
Add support for custom models
1 parent fa00c9e commit 35e268b

File tree

9 files changed

+123
-21
lines changed

9 files changed

+123
-21
lines changed

AiServer.ServiceInterface/AiProviderFactory.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ public interface IOllamaAiProvider
1717
Task<OllamaGenerationResult> GenerateAsync(AiProvider provider, OllamaGenerate request, CancellationToken token = default);
1818
}
1919

20-
public class AiProviderFactory(OpenAiProvider openAiProvider, OllamaAiProvider ollamaAiProvider, GoogleAiProvider googleProvider, AnthropicAiProvider anthropicAiProvider)
20+
public class AiProviderFactory(
21+
OpenAiProvider openAiProvider,
22+
OllamaAiProvider ollamaAiProvider,
23+
GoogleAiProvider googleProvider,
24+
AnthropicAiProvider anthropicAiProvider,
25+
CustomAiProvider customAiProvider)
2126
{
2227
public IOpenAiProvider GetOpenAiProvider(AiProviderType aiProviderType=AiProviderType.OpenAiProvider)
2328
{
@@ -26,6 +31,7 @@ public IOpenAiProvider GetOpenAiProvider(AiProviderType aiProviderType=AiProvide
2631
AiProviderType.OllamaAiProvider => ollamaAiProvider,
2732
AiProviderType.GoogleAiProvider => googleProvider,
2833
AiProviderType.AnthropicAiProvider => anthropicAiProvider,
34+
AiProviderType.CustomOpenAiProvider => customAiProvider,
2935
_ => openAiProvider
3036
};
3137
}

AiServer.ServiceInterface/OpenAiChatServices.cs

+86-19
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ public object Any(ActiveAiModels request)
4242
.SelectMany(x => x.Models.Select(m => appData.GetQualifiedModel(m.Model)))
4343
.Where(x => x != null)
4444
.Select(x => x!) // Non-null assertion after filtering out null values
45-
.Distinct()
46-
.OrderBy(x => x)
47-
.ToList();
45+
.ToSet();
4846

4947
if (request.Vision == true)
5048
{
@@ -53,12 +51,32 @@ public object Any(ActiveAiModels request)
5351
.Where(x => x.Vision == true)
5452
.Select(x => x.Id)
5553
.ToSet();
56-
activeModels = activeModels.Where(x => allVisionModels.Contains(x.LeftPart(':'))).ToList();
54+
activeModels = activeModels.Where(x => allVisionModels.Contains(x.LeftPart(':'))).ToSet();
55+
}
56+
57+
var customModels = appData.AiProviders
58+
.Where(x => x.AiTypeId == "Custom")
59+
.SelectMany(x => x.SelectedModels);
60+
foreach (var customModel in customModels)
61+
{
62+
activeModels.Add(customModel);
5763
}
5864

5965
return new StringsResponse
6066
{
61-
Results = activeModels
67+
Results = activeModels.OrderBy(x => x).ToList()
68+
};
69+
}
70+
71+
public object Any(ActiveCustomAiModels request)
72+
{
73+
return new StringsResponse
74+
{
75+
Results = appData.AiProviders
76+
.Where(x => x.AiTypeId == "Custom")
77+
.SelectMany(x => x.SelectedModels)
78+
.OrderBy(x => x)
79+
.ToList()
6280
};
6381
}
6482

@@ -104,6 +122,23 @@ public object GetModelImage(string model)
104122
}
105123
}
106124

125+
var customModels = appData.AiProviders
126+
.Where(x => x.AiTypeId == "Custom")
127+
.SelectMany(x => x.SelectedModels)
128+
.ToSet(StringComparer.OrdinalIgnoreCase);
129+
130+
if (customModels.Contains(model))
131+
{
132+
return new HttpResult(
133+
"""
134+
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 32 32">
135+
<path fill="currentColor" d="M19 22v-2h1v-7h-1v-2h4v2h-1v7h1v2zm-3.5 0h2L14 11h-3L7.503 22h2l.601-2h4.778zm-4.794-4l1.628-5.411l.256-.003L14.264 18zM32 4h-4V0h-2v4h-4v2h4v4h2V6h4zm-2 8h2v2h-2zM18 0h2v2h-2z"/>
136+
<path fill="currentColor" d="M32 32H0V0h14v2H2v28h28V18h2z"/>
137+
</svg>
138+
""",
139+
MimeTypes.ImageSvg);
140+
}
141+
107142
return new HttpResult(
108143
"""
109144
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 48 48">
@@ -217,11 +252,19 @@ public QueueOpenAiChatResponse Any(QueueOpenAiChatCompletion request)
217252

218253
if (request.Request.Messages.IsNullOrEmpty())
219254
throw new ArgumentNullException(nameof(request.Request.Messages));
220-
255+
221256
var qualifiedModel = appData.GetQualifiedModel(request.Request.Model);
222-
if (qualifiedModel == null)
257+
258+
var customModel = appData.AiProviders
259+
.Where(x => x.AiTypeId == "Custom")
260+
.SelectMany(x => x.SelectedModels)
261+
.FirstOrDefault(x => x == request.Request.Model);
262+
263+
if (qualifiedModel == null && customModel == null)
223264
throw HttpError.NotFound($"Model {request.Request.Model} not found");
224265

266+
qualifiedModel ??= customModel;
267+
225268
var queueCounts = jobs.GetWorkerQueueCounts();
226269
var providerQueueCount = int.MaxValue;
227270
AiProvider? useProvider = null;
@@ -497,13 +540,23 @@ public object Any(CreateAiProvider request)
497540
request.Models ??= [];
498541
foreach (var selectedModel in request.SelectedModels)
499542
{
500-
var qualifiedModel = appData.GetQualifiedModel(selectedModel);
501-
if (qualifiedModel == null)
502-
continue;
503-
request.Models.Add(new()
543+
if (request.AiTypeId == "Custom")
544+
{
545+
request.Models.Add(new()
546+
{
547+
Model = selectedModel
548+
});
549+
}
550+
else
504551
{
505-
Model = qualifiedModel
506-
});
552+
var qualifiedModel = appData.GetQualifiedModel(selectedModel);
553+
if (qualifiedModel == null)
554+
continue;
555+
request.Models.Add(new()
556+
{
557+
Model = qualifiedModel
558+
});
559+
}
507560
}
508561
}
509562

@@ -515,20 +568,34 @@ public object Any(CreateAiProvider request)
515568
public object Any(UpdateAiProvider request)
516569
{
517570
var ignore = new[] { nameof(request.Id), nameof(request.SelectedModels) };
571+
var provider = Db.SingleById<AiProvider>(request.Id);
572+
if (provider == null)
573+
throw HttpError.NotFound("Provider not found");
574+
518575
// Only call AutoQuery Update if there's something to update
519576
IdResponse? response = null;
520577
if (request.SelectedModels is { Count: > 0 })
521578
{
522579
request.Models ??= [];
523580
foreach (var selectedModel in request.SelectedModels)
524581
{
525-
var qualifiedModel = appData.GetQualifiedModel(selectedModel);
526-
if (qualifiedModel == null)
527-
continue;
528-
request.Models.Add(new()
582+
if (provider.AiTypeId == "Custom")
529583
{
530-
Model = qualifiedModel
531-
});
584+
request.Models.Add(new()
585+
{
586+
Model = selectedModel
587+
});
588+
}
589+
else
590+
{
591+
var qualifiedModel = appData.GetQualifiedModel(selectedModel);
592+
if (qualifiedModel == null)
593+
continue;
594+
request.Models.Add(new()
595+
{
596+
Model = qualifiedModel
597+
});
598+
}
532599
}
533600
}
534601
if (request.ToObjectDictionary().HasNonDefaultValues(ignoreKeys:ignore) || Request!.QueryString[Keywords.Reset] != null)

AiServer.ServiceInterface/OpenAiProvider.cs

+10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ public class OpenAiProvider(ILogger<OpenAiProvider> log) : OpenAiProviderBase(lo
1111
{
1212
}
1313

14+
public class CustomAiProvider(ILogger<OpenAiProvider> log) : OpenAiProviderBase(log)
15+
{
16+
public override Task<OpenAiChatResult> ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token = default)
17+
{
18+
return base.ChatAsync(provider, request, token);
19+
}
20+
}
21+
1422
public class OllamaAiProvider(ILogger<OllamaAiProvider> log) : OpenAiProviderBase(log), IOllamaAiProvider
1523
{
1624
protected virtual async Task<OllamaGenerateResponse> SendOllamaGenerateRequestAsync(AiProvider provider, OllamaGenerate request,
@@ -141,6 +149,8 @@ public string GetApiEndpointUrlFor(AiProvider aiProvider, TaskType taskType)
141149
{
142150
var apiBaseUrl = aiProvider.ApiBaseUrl ?? aiProvider.AiType?.ApiBaseUrl
143151
?? throw new NotSupportedException($"[{aiProvider.Name}] No ApiBaseUrl found in AiProvider or AiType");
152+
if (aiProvider.AiTypeId == "Custom")
153+
return apiBaseUrl;
144154
if (taskType == TaskType.OllamaGenerate)
145155
return apiBaseUrl.CombineWith("/api/generate");
146156
if (taskType == TaskType.OpenAiChat)

AiServer.ServiceModel/AiProvider.cs

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public enum AiProviderType
107107
OpenAiProvider,
108108
GoogleAiProvider,
109109
AnthropicAiProvider,
110+
CustomOpenAiProvider,
110111
}
111112

112113
/// <summary>

AiServer.ServiceModel/ApiAdmin.cs

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ public class ActiveAiModels : IGet, IReturn<StringsResponse>
1919
public bool? Vision { get; set; }
2020
}
2121

22+
[Tag(Tags.AiInfo)]
23+
[Api("Active Custom AI Worker Models available in AI Server")]
24+
public class ActiveCustomAiModels : IGet, IReturn<StringsResponse>
25+
{
26+
}
2227

2328
[Tag(Tags.AiInfo)]
2429
[ValidateApiKey]

AiServer/Configure.AppHost.cs

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public void Configure(IWebHostBuilder builder) => builder
5454
services.AddSingleton<OllamaAiProvider>();
5555
services.AddSingleton<GoogleAiProvider>();
5656
services.AddSingleton<AnthropicAiProvider>();
57+
services.AddSingleton<CustomAiProvider>();
5758
services.AddSingleton<AiProviderFactory>();
5859

5960
services.AddSingleton(new ComfyMediaProviderOptions
Loading

AiServer/wwwroot/lib/data/ai-types.json

+7
Original file line numberDiff line numberDiff line change
@@ -222,5 +222,12 @@
222222
"claude-3-5-sonnet": "claude-3-5-sonnet-latest",
223223
"claude-3-5-haiku": "claude-3-5-haiku-latest"
224224
}
225+
},
226+
{
227+
"id": "Custom",
228+
"provider": "CustomOpenAiProvider",
229+
"website": "https://platform.openai.com/docs/api-reference/chat",
230+
"icon": "/img/providers/custom.svg",
231+
"apiBaseUrl": "http://localhost:8080/v1/chat/completions"
225232
}
226233
]

AiServer/wwwroot/mjs/components/AiProviders.mjs

+5-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,11 @@ export default {
148148
</template>
149149
<template #formfooter="{ form, formInstance, apis, type, model, id, updateModel }">
150150
<div class="pl-6">
151-
<SelectModels v-if="model" :aiTypeId="model?.aiTypeId" :edit="model" @update:modelValue="updateModel(model)" />
151+
<div v-if="providerType=='Custom'" class="pr-6">
152+
<TagInput v-if="model" v-model="model.selectedModels" label="Custom Models" @update:modelValue="updateModel(model)" />
153+
<TagInput v-else v-model="formModel.selectedModels" label="Custom Models" @update:modelValue="updateModel(model)" />
154+
</div>
155+
<SelectModels v-else-if="model" :aiTypeId="model?.aiTypeId" :edit="model" @update:modelValue="updateModel(model)" />
152156
<SelectModels v-else-if="formModel" :aiTypeId="providerType" :edit="formInstance.model" @update:modelValue="formInstance.setModel" />
153157
</div>
154158
</template>

0 commit comments

Comments
 (0)