|
| 1 | +import { TextGenerationPipeline, type Chat } from '@xenova/transformers'; |
| 2 | +import type { CoreMessage, LanguageModelUsage } from 'ai'; |
| 3 | + |
| 4 | +export const runWithWebLLM = async ({ |
| 5 | + pipeline, |
| 6 | + messages, |
| 7 | + args, |
| 8 | +}: { |
| 9 | + pipeline: TextGenerationPipeline; |
| 10 | + messages: CoreMessage[]; |
| 11 | + args?: { max_tokens?: number; temperature?: number }; |
| 12 | +}) => { |
| 13 | + if (!pipeline) { |
| 14 | + throw new Error('Transformers pipeline is required for generation!'); |
| 15 | + } |
| 16 | + console.log(messages); |
| 17 | + |
| 18 | + // map messages to chat |
| 19 | + const msgs: Chat[] = messages.map((m) => { |
| 20 | + return { |
| 21 | + role: m.role, |
| 22 | + content: m.content as string, |
| 23 | + } as unknown as Chat; |
| 24 | + }); |
| 25 | + |
| 26 | + // generate |
| 27 | + const out = await pipeline(msgs, { |
| 28 | + max_new_tokens: args?.max_tokens, |
| 29 | + temperature: args?.temperature, |
| 30 | + }); |
| 31 | + console.log(JSON.stringify(out, null, 2)); |
| 32 | + // [{'label': 'POSITIVE', 'score': 0.999817686}] |
| 33 | + |
| 34 | + const usage: LanguageModelUsage = { |
| 35 | + completionTokens: 0, |
| 36 | + promptTokens: 0, |
| 37 | + totalTokens: 0, |
| 38 | + }; |
| 39 | + return { |
| 40 | + result: out, |
| 41 | + usage, |
| 42 | + }; |
| 43 | + |
| 44 | + /* const reply = await generate.chat.completions.create({ |
| 45 | + messages: messages as ChatCompletionMessageParam[], |
| 46 | + ...args, |
| 47 | + }); |
| 48 | + const answers = reply.choices; |
| 49 | + const usage: LanguageModelUsage = { |
| 50 | + completionTokens: reply.usage?.completion_tokens ?? 0, |
| 51 | + promptTokens: reply.usage?.prompt_tokens ?? 0, |
| 52 | + totalTokens: reply.usage?.total_tokens ?? 0, |
| 53 | + }; |
| 54 | + console.log(answers.map((a) => a.message.content)); |
| 55 | + const result = answers[0].message.content; |
| 56 | + return { |
| 57 | + result, |
| 58 | + usage, |
| 59 | + };*/ |
| 60 | +}; |
0 commit comments