Skip to content

Commit

Permalink
Add C-specific prompt (#338)
Browse files Browse the repository at this point in the history
This implements the first step of
#337

Adds a harness generation flow that, in comparison to the existing
default builder:
- Provides repository link for the target project.
- Is C-specific, uses no CPP code language or similar.
- Includes post-processing on the generated code to add certain header
files we always want in the harnesses.
- Adds constraints on header files the LLM should include in the
harnesses. Does this by providing absolute paths to header files in the
OSS-Fuzz containers.
- Uses some new fuzz introspector APIs to help with context.

This PR was made to have no intrusion on the existing workflow, i.e.
experiments can continue as they are running now. However, there are
several improvements that can be made and I prefer to have these in
follow-up PRs:

1) Fixing logic relies on the default prompt builder. This is because
the code fixer creates a new prompt builder
https://github.com/google/oss-fuzz-gen/blob/09d2235f3957c4d43367ecbd7f3f88147b487abf/llm_toolkit/code_fixer.py#L408
This in fact means that the C++ default logic is used for fixing JVM
targets. I would like to change the flow here in the medium term such
that the code fixing logic reuses the one we used for main harness
generation. I think this should be changed so the prompt builder comes
closer to a "harness generator" abstraction and has more knowledge of
the target under analysis. But, I prefer to do this later as the PR is
already big.
2) Integrate so we can run experiments in the CI with bother or either
harness generation flows.
3) Add new features to the prompt builder.

Ref: #337

---------

Signed-off-by: David Korczynski <[email protected]>
  • Loading branch information
DavidKorczynski authored Jun 18, 2024
1 parent 3bd98aa commit 8d8a8bd
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 6 deletions.
39 changes: 38 additions & 1 deletion data_prep/introspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
INTROSPECTOR_TYPE = ''
INTROSPECTOR_FUNC_SIG = ''
INTROSPECTOR_ADDR_TYPE = ''
INTROSPECTOR_ALL_HEADER_FILES = ''
INTROSPECTOR_ALL_FUNC_TYPES = ''
INTROSPECTOR_SAMPLE_XREFS = ''


def get_oracle_dict() -> Dict[str, Any]:
Expand All @@ -69,7 +72,8 @@ def set_introspector_endpoints(endpoint):
INTROSPECTOR_FUNCTION_SOURCE, INTROSPECTOR_PROJECT_SOURCE, \
INTROSPECTOR_XREF, INTROSPECTOR_TYPE, INTROSPECTOR_ORACLE_FAR_REACH, \
INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, \
INTROSPECTOR_ORACLE_EASY_PARAMS
INTROSPECTOR_ALL_HEADER_FILES, INTROSPECTOR_ALL_FUNC_TYPES, \
INTROSPECTOR_SAMPLE_XREFS, INTROSPECTOR_ORACLE_EASY_PARAMS

INTROSPECTOR_ENDPOINT = endpoint
logging.info('Fuzz Introspector endpoint set to %s', INTROSPECTOR_ENDPOINT)
Expand All @@ -88,6 +92,10 @@ def set_introspector_endpoints(endpoint):
INTROSPECTOR_FUNC_SIG = f'{INTROSPECTOR_ENDPOINT}/function-signature'
INTROSPECTOR_ADDR_TYPE = (
f'{INTROSPECTOR_ENDPOINT}/addr-to-recursive-dwarf-info')
INTROSPECTOR_ALL_HEADER_FILES = f'{INTROSPECTOR_ENDPOINT}/all-header-files'
INTROSPECTOR_ALL_FUNC_TYPES = f'{INTROSPECTOR_ENDPOINT}/func-debug-types'
INTROSPECTOR_SAMPLE_XREFS = (
f'{INTROSPECTOR_ENDPOINT}/sample-cross-references')


