Skip to content

Commit 3bd98aa

Browse files
authored
A script to run Gemini models (#345)
Adding a script to run Gemini models with a text prompt. Steps: ```shell python -m venv .venv pip install -r requirements.txt gcloud auth login && gcloud auth application-default login && gcloud auth application-default set-quota-project oss-fuzz python -m experimental.manual.prompter -p prompt.txt -l vertex_ai_gemini-1-5 ls `./responses/` ``` See the [`name` attribute](https://github.com/google/oss-fuzz-gen/blob/main/llm_toolkit/models.py#L422) for more supported model names.
1 parent 00507b3 commit 3bd98aa

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

experimental/manual/__init__.py

Whitespace-only changes.

experimental/manual/prompter.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Play with Gemini models manually
2+
Usage:
3+
# Under venv.
4+
python -m experimental.manual.prompter -p <prompt_file> -l <model_name>
5+
# <prompt_file> is a plain text file.
6+
# <model_name> is `name` attribute of classes in llm_toolkit/models.py.
7+
# E.g.,
8+
python -m experimental.manual.prompter -p prompt.txt -l vertex_ai_gemini-1-5
9+
"""
10+
11+
import argparse
12+
import os
13+
14+
from llm_toolkit import models, prompts
15+
16+
NUM_SAMPLES: int = 1
17+
TEMPERATURE: float = 0.4
18+
MAX_TOKENS: int = 4096
19+
20+
21+
def parse_args() -> argparse.Namespace:
22+
"""Parses command line arguments."""
23+
parser = argparse.ArgumentParser(
24+
description='Run all experiments that evaluates all target functions.')
25+
parser.add_argument('-n',
26+
'--num-samples',
27+
type=int,
28+
default=NUM_SAMPLES,
29+
help='The number of samples to request from LLM.')
30+
parser.add_argument(
31+
'-t',
32+
'--temperature',
33+
type=float,
34+
default=TEMPERATURE,
35+
help=('A value between 0 and 2 representing the variety of the targets '
36+
'generated by LLM.'))
37+
parser.add_argument('-l',
38+
'--model',
39+
default=models.DefaultModel.name,
40+
help=('Models available: '
41+
f'{", ".join(models.LLM.all_llm_names())}'))
42+
parser.add_argument('-p', '--prompt', help='Prompt file for LLM.')
43+
parser.add_argument('-r',
44+
'--response-dir',
45+
default='./responses',
46+
help='LLM response directory.')
47+
return parser.parse_args()
48+
49+
50+
def setup_model() -> models.LLM:
51+
return models.LLM.setup(
52+
ai_binary='',
53+
name=args.model,
54+
max_tokens=MAX_TOKENS,
55+
num_samples=args.num_samples,
56+
temperature=args.temperature,
57+
)
58+
59+
60+
def construct_prompt() -> prompts.Prompt:
61+
with open(args.prompt, 'r') as prompt_file:
62+
content = prompt_file.read()
63+
return model.prompt_type()(initial=content)
64+
65+
66+
if __name__ == "__main__":
67+
args = parse_args()
68+
model = setup_model()
69+
prompt = construct_prompt()
70+
os.makedirs(args.response_dir, exist_ok=True)
71+
model.generate_code(prompt, response_dir='responses')

0 commit comments

Comments
 (0)