Skip to content

Commit

Permalink
A script to run Gemini models (#345)
Browse files Browse the repository at this point in the history
Adding a script to run Gemini models with a text prompt.
Steps:
```shell
python -m venv .venv
pip install -r requirements.txt
gcloud auth login && gcloud auth application-default login && gcloud auth application-default set-quota-project oss-fuzz
python -m experimental.manual.prompter -p prompt.txt -l vertex_ai_gemini-1-5
ls `./responses/`
```

See the [`name`
attribute](https://github.com/google/oss-fuzz-gen/blob/main/llm_toolkit/models.py#L422)
for more supported model names.
  • Loading branch information
DonggeLiu authored Jun 18, 2024
1 parent 00507b3 commit 3bd98aa
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
Empty file added experimental/manual/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions experimental/manual/prompter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Play with Gemini models manually
Usage:
# Under venv.
python -m experimental.manual.prompter -p <prompt_file> -l <model_name>
# <prompt_file> is a plain text file.
# <model_name> is `name` attribute of classes in llm_toolkit/models.py.
# E.g.,
python -m experimental.manual.prompter -p prompt.txt -l vertex_ai_gemini-1-5
"""

import argparse
import os

from llm_toolkit import models, prompts

NUM_SAMPLES: int = 1
TEMPERATURE: float = 0.4
MAX_TOKENS: int = 4096


def parse_args() -> argparse.Namespace:
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
description='Run all experiments that evaluates all target functions.')
parser.add_argument('-n',
'--num-samples',
type=int,
default=NUM_SAMPLES,
help='The number of samples to request from LLM.')
parser.add_argument(
'-t',
'--temperature',
type=float,
default=TEMPERATURE,
help=('A value between 0 and 2 representing the variety of the targets '
'generated by LLM.'))
parser.add_argument('-l',
'--model',
default=models.DefaultModel.name,
help=('Models available: '
f'{", ".join(models.LLM.all_llm_names())}'))
parser.add_argument('-p', '--prompt', help='Prompt file for LLM.')
parser.add_argument('-r',
'--response-dir',
default='./responses',
help='LLM response directory.')
return parser.parse_args()


def setup_model() -> models.LLM:
return models.LLM.setup(
ai_binary='',
name=args.model,
max_tokens=MAX_TOKENS,
num_samples=args.num_samples,
temperature=args.temperature,
)


def construct_prompt() -> prompts.Prompt:
with open(args.prompt, 'r') as prompt_file:
content = prompt_file.read()
return model.prompt_type()(initial=content)


if __name__ == "__main__":
args = parse_args()
model = setup_model()
prompt = construct_prompt()
os.makedirs(args.response_dir, exist_ok=True)
model.generate_code(prompt, response_dir='responses')

0 comments on commit 3bd98aa

Please sign in to comment.