-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathgenerate_answer.py
85 lines (82 loc) · 3.32 KB
/
generate_answer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import copy
import json
import os
from tqdm import tqdm
from anthropic import AnthropicVertex
import multiprocessing as mp
from datasets import load_dataset
retriever = 'qwen'
PROJECT_ID = ""
LOCATION = ""
claude_client = AnthropicVertex(region=LOCATION, project_id=PROJECT_ID)
def worker(arg):
exec_count = 0
documentation = ''
for did,d in enumerate(arg['documents']):
documentation += f'---------- Document {did+1} start ----------\n{d.strip()}\n---------- Document {did+1} end ----------\n'
while exec_count<50:
try:
exec_count += 1
message = claude_client.messages.create(
max_tokens=4096,
messages=[
{
"role": "user",
"content": f"{documentation}\n"
f"---------- post start ----------\n"
f"{arg['query'].strip()}\n"
f"---------- post end ----------\n\n"
f"Let's think step by step to address the query and give a detailed answer.",
}
],
model="claude-3-5-sonnet@20240620",
temperature=0,
)
response = json.loads(message.model_dump_json(indent=2))
completion = response['content'][0]['text']
output = copy.deepcopy(arg)
output['pred'] = completion
with open(os.path.join(f'outputs/{retriever}_retrieval/{arg["task"]}', f"{arg['id']}.json"), 'w') as f:
json.dump(output, f, indent=2)
return
except Exception as e:
print(e)
pass
for task in ['biology','earth_science','economics','psychology','robotics','stackoverflow','sustainable_living']:
data = []
examples_hf = load_dataset('xlangai/BRIGHT', 'examples')[task]
examples = {}
for e in examples_hf:
examples[e['id']] = e
documents_hf = load_dataset('xlangai/BRIGHT', 'documents')[task]
documents = {}
for d in documents_hf:
documents[d['id']] = d['content']
with open(f"../0617/outputs/{task}_{retriever}_long_False/score.json") as f:
scores = json.load(f)
examples_hf = load_dataset('xlangai/BRIGHT', 'examples')[task]
for e in examples_hf:
# for task_file in os.listdir(f'qa_data/{task}'):
# if not task_file.endswith('.json'):
# continue
# with open(os.path.join(f'qa_data/{task}',task_file)) as f:
# e = json.load(f)
eid = e['id']
cur_scores = sorted(scores[eid].items(),key=lambda x:x[1],reverse=True)[:10]
selected_ids = [doc_score[0] for doc_score in cur_scores]
cur_documents = []
for doc_id in selected_ids:
cur_documents.append(documents[doc_id])
# assert len(e["content"])>=10
data.append({
'id': e["id"],
'query': e['query'],
'gold': e["gold_answer"],
'task':task,
'documents': cur_documents
})
if not os.path.isdir(f'outputs/{retriever}_retrieval/{task}'):
os.makedirs(f'outputs/{retriever}_retrieval/{task}')
with mp.Pool(64) as pool, tqdm(total=len(data), desc=task) as pbar:
for return_contents in pool.imap_unordered(worker, data):
pbar.update()