|
2 | 2 | from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
3 | 3 |
|
4 | 4 | from src.config import Config
|
| 5 | +from src.logger import Logger |
| 6 | + |
| 7 | +logger = Logger() |
| 8 | +config = Config() |
5 | 9 |
|
6 | 10 | class Gemini:
|
7 | 11 | def __init__(self):
|
8 |
| - config = Config() |
9 | 12 | api_key = config.get_gemini_api_key()
|
10 |
| - genai.configure(api_key=api_key) |
| 13 | + if not api_key: |
| 14 | + error_msg = ("Gemini API key not found in configuration. " |
| 15 | + "Please add your Gemini API key to config.toml under [API_KEYS] " |
| 16 | + "section as GEMINI = 'your-api-key'") |
| 17 | + logger.error(error_msg) |
| 18 | + raise ValueError(error_msg) |
| 19 | + try: |
| 20 | + genai.configure(api_key=api_key) |
| 21 | + logger.info("Successfully initialized Gemini client") |
| 22 | + except Exception as e: |
| 23 | + error_msg = f"Failed to configure Gemini client: {str(e)}" |
| 24 | + logger.error(error_msg) |
| 25 | + raise ValueError(error_msg) |
11 | 26 |
|
12 | 27 | def inference(self, model_id: str, prompt: str) -> str:
|
13 |
| - config = genai.GenerationConfig(temperature=0) |
14 |
| - model = genai.GenerativeModel(model_id, generation_config=config) |
15 |
| - # Set safety settings for the request |
16 |
| - safety_settings = { |
17 |
| - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
18 |
| - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
19 |
| - # You can adjust other categories as needed |
20 |
| - } |
21 |
| - response = model.generate_content(prompt, safety_settings=safety_settings) |
22 | 28 | try:
|
23 |
| - # Check if the response contains text |
24 |
| - return response.text |
25 |
| - except ValueError: |
26 |
| - # If the response doesn't contain text, check if the prompt was blocked |
27 |
| - print("Prompt feedback:", response.prompt_feedback) |
28 |
| - # Also check the finish reason to see if the response was blocked |
29 |
| - print("Finish reason:", response.candidates[0].finish_reason) |
30 |
| - # If the finish reason was SAFETY, the safety ratings have more details |
31 |
| - print("Safety ratings:", response.candidates[0].safety_ratings) |
32 |
| - # Handle the error or return an appropriate message |
33 |
| - return "Error: Unable to generate content Gemini API" |
| 29 | + logger.info(f"Initializing Gemini model: {model_id}") |
| 30 | + config = genai.GenerationConfig(temperature=0) |
| 31 | + model = genai.GenerativeModel(model_id, generation_config=config) |
| 32 | + |
| 33 | + safety_settings = { |
| 34 | + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
| 35 | + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
| 36 | + } |
| 37 | + |
| 38 | + logger.info("Generating response with Gemini") |
| 39 | + response = model.generate_content(prompt, safety_settings=safety_settings) |
| 40 | + |
| 41 | + try: |
| 42 | + if response.text: |
| 43 | + logger.info("Successfully generated response") |
| 44 | + return response.text |
| 45 | + else: |
| 46 | + error_msg = f"Empty response from Gemini model {model_id}" |
| 47 | + logger.error(error_msg) |
| 48 | + raise ValueError(error_msg) |
| 49 | + except ValueError: |
| 50 | + logger.error("Failed to get response text") |
| 51 | + logger.error(f"Prompt feedback: {response.prompt_feedback}") |
| 52 | + logger.error(f"Finish reason: {response.candidates[0].finish_reason}") |
| 53 | + logger.error(f"Safety ratings: {response.candidates[0].safety_ratings}") |
| 54 | + return "Error: Unable to generate content with Gemini API" |
| 55 | + |
| 56 | + except Exception as e: |
| 57 | + error_msg = f"Error during Gemini inference: {str(e)}" |
| 58 | + logger.error(error_msg) |
| 59 | + raise ValueError(error_msg) |
0 commit comments