Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement tests for search index tool #11

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ API_PROVIDER=olly_chat
# Agents, required if API_PROVIDER is agent_framework
ROOT_AGENT_ID=
PPL_AGENT_ID=
SEARCH_INDEX_AGENT_ID=

# Evaluation models, required if LLM or embeddings based test evaluations are used. fallbacks to reading `.chat-assistant-config` if not provided
# ml-commons model id for LLM based requests
Expand Down
2 changes: 1 addition & 1 deletion src/providers/factory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export class ApiProviderFactory {
return new MlCommonsApiProvider();

case PROVIDERS.AGENT_FRAMEWORK:
return new AgentFrameworkApiProvider();
return new AgentFrameworkApiProvider(undefined, options.agentIdKey);

default:
console.info(`$API_PROVIDER unset or invalid, defaulting to ${PROVIDERS.OLLY} provider`);
Expand Down
117 changes: 117 additions & 0 deletions src/runners/search_index/search_index_runner.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch';
import { SearchResponse } from '@opensearch-project/opensearch/api/types';
import { ResponseError } from '@opensearch-project/opensearch/lib/errors';
import { ApiProvider } from 'promptfoo';
import { LevenshteinMatcher } from '../../matchers/levenshtein';
import { PythonMatcher } from '../../matchers/python';
import { openSearchClient } from '../../providers/clients/opensearch';
import { OpenSearchProviderResponse } from '../../providers/types';
import { OpenSearchTestIndices } from '../../utils/indices';
import { TestResult, TestRunner, TestSpec } from '../test_runner';

interface DSLSpec extends TestSpec {
question: string;
gold_query: string;
index: string;
}

class GoldQueryError extends Error {
constructor(message: string) {
super(message);
this.name = 'GoldQueryError';
}
}

export class SearchIndexRunner extends TestRunner<DSLSpec, ApiProvider> {
levenshtein = new LevenshteinMatcher();
query_eval = new PythonMatcher('os_query_eval/eval.py');

protected async beforeAll(clusterStateId: string): Promise<void> {
await OpenSearchTestIndices.init();
await OpenSearchTestIndices.create(clusterStateId, { ignoreExisting: true });
}

private async runDSL(query: string, index: string): Promise<ApiResponse<SearchResponse<string>>> {
return (await openSearchClient.transport.request({
method: 'GET',
path: index + '/_search',
body: query,
})) as ApiResponse<SearchResponse<string>>;
}

private parseDSLSearchResponse(expected: ApiResponse<SearchResponse<string>>): string {
let expected_string = '';
try {
for (let i = 0; i < expected.body.hits.hits.length; i++) {
const responseItem = expected.body.hits.hits[i];
expected_string += JSON.stringify({
_index: responseItem._index,
_source: responseItem._source,
_id: responseItem._id,
_score: responseItem._score,
});
}
} catch (error) {
console.log('error parsing expected DSL response:', error);
}
return expected_string;
}

protected buildInput(spec: DSLSpec): {
prompt: Parameters<ApiProvider['callApi']>[0];
context: Parameters<ApiProvider['callApi']>[1];
} {
return {
prompt: '',
context: { vars: { input: spec.question } },
};
}

public async evaluate(received: OpenSearchProviderResponse, spec: DSLSpec): Promise<TestResult> {
try {
const actual = received.output ?? '';
const expected = await this.runDSL(spec.gold_query, spec.index);
const expected_string: string = this.parseDSLSearchResponse(expected);

const editDistance = (await this.levenshtein.calculateScore(actual, expected_string)).score;
console.info(
` : ${actual}\nExpected query: ${spec.gold_query}\nEdit distance: ${editDistance}`,
);
return {
pass: editDistance >= 0.9,
message: () => `Score ${editDistance} is above 0.9`,
score: editDistance,
extras: {
editDistance,
exception: null,
},
};
} catch (error) {
const result: TestResult & Required<Pick<TestResult, 'extras'>> = {
pass: false,
message: () => `failed to execute query: ${String(error)}`,
score: 0,
extras: {},
};
if (error instanceof GoldQueryError) {
console.error(`[${spec.id}] Invalid gold query: ${spec.gold_query}`);
result.extras.exception = 'Gold query error';
} else if (error instanceof ResponseError) {
const respError = (error as ResponseError<string>).body;
const dslError = JSON.parse(respError) as {
error: { reason: string; details: string; type: string };
status: number;
};
result.extras.exception = dslError.error.type;
} else {
throw error;
}
return result;
}
}
}
18 changes: 18 additions & 0 deletions src/tests/search_index/search_index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import path from 'path';
import { PROVIDERS } from '../../providers/constants';
import { ApiProviderFactory } from '../../providers/factory';
import { SearchIndexRunner } from '../../runners/search_index/search_index_runner';

const provider = ApiProviderFactory.create(PROVIDERS.AGENT_FRAMEWORK, {
agentIdKey: 'SEARCH_INDEX_AGENT_ID',
});
const runner = new SearchIndexRunner(provider);
const specDirectory = path.join(__dirname, 'specs');
const specFiles = [path.join(specDirectory, 'olly_search_index_eval.jsonl')];

runner.run(specFiles);
Loading
Loading