Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prompt_builder: add list of header files in code fixing prompt #318

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion data_prep/introspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
INTROSPECTOR_TYPE = ''
INTROSPECTOR_FUNC_SIG = ''
INTROSPECTOR_ADDR_TYPE = ''
INTROSPECTOR_ALL_HEADER_FILES = ''


def get_oracle_dict() -> Dict[str, Any]:
Expand All @@ -66,7 +67,8 @@ def set_introspector_endpoints(endpoint):
global INTROSPECTOR_ENDPOINT, INTROSPECTOR_CFG, INTROSPECTOR_FUNC_SIG, \
INTROSPECTOR_FUNCTION_SOURCE, INTROSPECTOR_PROJECT_SOURCE, \
INTROSPECTOR_XREF, INTROSPECTOR_TYPE, INTROSPECTOR_ORACLE_FAR_REACH, \
INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE
INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, \
INTROSPECTOR_ALL_HEADER_FILES

INTROSPECTOR_ENDPOINT = endpoint
logging.info('Fuzz Introspector endpoint set to %s', INTROSPECTOR_ENDPOINT)
Expand All @@ -83,6 +85,7 @@ def set_introspector_endpoints(endpoint):
INTROSPECTOR_FUNC_SIG = f'{INTROSPECTOR_ENDPOINT}/function-signature'
INTROSPECTOR_ADDR_TYPE = (
f'{INTROSPECTOR_ENDPOINT}/addr-to-recursive-dwarf-info')
INTROSPECTOR_ALL_HEADER_FILES = (f'{INTROSPECTOR_ENDPOINT}/all-header-files')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for () if this fits into one line.



def _construct_url(api: str, params: dict) -> str:
Expand Down Expand Up @@ -182,6 +185,13 @@ def query_introspector_cfg(project: str) -> dict:
return _get_data(resp, 'project', {})


def query_introspector_header_files(project: str) -> List[str]:
resp = _query_introspector(INTROSPECTOR_ALL_HEADER_FILES,
{'project': project})
all_header_files = _get_data(resp, 'all-header-files', [])
return all_header_files


def query_introspector_function_source(project: str, func_sig: str) -> str:
"""Queries FuzzIntrospector API for source code of |func_sig|."""
resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, {
Expand Down
19 changes: 16 additions & 3 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import requests
import yaml

from data_prep import project_targets
from data_prep import introspector, project_targets
from experiment.benchmark import Benchmark, FileType
from experiment.fuzz_target_error import SemanticCheckResult
from llm_toolkit import models, prompts
Expand All @@ -46,6 +46,7 @@
'jansi_colors-problem.txt')
FDP_JVM_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH,
'jansi_colors-solution.java')
HEADER_FIXER_PROMPT = os.path.join(DEFAULT_TEMPLATE_DIR, 'header_fixer.txt')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you add this file yet?


EXAMPLES = {
'c++': [
Expand Down Expand Up @@ -271,7 +272,7 @@ def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
"""Prepares the code-fixing prompt."""
priming, priming_weight = self._format_fixer_priming()
problem = self._format_fixer_problem(raw_code, error_desc, errors,
priming_weight)
priming_weight, benchmark)

self._prepare_prompt(priming, problem)
return self._prompt
Expand All @@ -287,7 +288,8 @@ def _format_fixer_priming(self) -> Tuple[str, int]:
return priming, priming_weight

def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str],
errors: list[str], priming_weight: int) -> str:
errors: list[str], priming_weight: int,
benchmark: Benchmark) -> str:
"""Formats a problem for code fixer based on the template."""
with open(self.fixer_problem_template_file) as f:
problem = f.read().strip()
Expand All @@ -297,6 +299,17 @@ def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str],
else:
# Build error does not pass error desc.
error_summary = BUILD_ERROR_SUMMARY
headers_to_avoid = introspector.query_introspector_header_files(
benchmark.project)
if len(headers_to_avoid) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if headers_to_avoid: for simplicity

with open(HEADER_FIXER_PROMPT, 'r') as f:
header_avoid_string = f.read()
for header_file in headers_to_avoid:
header_avoid_string += '- %s\n' % (os.path.basename(header_file))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-string, please :)

else:
header_avoid_string = ''
problem = problem.replace('{ADDITIONAL_MESSAGE}', header_avoid_string)

problem = problem.replace('{ERROR_SUMMARY}', error_summary)

problem_prompt = self._prompt.create_prompt_piece(problem, 'user')
Expand Down
3 changes: 2 additions & 1 deletion prompts/template_xml/fixer_problem.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ Fix code:
2. Choose a solution that can maximize fuzzing result, which is utilizing the function under test and feeding it not null input.
3. Apply the solutions to the original code.
It's important to show the complete code, not only the fixed line.
<solution>
{ADDITIONAL_MESSAGE}
<solution>