Skip to content

Commit 16e501d

Browse files
fix: Improve Gemini client error handling and add tests (stitionai#530)
- Add better error messages for API key configuration - Add comprehensive test coverage - Update google-generativeai version requirement - Add proper logging for debugging Fixes stitionai#530 Co-Authored-By: Erkin Alp Güney <[email protected]>
1 parent 3b98ed3 commit 16e501d

File tree

3 files changed

+127
-24
lines changed

3 files changed

+127
-24
lines changed

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ pytest-playwright
1414
tiktoken
1515
ollama
1616
openai
17-
anthropic
18-
google-generativeai
17+
anthropic>=0.8.0
18+
google-generativeai>=0.3.0
1919
sqlmodel
2020
keybert
2121
GitPython

src/llm/gemini_client.py

+48-22
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,58 @@
22
from google.generativeai.types import HarmCategory, HarmBlockThreshold
33

44
from src.config import Config
5+
from src.logger import Logger
6+
7+
logger = Logger()
8+
config = Config()
59

610
class Gemini:
711
def __init__(self):
8-
config = Config()
912
api_key = config.get_gemini_api_key()
10-
genai.configure(api_key=api_key)
13+
if not api_key:
14+
error_msg = ("Gemini API key not found in configuration. "
15+
"Please add your Gemini API key to config.toml under [API_KEYS] "
16+
"section as GEMINI = 'your-api-key'")
17+
logger.error(error_msg)
18+
raise ValueError(error_msg)
19+
try:
20+
genai.configure(api_key=api_key)
21+
logger.info("Successfully initialized Gemini client")
22+
except Exception as e:
23+
error_msg = f"Failed to configure Gemini client: {str(e)}"
24+
logger.error(error_msg)
25+
raise ValueError(error_msg)
1126

1227
def inference(self, model_id: str, prompt: str) -> str:
13-
config = genai.GenerationConfig(temperature=0)
14-
model = genai.GenerativeModel(model_id, generation_config=config)
15-
# Set safety settings for the request
16-
safety_settings = {
17-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
18-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
19-
# You can adjust other categories as needed
20-
}
21-
response = model.generate_content(prompt, safety_settings=safety_settings)
2228
try:
23-
# Check if the response contains text
24-
return response.text
25-
except ValueError:
26-
# If the response doesn't contain text, check if the prompt was blocked
27-
print("Prompt feedback:", response.prompt_feedback)
28-
# Also check the finish reason to see if the response was blocked
29-
print("Finish reason:", response.candidates[0].finish_reason)
30-
# If the finish reason was SAFETY, the safety ratings have more details
31-
print("Safety ratings:", response.candidates[0].safety_ratings)
32-
# Handle the error or return an appropriate message
33-
return "Error: Unable to generate content Gemini API"
29+
logger.info(f"Initializing Gemini model: {model_id}")
30+
config = genai.GenerationConfig(temperature=0)
31+
model = genai.GenerativeModel(model_id, generation_config=config)
32+
33+
safety_settings = {
34+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
35+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
36+
}
37+
38+
logger.info("Generating response with Gemini")
39+
response = model.generate_content(prompt, safety_settings=safety_settings)
40+
41+
try:
42+
if response.text:
43+
logger.info("Successfully generated response")
44+
return response.text
45+
else:
46+
error_msg = f"Empty response from Gemini model {model_id}"
47+
logger.error(error_msg)
48+
raise ValueError(error_msg)
49+
except ValueError:
50+
logger.error("Failed to get response text")
51+
logger.error(f"Prompt feedback: {response.prompt_feedback}")
52+
logger.error(f"Finish reason: {response.candidates[0].finish_reason}")
53+
logger.error(f"Safety ratings: {response.candidates[0].safety_ratings}")
54+
return "Error: Unable to generate content with Gemini API"
55+
56+
except Exception as e:
57+
error_msg = f"Error during Gemini inference: {str(e)}"
58+
logger.error(error_msg)
59+
raise ValueError(error_msg)

tests/test_gemini_client.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Tests for Gemini client implementation.
3+
"""
4+
import pytest
5+
from unittest.mock import MagicMock, patch
6+
from src.llm.gemini_client import Gemini
7+
8+
@pytest.fixture
9+
def mock_config():
10+
with patch('src.llm.gemini_client.config') as mock:
11+
mock.get_gemini_api_key.return_value = "test-api-key"
12+
yield mock
13+
14+
@pytest.fixture
15+
def mock_genai():
16+
with patch('src.llm.gemini_client.genai') as mock:
17+
yield mock
18+
19+
@pytest.fixture
20+
def gemini_client(mock_config, mock_genai):
21+
return Gemini()
22+
23+
def test_init_with_api_key(mock_config, mock_genai):
24+
"""Test client initialization with API key."""
25+
client = Gemini()
26+
mock_genai.configure.assert_called_once_with(api_key="test-api-key")
27+
28+
def test_init_without_api_key(mock_config, mock_genai):
29+
"""Test client initialization without API key."""
30+
mock_config.get_gemini_api_key.return_value = None
31+
with pytest.raises(ValueError, match="Gemini API key not found in configuration"):
32+
Gemini()
33+
34+
def test_init_config_failure(mock_config, mock_genai):
35+
"""Test handling of configuration failure."""
36+
mock_genai.configure.side_effect = Exception("Test error")
37+
with pytest.raises(ValueError, match="Failed to configure Gemini client: Test error"):
38+
Gemini()
39+
40+
def test_inference_success(mock_genai, gemini_client):
41+
"""Test successful text generation."""
42+
mock_model = MagicMock()
43+
mock_response = MagicMock()
44+
mock_response.text = "Generated response"
45+
mock_model.generate_content.return_value = mock_response
46+
mock_genai.GenerativeModel.return_value = mock_model
47+
48+
response = gemini_client.inference("gemini-pro", "Test prompt")
49+
assert response == "Generated response"
50+
mock_model.generate_content.assert_called_once_with("Test prompt", safety_settings={
51+
mock_genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
52+
mock_genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
53+
})
54+
55+
def test_inference_empty_response(mock_genai, gemini_client):
56+
"""Test handling of empty response."""
57+
mock_model = MagicMock()
58+
mock_response = MagicMock()
59+
mock_response.text = None
60+
mock_model.generate_content.return_value = mock_response
61+
mock_genai.GenerativeModel.return_value = mock_model
62+
63+
with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
64+
gemini_client.inference("gemini-pro", "Test prompt")
65+
66+
def test_inference_error(mock_genai, gemini_client):
67+
"""Test handling of inference error."""
68+
mock_model = MagicMock()
69+
mock_model.generate_content.side_effect = Exception("Test error")
70+
mock_genai.GenerativeModel.return_value = mock_model
71+
72+
with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
73+
gemini_client.inference("gemini-pro", "Test prompt")
74+
75+
def test_str_representation(gemini_client):
76+
"""Test string representation."""
77+
assert str(gemini_client) == "Gemini"

0 commit comments

Comments
 (0)