def _construct_url(api: str, params: dict) -> str:
Expand Down Expand Up @@ -218,6 +226,35 @@ def query_introspector_source_code(project: str, filepath: str, begin_line: int,
return _get_data(resp, 'source_code', '')


def query_introspector_header_files(project: str) -> List[str]:
"""Queries for the header files used in a given project."""
resp = _query_introspector(INTROSPECTOR_ALL_HEADER_FILES,
{'project': project})
all_header_files = _get_data(resp, 'all-header-files', [])
return all_header_files


def query_introspector_sample_xrefs(project: str, func_sig: str) -> List[str]:
"""Queries for sample references in the form of source code."""
resp = _query_introspector(INTROSPECTOR_SAMPLE_XREFS, {
'project': project,
'function_signature': func_sig
})
return _get_data(resp, 'source-code-refs', [])


def query_introspector_function_debug_arg_types(project: str,
func_sig: str) -> List[str]:
"""Queries FuzzIntrospector function arguments extracted by way of debug
info."""
resp = _query_introspector(INTROSPECTOR_ALL_FUNC_TYPES, {
'project': project,
'function_signature': func_sig
})
arg_types = _get_data(resp, 'arg-types', [])
return arg_types


def query_introspector_cross_references(project: str,
func_sig: str) -> list[str]:
"""Queries FuzzIntrospector API for source code of functions
Expand Down
15 changes: 15 additions & 0 deletions experiment/oss_fuzz_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,18 @@ def get_project_language(project: str) -> str:
with open(project_yaml_path, 'r') as benchmark_file:
data = yaml.safe_load(benchmark_file)
return data.get('language', 'C++')


def get_project_repository(project: str) -> str:
"""Returns the |project| repository read from its project.yaml."""
project_yaml_path = os.path.join(OSS_FUZZ_DIR, 'projects', project,
'project.yaml')
if not os.path.isfile(project_yaml_path):
logging.warning(
'Failed to find the project yaml of %s, return empty repository',
project)
return ''

with open(project_yaml_path, 'r') as benchmark_file:
data = yaml.safe_load(benchmark_file)
return data.get('main_repo', '')
126 changes: 125 additions & 1 deletion llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import requests
import yaml

from data_prep import project_targets
from data_prep import introspector, project_targets
from experiment import oss_fuzz_checkout
from experiment.benchmark import Benchmark, FileType
from experiment.fuzz_target_error import SemanticCheckResult
from llm_toolkit import models, prompts
Expand Down Expand Up @@ -70,6 +71,8 @@
FALSE_FUZZED_DATA_PROVIDER_ERROR = 'include/fuzzer/FuzzedDataProvider.h:16:10:'
FALSE_EXTERN_KEYWORD_ERROR = 'expected identifier or \'(\'\nextern "C"'

C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES = ['stdio.h', 'stdlib.h', 'stdint.h']


class PromptBuilder:
"""Prompt builder."""
Expand All @@ -93,6 +96,11 @@ def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
errors: list[str]) -> prompts.Prompt:
"""Builds a fixer prompt."""

def post_process_generated_code(self, generated_code: str) -> str:
"""Allows prompt builder to adjust the generated code."""
# return the same by default
return generated_code


class DefaultTemplateBuilder(PromptBuilder):
"""Default builder for C/C++."""
Expand Down Expand Up @@ -576,3 +584,119 @@ def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
"""Builds a fixer prompt."""
# Do nothing for jvm project now.
return self._prompt


class CSpecificBuilder(PromptBuilder):
"""Builder specifically targeted C (and excluding C++)."""

def __init__(self,
model: models.LLM,
benchmark: Benchmark,
template_dir: str = DEFAULT_TEMPLATE_DIR):
super().__init__(model)
self._template_dir = template_dir
self.benchmark = benchmark

# Load templates.
self.priming_template_file = self._find_template(template_dir,
'c-priming.txt')

def _find_template(self, template_dir: str, template_name: str) -> str:
"""Finds template file based on |template_dir|."""
preferred_template = os.path.join(template_dir, template_name)
# Use the preferred template if it exists.
if os.path.isfile(preferred_template):
return preferred_template
# Fall back to the default template.
default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name)
return default_template

def _get_template(self, template_file: str) -> str:
"""Reads the template for prompts."""
with open(template_file) as file:
return file.read()

def build(self,
function_signature: str,
target_file_type: FileType,
example_pair: list[list[str]],
project_example_content: Optional[list[list[str]]] = None,
project_context_content: Optional[dict] = None) -> prompts.Prompt:
"""Constructs a prompt using the templates in |self| and saves it."""

with open(self.priming_template_file, 'r') as f:
prompt_text = f.read()

# Format the priming
target_repository = oss_fuzz_checkout.get_project_repository(
self.benchmark.project)
prompt_text = prompt_text.replace('{TARGET_REPO}', target_repository)
prompt_text = prompt_text.replace('{TARGET_FUNCTION}',
self.benchmark.function_signature)
function_source = introspector.query_introspector_function_source(
self.benchmark.project, self.benchmark.function_signature)
prompt_text = prompt_text.replace('{TARGET_FUNCTION_SOURCE_CODE}',
function_source)

# Set header inclusion string if there are any headers.
headers_to_include = introspector.query_introspector_header_files(
self.benchmark.project)
header_inclusion_string = ''
if headers_to_include:
header_inclusion_string = ', '.join(headers_to_include)

# TODO: Programmatically select and refine the header.
prompt_text = prompt_text.replace('{TARGET_HEADER_FILES}',
header_inclusion_string)

# Add function arg types
arg_types = introspector.query_introspector_function_debug_arg_types(
self.benchmark.project, self.benchmark.function_signature)

arg_types_text = ''
if arg_types:
arg_types_text = 'The target function takes the following arguments:\n'
arg_types_text += '- ' + '- '.join(f'{arg}\n' for arg in arg_types)

arg_types_text += (
'You must make sure the arguments passed to the '
'function match the types of the function. Do this by casting '
'appropriately.\n')

prompt_text = prompt_text.replace('{FUNCTION_ARG_TYPES_MSG}',
arg_types_text)

sample_cross_references = introspector.query_introspector_sample_xrefs(
self.benchmark.project, self.benchmark.function_signature)
if sample_cross_references:
additional_text = (
'The target function is used in various places of the target project.'
'Please see the following samples of code using the target, which '
'you should use as inspiration for the harness to structure the code:'
'\n')

exp_usage = 'Example usage:\n'
additional_text += exp_usage + exp_usage.join(
f'```c{elem}\n```\n' for elem in sample_cross_references)
else:
additional_text = ''

prompt_text = prompt_text.replace('{ADDITIONAL_INFORMATION}',
additional_text)

self._prompt.add_priming(prompt_text)
return self._prompt

def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
error_desc: Optional[str],
errors: list[str]) -> prompts.Prompt:
"""Prepares the code-fixing prompt."""
return self._prompt

def post_proces_generated_code(self, generated_code: str) -> str:
"""Adds specific C headers we always want in the harnesses."""
# TODO: explore if we can make this more precise, by only adding headers
# if needed.
for header in C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES:
generated_code = f'#include <{header}>\n' + generated_code
return generated_code
24 changes: 24 additions & 0 deletions prompts/template_xml/c-priming.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<system>
Hello! I need you to write a fuzzing harness. The target codebase is written purely in the C language so the harness should be in pure C.

The Codebase we are targeting is located in the repository {TARGET_REPO}.

I would like for you to write the harness targeting the function {TARGET_FUNCTION}

The source code for the function is:

{TARGET_FUNCTION_SOURCE_CODE}

The harness should be in libFuzzer style, with the code wrapped in `int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)`. Specifically, do not include `extern "C"` in the fuzzer code.

Please wrap all code in <code> tags and you should include nothing else but the code in your reply. Do not include any other text.

Make sure the ensure strings passed to the target are null-terminated.

There is one rule that your harness must satisfy: all of the header files in this library are: {TARGET_HEADER_FILES}. Make sure to not include any header files not in this list and you should use the full path to the header file as outlined in the list.

{FUNCTION_ARG_TYPES_MSG}

The most important part of the harness is that it will build and compile correctly against the target code. Please focus on making the code as simple as possible in order to secure it can be build.

{ADDITIONAL_INFORMATION}
7 changes: 6 additions & 1 deletion run_all_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def run_experiments(benchmark: benchmarklib.Benchmark,
cloud_experiment_bucket=args.cloud_experiment_bucket,
use_context=args.context,
run_timeout=args.run_timeout,
dry_run=args.dry_run)
dry_run=args.dry_run,
prompt_builder_to_use=args.prompt_builder)
return Result(benchmark, result)
except Exception as e:
print('Exception while running experiment:', e, file=sys.stderr)
Expand Down Expand Up @@ -232,6 +233,10 @@ def parse_args() -> argparse.Namespace:
default=0,
help=('Delay each experiment by certain seconds (e.g., 10s) to avoid '
'exceeding quota limit in large scale experiments.'))
parser.add_argument('-p',
'--prompt-builder',
help='The prompt builder to use for harness generation.',
default='DEFAULT')

args = parser.parse_args()
if args.num_samples:
Expand Down
14 changes: 11 additions & 3 deletions run_one_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def generate_targets(benchmark: Benchmark,
model: models.LLM,
prompt: prompts.Prompt,
work_dirs: WorkDirs,
builder: prompt_builder.PromptBuilder,
debug: bool = DEBUG) -> list[str]:
"""Generates fuzz target with LLM."""
print(f'Generating targets for {benchmark.project} '
Expand All @@ -97,6 +98,7 @@ def generate_targets(benchmark: Benchmark,
continue
raw_output = os.path.join(work_dirs.raw_targets, file)
target_code = output_parser.parse_code(raw_output)
target_code = builder.post_process_generated_code(target_code)
target_id, _ = os.path.splitext(raw_output)
target_file = f'{target_id}{target_ext}'
target_path = os.path.join(work_dirs.raw_targets, target_file)
Expand Down Expand Up @@ -221,7 +223,8 @@ def run(benchmark: Benchmark,
cloud_experiment_bucket: str = '',
use_context: bool = False,
run_timeout: int = RUN_TIMEOUT,
dry_run: bool = False) -> Optional[AggregatedResult]:
dry_run: bool = False,
prompt_builder_to_use: str = 'DEFAULT') -> Optional[AggregatedResult]:
"""Generates code via LLM, and evaluates them."""
model.cloud_setup()
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -255,8 +258,12 @@ def run(benchmark: Benchmark,
builder = prompt_builder.DefaultJvmTemplateBuilder(
model, benchmark.project, benchmark.params, template_dir)
else:
# For C/C++ projects
builder = prompt_builder.DefaultTemplateBuilder(model, template_dir)
if prompt_builder_to_use == 'CSpecific':
builder = prompt_builder.CSpecificBuilder(model, benchmark,
template_dir)
else:
# Use default
builder = prompt_builder.DefaultTemplateBuilder(model, template_dir)

prompt = builder.build(benchmark.function_signature,
benchmark.file_type,
Expand All @@ -272,6 +279,7 @@ def run(benchmark: Benchmark,
model,
prompt,
work_dirs,
builder,
debug=debug)
generated_targets = fix_code(work_dirs, generated_targets)
return check_targets(model.ai_binary, benchmark, work_dirs, generated_targets,
Expand Down

0 comments on commit 8d8a8bd

Please sign in to comment.