Skip to content

Commit 686c604

Browse files
committed
Allow setting default LLM args on instance level
Closes #20
1 parent d76c46b commit 686c604

File tree

5 files changed

+79
-12
lines changed

5 files changed

+79
-12
lines changed

packages/litlytics/engine/runPrompt.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import type { CoreMessage, CoreTool } from 'ai';
1+
import type { CoreMessage } from 'ai';
22
import type { LLMProviders } from '../litlytics';
33
import { executeOnLLM } from '../llm/llm';
4-
import type { LLMModel, LLMProvider, LLMRequest } from '../llm/types';
4+
import type { LLMArgs, LLMModel, LLMProvider, LLMRequest } from '../llm/types';
55

66
export interface RunPromptFromMessagesArgs {
77
provider: LLMProviders;
88
key: string;
99
model: LLMModel;
1010
messages: CoreMessage[];
11-
args?: Record<string, CoreTool>;
11+
args?: LLMArgs;
1212
}
1313
export const runPromptFromMessages = async ({
1414
provider,
@@ -36,7 +36,7 @@ export interface RunPromptArgs {
3636
model: LLMModel;
3737
system: string;
3838
user: string;
39-
args?: Record<string, CoreTool>;
39+
args?: LLMArgs;
4040
}
4141
export const runPrompt = async ({
4242
provider,

packages/litlytics/litlytics.ts

+14-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import {
99
import { runStep, type RunStepArgs } from './engine/runStep';
1010
import { runLLMStep, type RunLLMStepArgs } from './engine/step/runLLMStep';
1111
import { testPipelineStep } from './engine/testStep';
12-
import type { LLMModel, LLMProvider } from './llm/types';
12+
import type { LLMArgs, LLMModel, LLMProvider } from './llm/types';
1313
import { OUTPUT_ID } from './output/Output';
1414
import {
1515
pipelineFromText,
@@ -38,6 +38,7 @@ export { modelCosts } from './llm/costs';
3838
export {
3939
LLMModelsList,
4040
LLMProvidersList,
41+
type LLMArgs,
4142
type LLMModel,
4243
type LLMProvider,
4344
} from './llm/types';
@@ -70,6 +71,7 @@ export class LitLytics {
7071
// model config
7172
provider?: LLMProviders;
7273
model?: LLMModel;
74+
llmArgs?: LLMArgs;
7375
#llmKey?: string;
7476

7577
// pipeline
@@ -82,14 +84,17 @@ export class LitLytics {
8284
provider,
8385
model,
8486
key,
87+
llmArgs,
8588
}: {
8689
provider: LLMProviders;
8790
model: LLMModel;
8891
key: string;
92+
llmArgs?: LLMArgs;
8993
}) {
9094
this.provider = provider;
9195
this.model = model;
9296
this.#llmKey = key;
97+
this.llmArgs = llmArgs;
9398
}
9499

95100
/**
@@ -179,7 +184,10 @@ export class LitLytics {
179184
key: this.#llmKey,
180185
model: this.model,
181186
messages,
182-
args,
187+
args: {
188+
...args,
189+
...this.llmArgs,
190+
},
183191
});
184192
};
185193

@@ -202,7 +210,10 @@ export class LitLytics {
202210
model: this.model,
203211
system,
204212
user,
205-
args,
213+
args: {
214+
...args,
215+
...this.llmArgs,
216+
},
206217
});
207218
};
208219

packages/litlytics/llm/types.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { CoreMessage, CoreTool } from 'ai';
1+
import type { CoreMessage, generateText } from 'ai';
22

33
export const LLMProvidersList = [
44
'openai',
@@ -40,10 +40,12 @@ export const LLMModelsList = {
4040
export type LLMModel =
4141
(typeof LLMModelsList)[keyof typeof LLMModelsList][number];
4242

43+
export type LLMArgs = Partial<Parameters<typeof generateText>[0]>;
44+
4345
export interface LLMRequest {
4446
provider: LLMProvider;
4547
key: string;
4648
model: LLMModel;
4749
messages: CoreMessage[];
48-
modelArgs?: Record<string, CoreTool>;
50+
modelArgs?: LLMArgs;
4951
}

packages/litlytics/step/Step.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import type { CoreTool, LanguageModelUsage } from 'ai';
1+
import type { LanguageModelUsage } from 'ai';
22
import type { Doc } from '../doc/Document';
3+
import type { LLMArgs } from '../llm/types';
34

45
export interface StepResult {
56
stepId: string;
@@ -46,7 +47,7 @@ export interface ProcessingStep extends BaseStep {
4647
input?: StepInput;
4748
// llm
4849
prompt?: string;
49-
llmArgs?: Record<string, CoreTool>;
50+
llmArgs?: LLMArgs;
5051
// code
5152
code?: string;
5253
codeExplanation?: string;

packages/litlytics/test/litlytics.test.ts

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import type { LanguageModelUsage } from 'ai';
2-
import { expect, test } from 'vitest';
2+
import { expect, test, vi } from 'vitest';
3+
import * as run from '../engine/runPrompt';
34
import {
45
LitLytics,
56
OUTPUT_ID,
67
type Doc,
8+
type LLMArgs,
79
type Pipeline,
810
type StepInput,
911
} from '../litlytics';
@@ -229,3 +231,54 @@ test('should generate suggested tasks for current pipeline', async () => {
229231
const newNonTestDoc = litlytics.docs.find((d) => d.id === docNonTest.id);
230232
expect(newNonTestDoc?.summary).toBeUndefined();
231233
});
234+
235+
test('should pass llm args when running prompt', async () => {
236+
const testArgs: LLMArgs = {
237+
temperature: 0.5,
238+
maxTokens: 1000,
239+
};
240+
const litlytics = new LitLytics({
241+
provider: 'openai',
242+
model: 'test',
243+
key: 'test',
244+
llmArgs: testArgs,
245+
});
246+
litlytics.pipeline.pipelineDescription = 'test description';
247+
248+
const testResult = `Step name: Generate Title and Description
249+
Step type: llm
250+
Step input: doc
251+
Step description: Generate an Etsy product title and description based on the provided document describing the product.
252+
253+
---
254+
255+
Step name: Check for Copyrighted Terms
256+
Step type: llm
257+
Step input: result
258+
Step description: Analyze the generated title and description for possible copyrighted terms and suggest edits.
259+
`;
260+
261+
// mock prompt replies
262+
const spy = vi
263+
.spyOn(run, 'runPrompt')
264+
.mockImplementation(
265+
async ({
266+
user,
267+
args,
268+
}: {
269+
system: string;
270+
user: string;
271+
args?: LLMArgs;
272+
}) => {
273+
expect(args).toEqual(testArgs);
274+
expect(user).toEqual('test description');
275+
return { result: testResult, usage: {} as LanguageModelUsage };
276+
}
277+
);
278+
// run generation
279+
await litlytics.generatePipeline();
280+
// check that spy was called
281+
expect(spy).toHaveBeenCalled();
282+
// cleanup
283+
spy.mockClear();
284+
});

0 commit comments

Comments
 (0)