Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initialize tool selection test runner #16

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/providers/constants/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const PROVIDERS = {
PPL_GENERATOR: 'ppl_generator',
ML_COMMONS: 'ml_commons',
AGENT_FRAMEWORK: 'agent_framework',
TOOL_SELECTION: 'tool_selection',
} as const;

export const OPENSEARCH_CONFIG = {
Expand Down
5 changes: 4 additions & 1 deletion src/providers/factory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import { ApiProvider } from 'promptfoo';
import { PROVIDERS } from '../constants';
import { AgentFrameworkApiProvider, MlCommonsApiProvider } from '../ml_commons';
import { AgentFrameworkApiProvider, MlCommonsApiProvider, ToolSelectionApiProvider } from '../ml_commons';
import { OllyApiProvider, OllyPPLGeneratorApiProvider } from '../olly';

type Provider = (typeof PROVIDERS)[keyof typeof PROVIDERS];
Expand All @@ -31,6 +31,9 @@ export class ApiProviderFactory {
case PROVIDERS.AGENT_FRAMEWORK:
return new AgentFrameworkApiProvider(undefined, options.agentIdKey);

case PROVIDERS.TOOL_SELECTION:
return new ToolSelectionApiProvider(options.agentIdKey);

default:
console.info(`$API_PROVIDER unset or invalid, defaulting to ${PROVIDERS.OLLY} provider`);
case PROVIDERS.OLLY:
Expand Down
1 change: 1 addition & 0 deletions src/providers/ml_commons/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

export { AgentFrameworkApiProvider } from './agent_framework';
export { MlCommonsApiProvider } from './ml_commons';
export { ToolSelectionApiProvider } from './tool_selection'
114 changes: 114 additions & 0 deletions src/providers/ml_commons/tool_selection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch';
import { ApiProvider } from 'promptfoo';
import { openSearchClient } from '../clients/opensearch';
import { PROVIDERS } from '../constants';
import { OpenSearchProviderResponse } from '../types';

interface AgentResponse {
inference_results: Array<{
output: Array<{
name: string;
result?: string;
}>;
}>;
}

/**
* Api Provider to request a agent.
*/
export class ToolSelectionApiProvider implements ApiProvider {
constructor(
private readonly agentIdKey = 'ROOT_AGENT_ID',
) {}

id() {
return PROVIDERS.TOOL_SELECTION;
}

private getAgentId() {
const id = process.env[this.agentIdKey];
if (!id) throw new Error(`${this.agentIdKey} environment variable not set`);
return id;
}

private async getToolSelectedByPrompt(
prompt?: string,
context?: { vars: Record<string, string | object> },
) {
const agentId = this.getAgentId();
const response = (await openSearchClient.transport.request({
method: 'POST',
path: `/_plugins/_ml/agents/${agentId}/_execute`,
body: JSON.stringify({ parameters: { question: prompt, ...context?.vars } }),
}, {
/**
* It is time-consuming for LLM to generate final answer
* Give it a large timeout window
*/
requestTimeout: 5 * 60 * 1000,
/**
* Do not retry
*/
maxRetries: 0,
})) as ApiResponse<AgentResponse, unknown>;

const outputResponse =
response.body.inference_results[0].output.find((output) => output.name === 'parent_interaction_id') ??
response.body.inference_results[0].output[0];
const interactionId = outputResponse.result;
if (!interactionId) throw new Error('Cannot find interaction id from agent response');

const tracesResp = (await openSearchClient.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/message/${interactionId}/traces`,
})) as ApiResponse<{
traces: Array<{
message_id: string;
create_time: string;
input: string;
response: string;
origin: string;
trace_number: number;
}>;
}>;

const firstTrace = tracesResp.body.traces?.find(item => item.origin && item.trace_number && item.origin !== 'LLM');
return firstTrace?.origin || '';
}

async callApi(
prompt?: string,
context?: { vars: Record<string, string | object> },
): Promise<OpenSearchProviderResponse> {
let toolSelected: string = '';
let retryTimes = 0;
let error;
do {
try {
toolSelected = await this.getToolSelectedByPrompt(prompt, context);
} catch (e) {
error = e;
}

if (!toolSelected) {
retryTimes++;
if (retryTimes >= 3) {
break;
}
console.warn(`No tool selected, retry prompt: ${prompt}, retryTimes: ${retryTimes}`);
await new Promise(resolve => setTimeout(resolve, 1000));
continue;
}
} while (!toolSelected)
if (toolSelected) {
return { output: toolSelected };
} else {
return { error: `question: ${prompt}, API call error: ${String(error || '')}` };
}
}
}
28 changes: 28 additions & 0 deletions src/runners/tool-selection/tool-selection-runner.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiProvider } from 'promptfoo';
import { OpenSearchProviderResponse } from '../../providers/types';
import { TestResult, TestRunner, TestSpec } from '../test_runner';

interface ToolSelectionSpec extends TestSpec {
tool: string;
}

export class ToolSelectionRunner extends TestRunner<ToolSelectionSpec, ApiProvider> {
public async evaluate(received: OpenSearchProviderResponse, spec: ToolSelectionSpec): Promise<TestResult> {
const infoMessage = `Question: ${spec.question}\nReceived selected tool: ${received.output}\nExpected selected tool: ${spec.tool}`;
console.info(infoMessage);
const match = received.output === spec.tool;
return {
pass: match,
message: () => infoMessage,
score: match ? 1 : 0,
extras: {
exception: null,
},
};
}
}
Loading
Loading