Skip to content

Commit f09dcd0

Browse files
committed
feat(embeddingModel): add embedding model into mongodb
1 parent 7692f71 commit f09dcd0

18 files changed

+194
-86
lines changed

scripts/populate.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import type { User } from "../src/lib/types/User";
1414
import type { Assistant } from "../src/lib/types/Assistant";
1515
import type { Conversation } from "../src/lib/types/Conversation";
1616
import type { Settings } from "../src/lib/types/Settings";
17-
import { defaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts";
17+
import { getDefaultEmbeddingModel } from "../src/lib/server/embeddingModels.ts";
1818
import { Message } from "../src/lib/types/Message.ts";
1919

2020
import { addChildren } from "../src/lib/utils/tree/addChildren.ts";
@@ -146,6 +146,7 @@ async function seed() {
146146
updatedAt: faker.date.recent({ days: 30 }),
147147
customPrompts: {},
148148
assistants: [],
149+
disableStream: false,
149150
};
150151
await collections.settings.updateOne(
151152
{ userId: user._id },
@@ -214,7 +215,7 @@ async function seed() {
214215
: faker.helpers.maybe(() => faker.hacker.phrase(), { probability: 0.5 })) ?? "";
215216

216217
const messages = await generateMessages(preprompt);
217-
218+
const defaultEmbeddingModel = await getDefaultEmbeddingModel();
218219
const conv = {
219220
_id: new ObjectId(),
220221
userId: user._id,
@@ -224,7 +225,7 @@ async function seed() {
224225
updatedAt: faker.date.recent({ days: 145 }),
225226
model: faker.helpers.arrayElement(modelIds),
226227
title: faker.internet.emoji() + " " + faker.hacker.phrase(),
227-
embeddingModel: defaultEmbeddingModel.id,
228+
embeddingModel: defaultEmbeddingModel.name,
228229
messages,
229230
rootMessageId: messages[0].id,
230231
} satisfies Conversation;

src/hooks.server.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import { initExitHandler } from "$lib/server/exitHandler";
1616
import { ObjectId } from "mongodb";
1717
import { refreshAssistantsCounts } from "$lib/jobs/refresh-assistants-counts";
1818
import { refreshConversationStats } from "$lib/jobs/refresh-conversation-stats";
19+
import { pupulateEmbeddingModel } from "$lib/server/embeddingModels";
1920

2021
// TODO: move this code on a started server hook, instead of using a "building" flag
2122
if (!building) {
@@ -25,6 +26,9 @@ if (!building) {
2526
if (env.ENABLE_ASSISTANTS) {
2627
refreshAssistantsCounts();
2728
}
29+
30+
await pupulateEmbeddingModel();
31+
2832
refreshConversationStats();
2933

3034
// Init metrics server

src/lib/server/database.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import type { AssistantStats } from "$lib/types/AssistantStats";
1616
import { logger } from "$lib/server/logger";
1717
import { building } from "$app/environment";
1818
import { onExit } from "./exitHandler";
19+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
1920

2021
export const CONVERSATION_STATS_COLLECTION = "conversations.stats";
2122

@@ -83,6 +84,7 @@ export class Database {
8384
const bucket = new GridFSBucket(db, { bucketName: "files" });
8485
const migrationResults = db.collection<MigrationResult>("migrationResults");
8586
const semaphores = db.collection<Semaphore>("semaphores");
87+
const embeddingModels = db.collection<EmbeddingModel>("embeddingModels");
8688

8789
return {
8890
conversations,
@@ -99,6 +101,7 @@ export class Database {
99101
bucket,
100102
migrationResults,
101103
semaphores,
104+
embeddingModels,
102105
};
103106
}
104107

@@ -120,6 +123,7 @@ export class Database {
120123
sessions,
121124
messageEvents,
122125
semaphores,
126+
embeddingModels,
123127
} = this.getCollections();
124128

125129
conversations
@@ -209,6 +213,8 @@ export class Database {
209213
semaphores
210214
.createIndex({ createdAt: 1 }, { expireAfterSeconds: 60 })
211215
.catch((e) => logger.error(e));
216+
217+
embeddingModels.createIndex({ name: 1 }, { unique: true }).catch((e) => logger.error(e));
212218
}
213219
}
214220

src/lib/server/embeddingEndpoints/embeddingEndpoints.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
embeddingEndpointOpenAIParametersSchema,
1313
} from "./openai/embeddingEndpoints";
1414
import { embeddingEndpointHfApi, embeddingEndpointHfApiSchema } from "./hfApi/embeddingHfApi";
15+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
1516

1617
// parameters passed when generating text
1718
interface EmbeddingEndpointParameters {
@@ -33,8 +34,8 @@ export const embeddingEndpointSchema = z.discriminatedUnion("type", [
3334
type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
3435

3536
// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint
36-
export type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
37-
inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }>
37+
type EmbeddingEndpointGenerator<T extends EmbeddingEndpointTypeOptions> = (
38+
inputs: Extract<z.infer<typeof embeddingEndpointSchema>, { type: T }> & { model: EmbeddingModel }
3839
) => EmbeddingEndpoint | Promise<EmbeddingEndpoint>;
3940

4041
// list of all endpoint generators

src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
33
import { chunk } from "$lib/utils/chunk";
44
import { env } from "$env/dynamic/private";
55
import { logger } from "$lib/server/logger";
6+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
67

78
export const embeddingEndpointHfApiSchema = z.object({
89
weight: z.number().int().positive().default(1),
9-
model: z.any(),
1010
type: z.literal("hfapi"),
1111
authorization: z
1212
.string()
1313
.optional()
1414
.transform((v) => (!v && env.HF_TOKEN ? "Bearer " + env.HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
1515
});
1616

17+
type EmbeddingEndpointHfApiInput = z.input<typeof embeddingEndpointHfApiSchema> & {
18+
model: EmbeddingModel;
19+
};
20+
1721
export async function embeddingEndpointHfApi(
18-
input: z.input<typeof embeddingEndpointHfApiSchema>
22+
input: EmbeddingEndpointHfApiInput
1923
): Promise<EmbeddingEndpoint> {
20-
const { model, authorization } = embeddingEndpointHfApiSchema.parse(input);
21-
const url = "https://api-inference.huggingface.co/models/" + model.id;
24+
const { model } = input;
25+
const { authorization } = embeddingEndpointHfApiSchema.parse(input);
26+
const url = "https://api-inference.huggingface.co/models/" + model.name;
2227

2328
return async ({ inputs }) => {
2429
const batchesInputs = chunk(inputs, 128);

src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@ import { z } from "zod";
22
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
33
import { chunk } from "$lib/utils/chunk";
44
import { env } from "$env/dynamic/private";
5+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
56

67
export const embeddingEndpointOpenAIParametersSchema = z.object({
78
weight: z.number().int().positive().default(1),
8-
model: z.any(),
99
type: z.literal("openai"),
1010
url: z.string().url().default("https://api.openai.com/v1/embeddings"),
1111
apiKey: z.string().default(env.OPENAI_API_KEY),
1212
defaultHeaders: z.record(z.string()).default({}),
1313
});
1414

15+
type EmbeddingEndpointOpenAIInput = z.input<typeof embeddingEndpointOpenAIParametersSchema> & {
16+
model: EmbeddingModel;
17+
};
18+
1519
export async function embeddingEndpointOpenAI(
16-
input: z.input<typeof embeddingEndpointOpenAIParametersSchema>
20+
input: EmbeddingEndpointOpenAIInput
1721
): Promise<EmbeddingEndpoint> {
18-
const { url, model, apiKey, defaultHeaders } =
19-
embeddingEndpointOpenAIParametersSchema.parse(input);
22+
const { model } = input;
23+
const { url, apiKey, defaultHeaders } = embeddingEndpointOpenAIParametersSchema.parse(input);
2024

2125
const maxBatchSize = model.maxBatchSize || 100;
2226

src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
33
import { chunk } from "$lib/utils/chunk";
44
import { env } from "$env/dynamic/private";
55
import { logger } from "$lib/server/logger";
6+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
67

78
export const embeddingEndpointTeiParametersSchema = z.object({
89
weight: z.number().int().positive().default(1),
9-
model: z.any(),
1010
type: z.literal("tei"),
1111
url: z.string().url(),
1212
authorization: z
@@ -35,10 +35,15 @@ const getModelInfoByUrl = async (url: string, authorization?: string) => {
3535
}
3636
};
3737

38+
type EmbeddingEndpointTeiInput = z.input<typeof embeddingEndpointTeiParametersSchema> & {
39+
model: EmbeddingModel;
40+
};
41+
3842
export async function embeddingEndpointTei(
39-
input: z.input<typeof embeddingEndpointTeiParametersSchema>
43+
input: EmbeddingEndpointTeiInput
4044
): Promise<EmbeddingEndpoint> {
41-
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
45+
const { model } = input;
46+
const { url, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
4247

4348
const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
4449
const maxBatchSize = Math.min(

src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import { z } from "zod";
22
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
33
import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers";
44
import { pipeline } from "@xenova/transformers";
5+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
56

67
export const embeddingEndpointTransformersJSParametersSchema = z.object({
78
weight: z.number().int().positive().default(1),
8-
model: z.any(),
99
type: z.literal("transformersjs"),
1010
});
1111

@@ -36,10 +36,16 @@ export async function calculateEmbedding(modelName: string, inputs: string[]) {
3636
return output.tolist();
3737
}
3838

39+
type EmbeddingEndpointTransformersJSInput = z.input<
40+
typeof embeddingEndpointTransformersJSParametersSchema
41+
> & {
42+
model: EmbeddingModel;
43+
};
44+
3945
export function embeddingEndpointTransformersJS(
40-
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
46+
input: EmbeddingEndpointTransformersJSInput
4147
): EmbeddingEndpoint {
42-
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
48+
const { model } = input;
4349

4450
return async ({ inputs }) => {
4551
return calculateEmbedding(model.name, inputs);

src/lib/server/embeddingModels.ts

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum";
55
import {
66
embeddingEndpoints,
77
embeddingEndpointSchema,
8-
type EmbeddingEndpoint,
98
} from "$lib/server/embeddingEndpoints/embeddingEndpoints";
109
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
1110

1211
import JSON5 from "json5";
12+
import type { EmbeddingModel } from "$lib/types/EmbeddingModel";
13+
import { collections } from "./database";
14+
import { ObjectId } from "mongodb";
1315

1416
const modelConfig = z.object({
1517
/** Used as an identifier in DB */
@@ -42,67 +44,77 @@ const rawEmbeddingModelJSON =
4244

4345
const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON));
4446

45-
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
46-
...m,
47-
id: m.id || m.name,
47+
const embeddingModels = embeddingModelsRaw.map((rawEmbeddingModel) => {
48+
const embeddingModel: EmbeddingModel = {
49+
name: rawEmbeddingModel.name,
50+
description: rawEmbeddingModel.description,
51+
websiteUrl: rawEmbeddingModel.websiteUrl,
52+
modelUrl: rawEmbeddingModel.modelUrl,
53+
chunkCharLength: rawEmbeddingModel.chunkCharLength,
54+
maxBatchSize: rawEmbeddingModel.maxBatchSize,
55+
preQuery: rawEmbeddingModel.preQuery,
56+
prePassage: rawEmbeddingModel.prePassage,
57+
_id: new ObjectId(),
58+
createdAt: new Date(),
59+
updatedAt: new Date(),
60+
endpoints: rawEmbeddingModel.endpoints,
61+
};
62+
63+
return embeddingModel;
4864
});
4965

50-
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
51-
...m,
52-
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
53-
if (!m.endpoints) {
54-
return embeddingEndpointTransformersJS({
55-
type: "transformersjs",
56-
weight: 1,
57-
model: m,
58-
});
59-
}
66+
export const getEmbeddingEndpoint = async (embeddingModel: EmbeddingModel) => {
67+
if (!embeddingModel.endpoints) {
68+
return embeddingEndpointTransformersJS({
69+
type: "transformersjs",
70+
weight: 1,
71+
model: embeddingModel,
72+
});
73+
}
6074

61-
const totalWeight = sum(m.endpoints.map((e) => e.weight));
62-
63-
let random = Math.random() * totalWeight;
64-
65-
for (const endpoint of m.endpoints) {
66-
if (random < endpoint.weight) {
67-
const args = { ...endpoint, model: m };
68-
69-
switch (args.type) {
70-
case "tei":
71-
return embeddingEndpoints.tei(args);
72-
case "transformersjs":
73-
return embeddingEndpoints.transformersjs(args);
74-
case "openai":
75-
return embeddingEndpoints.openai(args);
76-
case "hfapi":
77-
return embeddingEndpoints.hfapi(args);
78-
default:
79-
throw new Error(`Unknown endpoint type: ${args}`);
80-
}
75+
const totalWeight = sum(embeddingModel.endpoints.map((e) => e.weight));
76+
77+
let random = Math.random() * totalWeight;
78+
79+
for (const endpoint of embeddingModel.endpoints) {
80+
if (random < endpoint.weight) {
81+
const args = { ...endpoint, model: embeddingModel };
82+
console.log(args.type);
83+
84+
switch (args.type) {
85+
case "tei":
86+
return embeddingEndpoints.tei(args);
87+
case "transformersjs":
88+
return embeddingEndpoints.transformersjs(args);
89+
case "openai":
90+
return embeddingEndpoints.openai(args);
91+
case "hfapi":
92+
return embeddingEndpoints.hfapi(args);
93+
default:
94+
throw new Error(`Unknown endpoint type: ${args}`);
8195
}
82-
83-
random -= endpoint.weight;
8496
}
8597

86-
throw new Error(`Failed to select embedding endpoint`);
87-
},
88-
});
89-
90-
export const embeddingModels = await Promise.all(
91-
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
92-
);
93-
94-
export const defaultEmbeddingModel = embeddingModels[0];
98+
random -= endpoint.weight;
99+
}
95100

96-
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
97-
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
101+
throw new Error(`Failed to select embedding endpoint`);
98102
};
99103

100-
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
101-
return validateEmbeddingModel(_models, "id");
102-
};
104+
export const getDefaultEmbeddingModel = async (): Promise<EmbeddingModel> => {
105+
if (!embeddingModels[0]) {
106+
throw new Error(`Failed to find default embedding endpoint`);
107+
}
108+
109+
const defaultModel = await collections.embeddingModels.findOne({
110+
_id: embeddingModels[0]._id,
111+
});
103112

104-
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
105-
return validateEmbeddingModel(_models, "name");
113+
return defaultModel ? defaultModel : embeddingModels[0];
106114
};
107115

108-
export type EmbeddingBackendModel = typeof defaultEmbeddingModel;
116+
// to mimic current behaivor with creating embedding models from scratch during server start
117+
export async function pupulateEmbeddingModel() {
118+
await collections.embeddingModels.deleteMany({});
119+
await collections.embeddingModels.insertMany(embeddingModels);
120+
}

0 commit comments

Comments
 (0)