-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate_q_a_dataset.py
More file actions
316 lines (264 loc) · 13.8 KB
/
generate_q_a_dataset.py
File metadata and controls
316 lines (264 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import json
from finetuning.finetuning_prompts import FinetuningPrompts
from backporting_handler import BackportingHandler
from logger_refactored import FinetuneLogger
from finetuning.constants import COMMITS_DETAILS, FINETUNE_MODEL_NAME, QnA_DATA_DIR, CUSTOM_COMMIT_DETAILS_DIR, TRAINING_DATA_DIR, PACKAGE_NAME, PACKAGE_LANGUAGE
from llm_handler import RunLLM
from finetuning.azureLLM_handler import AzureLLMHandler
import os
import re
from datetime import datetime
class Generate_Q_A_Dataset:
def __init__(self):
self.prompts = FinetuningPrompts()
self.logger = FinetuneLogger()
if FINETUNE_MODEL_NAME.startswith("gpt-"):
self.llm = AzureLLMHandler()
else:
self.llm = RunLLM()
print("Starting Q&A Dataset Generation Process")
self.logger.log_info("Starting Q&A Dataset Generation Process")
timestamp = datetime.now().strftime("%d-%b-%Y_%H-%M").upper()
self.QnA_dir = os.path.join(QnA_DATA_DIR, f"{FINETUNE_MODEL_NAME}")
os.makedirs(self.QnA_dir, exist_ok=True)
self.QnA_file = os.path.join(self.QnA_dir, f"qna_{timestamp}.jsonl")
def fetch_cve_patches(self):
self.logger.log_info("Fetching CVE List and Corresponding Patches")
self.patch_data = {}
self.backporting_handler = BackportingHandler()
self.all_cves = self.backporting_handler.getCVEList()
for cve in self.all_cves:
self.patch_data[cve] = self.backporting_handler.getUpstreamPatchForCVE(cve)
self.logger.log_info(f"Total CVEs fetched: {len(self.all_cves)}")
print(f"Total CVEs fetched: {len(self.all_cves)}")
def generate_from_llm(self, system_prompt, user_prompt, prompt_type, commitfile, cve=None):
statement = f"Generating Q&A pairs for {prompt_type}, commit {commitfile}"
if cve:
statement += f", CVE {cve}"
self.logger.log_info(f"{statement}...")
print(statement + "...")
if FINETUNE_MODEL_NAME.startswith("gpt-"):
# Default: temperature=0.7, max_tokens=4000, top_p=0.9
output = self.llm.call_azure_openai_for_qna_schema(system_prompt, user_prompt, max_tokens=12000, temperature=0.1)
else:
# Default: temperature=0.0, top_p=1.0, top_k=0
output = self.llm.generate_base_output_with_separate_prompts(system_prompt, user_prompt, max_new_tokens=8000, temperature=0.1)
print(f"Q&A pairs for {prompt_type} generated from LLM")
self.logger.log_info(f"Q&A pairs for {prompt_type} generated from LLM")
self.logger.log_generated_output(prompt_type, str(output), commit=commitfile, cve_number=cve)
return output
def getPrompts(self, prompt_type, commit_data, commitfile, cve=None, cve_hunk=None):
self.logger.log_info(f"Generating prompts for {prompt_type}")
system_prompt, user_prompt = self.prompts.getPrompts(prompt_type, patch_hunk=cve_hunk, commit_data=commit_data)
self.logger.log_info(f"Prompts generated for {prompt_type}")
self.logger.log_prompt(f"{prompt_type}_system", system_prompt, commit=commitfile, cve_number=cve)
self.logger.log_prompt(f"{prompt_type}_user", user_prompt, commit=commitfile, cve_number=cve)
return system_prompt, user_prompt
def store_qna_pairs(self, qna_pairs, prompt_type):
self.logger.log_info(f"Converting Q&A pairs to JSON lines format and appending to {self.QnA_file}")
if FINETUNE_MODEL_NAME.startswith("gpt-"):
data = qna_pairs
else:
try:
data = json.loads(qna_pairs)
except json.JSONDecodeError as e:
print(qna_pairs)
self.logger.log_info(f"❌ ERROR: Failed to parse Q&A pairs from LLM output.: {e}")
print("❌ ERROR: Failed to parse Q&A pairs from LLM output.")
raise e
if len(data) == 0:
self.logger.log_info(f"[info] No Q&A pairs generated by LLM for {prompt_type}")
return
self.logger.log_info("Q&A pairs converted to JSON format")
for qna in data:
if FINETUNE_MODEL_NAME.startswith("gpt-"):
with open(self.QnA_file, "a") as qna_file:
qna_file.write(json.dumps(qna) + "\n")
else:
question = qna.get("question", "").strip()
answer = qna.get("answer", "").strip()
with open(self.QnA_file, "a") as qna_file:
qna_file.write(f'{{"question": "{question}", "answer": "{answer}"}}\n')
print(f"Q&A pairs for {prompt_type} appended to {self.QnA_file}")
self.logger.log_info(f"✅ Q&A pairs for {prompt_type} appended to {self.QnA_file}")
def handle_llm_output(self, system_prompt, user_prompt, prompt_type, commitfile, cve=None, retry_count=0):
try:
output = self.generate_from_llm(system_prompt, user_prompt, prompt_type, commitfile, cve)
self.store_qna_pairs(output, prompt_type)
except Exception as e:
# append json error prompt to the output and retry
retry_count += 1
if retry_count > 2:
self.logger.log_info(f"❌ ERROR: Failed to generate valid Q&A pairs after multiple attempts for {prompt_type}. Skipping...")
print(f"❌ ERROR: Failed to generate valid Q&A pairs after multiple attempts for {prompt_type}. Skipping...")
return retry_count
json_error_system_prompt, json_error_user_prompt = self.prompts.getPrompts("JSON_ERROR", error=str(e), output=output)
user_prompt += json_error_user_prompt
self.logger.log_info("Retrying with JSON error correction prompts")
if retry_count == 1:
self.logger.log_info(f"First retry attempt for {prompt_type} due to error: {e}")
print(f"First retry attempt for {prompt_type} due to error: {e}")
system_prompt += json_error_system_prompt
else:
self.logger.log_info(f"Second retry attempt for {prompt_type} due to error: {e}")
print(f"Second retry attempt for {prompt_type} due to error: {e}")
self.handle_llm_output(system_prompt, user_prompt, prompt_type, commitfile, cve, retry_count)
return retry_count
def generate_dataset(self):
self.fetch_cve_patches()
commit_dir = COMMITS_DETAILS # full commit history since PACKAGE_VERSION
# commit_dir = CUSTOM_COMMIT_DETAILS_DIR # custom commits, made on top of origin/master HEAD
for commitfile in os.listdir(commit_dir):
self.logger.log_info(f"\nProcessing commit file: {commitfile}")
filepath = os.path.join(commit_dir, commitfile)
with open(filepath, 'r') as f:
commit_data = f.read()
self.logger.log_info(f"Commit data read from {commitfile}")
self.logger.log_input("commit_data", commit_data, commit=commitfile)
prompt_types = ["COMMIT_DETAILS", "FOCUSED_COMMIT_DETAILS"]
# prompt_types = ["FOCUSED_COMMIT_DETAILS"]
for prompt_type in prompt_types:
system_prompts, user_prompts = self.getPrompts(prompt_type, commit_data=commit_data, commitfile=commitfile)
self.handle_llm_output(system_prompts, user_prompts, prompt_type, commitfile)
self.logger.log_info("Processing patch hunks for the commit")
print("Processing patch hunks for the commit...")
for cve in self.all_cves:
hunk = self.patch_data.get(cve, "")
self.logger.log_info(f"Processing CVE {cve} with its patch data")
self.logger.log_input("CVE_PATCH", hunk, commit=commitfile, cve_number=cve)
# print(patch)
# hunks = split_git_patch(patch)
# print(hunks)
# self.logger.log_info(f"Total hunks found for CVE {cve}: {len(hunks)}")
# for hunk in hunks:
prompt_types = ["COMMIT_TO_HUNK_CHANGES", "PATCH_BACKPORT"]
for p_type in prompt_types:
print(p_type)
system_prompts, user_prompts = self.getPrompts(p_type, commit_data=commit_data, commitfile=commitfile, cve_hunk=hunk, cve=cve)
self.handle_llm_output(system_prompts, user_prompts, p_type, commitfile, cve)
self.logger.log_info(f"Completed processing for CVE {cve}")
print(f"Completed processing for CVE {cve}")
self.logger.log_info(f"Completed processing for commit file: {commitfile}")
print(f"Completed processing for commit file: {commitfile}")
self.logger.log_info("✅ Q&A Dataset Generation Process Completed")
print("✅ Q&A Dataset Generation Process Completed")
# def split_git_patch(patch_text):
# print("Called this split function")
# file_splits = re.split(r'(?=^diff --git )', patch_text, flags=re.MULTILINE)
# hunks = []
# for file_block in file_splits:
# if not file_block.strip():
# continue
# # find all hunks inside this file block
# hunk_matches = re.finditer(
# r'(^@@.*?$(?:\n.*?)*?)(?=^@@|\Z)',
# file_block,
# flags=re.MULTILINE
# )
# file_header_match = re.split(r'^@@', file_block, 1, flags=re.MULTILINE)
# file_header = file_header_match[0] if len(file_header_match) > 1 else ""
# for match in hunk_matches:
# hunk_content = match.group(1)
# full_hunk = file_header + hunk_content
# hunks.append(full_hunk.strip("\n"))
# return hunks
def prepare_dataset_in_proper_format():
training_QnAs = [
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_07-15.jsonl',
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_14-08.jsonl',
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_17-06.jsonl',
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_17-06.jsonl',
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_17-30.jsonl',
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/gpt-4o/qna_17-SEP-2025_17-31.jsonl',
]
validation_QnAs = [
'/home/sumsharma/madhur/backporting-llm/training_llm/finetuning/data/QnA/manual-validation/validation_qna.jsonl',
]
system_content = f"""
You are an expert software developer with deep knowlege of {PACKAGE_LANGUAGE} programming language.
You have in-depth knowlege about the commit history of {PACKAGE_NAME} package.
Answer questions about how the files, functions and lines of code were changed over range of commits.
"""
all_chats = []
# for file_path in training_QnAs:
for file_path in validation_QnAs:
with open(file_path, "r", encoding="utf-8") as infile:
for line in infile:
if not line.strip():
continue
qa = json.loads(line)
chat_format = {
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": qa["question"]},
{"role": "assistant", "content": qa["answer"]}
]
}
all_chats.append(chat_format)
training_data_file = os.path.join(TRAINING_DATA_DIR, f"{FINETUNE_MODEL_NAME}_training_data.jsonl")
os.makedirs(TRAINING_DATA_DIR, exist_ok=True)
validation_training_data_file = os.path.join(TRAINING_DATA_DIR, f"{FINETUNE_MODEL_NAME}_validation_data.jsonl")
# with open(training_data_file, "w", encoding="utf-8") as outfile:
with open(validation_training_data_file, "w", encoding="utf-8") as outfile:
for chat in all_chats:
outfile.write(json.dumps(chat, ensure_ascii=False) + "\n")
def main():
generator = Generate_Q_A_Dataset()
generator.generate_dataset()
def test_git_patch_split():
test_patch = """
From a1b2c3d4e5f6g7h8i9j0 Mon Sep 17 00:00:00 2001
From: Jane Doe <jane@example.com>
Date: Mon, 16 Sep 2024 12:34:56 +0530
Subject: [PATCH] Refactor logging and fix buffer handling
Co-Author: John Smith <john@example.com>
---
src/logger.c | 6 +++---
src/buffer.c | 7 ++++---
2 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/src/logger.c b/src/logger.c
index 1234567..89abcde 100644
--- a/src/logger.c
+++ b/src/logger.c
@@ -10,7 +10,7 @@ void init_logger() {
log_level = LOG_INFO;
- fprintf(stderr, "Logger initialized\n");
+ fprintf(stdout, "Logger initialized at INFO level\n");
}
@@ -25,6 +25,7 @@ void log_message(const char *msg) {
- fprintf(stderr, "LOG: %s\n", msg);
+ fprintf(stdout, "[LOG] %s\n", msg);
+ fflush(stdout);
}
diff --git a/src/buffer.c b/src/buffer.c
index abc1234..def5678 100644
--- a/src/buffer.c
+++ b/src/buffer.c
@@ -42,7 +42,7 @@ int buffer_write(Buffer *buf, const char *data, size_t len) {
- if (len > buf->capacity) {
- return -1;
- }
+ if (len >= buf->capacity) {
+ fprintf(stderr, "Buffer overflow attempt\n");
+ return -1;
+ }
memcpy(buf->data, data, len);
buf->size = len;
return 0;
@@ -75,6 +75,7 @@ void buffer_clear(Buffer *buf) {
buf->size = 0;
- memset(buf->data, 0, buf->capacity);
+ memset(buf->data, 0, buf->capacity);
+ fprintf(stdout, "Buffer cleared\n");
}
--
GitLab
"""
# hunks = split_git_patch(test_patch)
# for i, hunk in enumerate(hunks):
# print(f"\n--- HUNK {i} ---\n")
# print(hunk)
if __name__ == "__main__":
# main()
# test_git_patch_split()
prepare_dataset_in_proper_format()