1
+ import os
2
+ import glob
3
+ import re
4
+ import json
5
+ import torch
6
+ import torch .utils .data
7
+ from transformers import AutoTokenizer , AutoModel
8
+ from tqdm import tqdm
9
+
10
+ tokenizer = AutoTokenizer .from_pretrained ("THUDM/chatglm2-6b" , trust_remote_code = True )
11
+ model = AutoModel .from_pretrained ("THUDM/chatglm2-6b" , trust_remote_code = True ).bfloat16 ().cuda ()
12
+
13
+ choices = ["A" , "B" , "C" , "D" ]
14
+ choice_tokens = [tokenizer .encode (choice , add_special_tokens = False )[0 ] for choice in choices ]
15
+
16
+
17
+ def build_prompt (text ):
18
+ return "[Round {}]\n \n 问:{}\n \n 答:" .format (1 , text )
19
+
20
+
21
+ extraction_prompt = '综上所述,ABCD中正确的选项是:'
22
+
23
+ accuracy_dict , count_dict = {}, {}
24
+ with torch .no_grad ():
25
+ for entry in glob .glob ("./CEval/val/**/*.jsonl" , recursive = True ):
26
+ dataset = []
27
+ with open (entry , encoding = 'utf-8' ) as file :
28
+ for line in file :
29
+ dataset .append (json .loads (line ))
30
+ correct = 0
31
+ dataloader = torch .utils .data .DataLoader (dataset , batch_size = 8 )
32
+ for batch in tqdm (dataloader ):
33
+ texts = batch ["inputs_pretokenized" ]
34
+ queries = [build_prompt (query ) for query in texts ]
35
+ inputs = tokenizer (queries , padding = True , return_tensors = "pt" , truncation = True , max_length = 2048 ).to ('cuda' )
36
+ outputs = model .generate (** inputs , do_sample = False , max_new_tokens = 512 )
37
+ intermediate_outputs = []
38
+ for idx in range (len (outputs )):
39
+ output = outputs .tolist ()[idx ][len (inputs ["input_ids" ][idx ]):]
40
+ response = tokenizer .decode (output )
41
+ intermediate_outputs .append (response )
42
+ answer_texts = [text + intermediate + "\n " + extraction_prompt for text , intermediate in
43
+ zip (texts , intermediate_outputs )]
44
+ input_tokens = [build_prompt (answer_text ) for answer_text in answer_texts ]
45
+ inputs = tokenizer (input_tokens , padding = True , return_tensors = "pt" , truncation = True , max_length = 2048 ).to ('cuda' )
46
+ outputs = model (** inputs , return_last_logit = True )
47
+ logits = outputs .logits [:, - 1 ]
48
+ logits = logits [:, choice_tokens ]
49
+ preds = logits .argmax (dim = - 1 )
50
+ correct += (preds .cpu () == batch ["label" ]).sum ().item ()
51
+ accuracy = correct / len (dataset )
52
+ print (entry , accuracy )
53
+ accuracy_dict [entry ] = accuracy
54
+ count_dict [entry ] = len (dataset )
55
+
56
+ acc_total , count_total = 0.0 , 0
57
+ for key in accuracy_dict :
58
+ acc_total += accuracy_dict [key ] * count_dict [key ]
59
+ count_total += count_dict [key ]
60
+ print (acc_total / count_total )
0 commit comments