Skip to content

Commit ade28e2

Browse files
authored
Fix ollama embedder (#245)
* Fix ollama embedder * Update CHANGELOG
1 parent 0f9d5df commit ade28e2

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Next
44

5+
### Fixed
6+
7+
- Fix a bug where the `OllamaEmbedder` would return a `list[list[float]]` instead of the expected `list[float]`.
8+
59
## 1.4.1
610

711
### Fixed

src/neo4j_graphrag/embeddings/ollama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5555
**kwargs,
5656
)
5757

58-
if embeddings_response is None or embeddings_response.embeddings is None:
58+
if embeddings_response is None or not embeddings_response.embeddings:
5959
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
6060

61-
embedding = embeddings_response.embeddings
61+
embeddings = embeddings_response.embeddings
62+
# client always returns a sequence of sequences
63+
embedding = embeddings[0]
6264
if not isinstance(embedding, list):
6365
raise EmbeddingsGenerationError("Embedding is not a list of floats.")
6466

tests/unit/embeddings/test_ollama_embedder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818
from neo4j_graphrag.embeddings.ollama import OllamaEmbeddings
19+
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
1920

2021

2122
@patch("builtins.__import__", side_effect=ImportError)
@@ -27,9 +28,19 @@ def test_ollama_embedder_missing_dependency(mock_import: Mock) -> None:
2728
@patch("builtins.__import__")
2829
def test_ollama_embedder_happy_path(mock_import: Mock) -> None:
2930
mock_import.return_value.Client.return_value.embed.return_value = MagicMock(
30-
embeddings=[1.0, 2.0],
31+
embeddings=[[1.0, 2.0]],
3132
)
3233
embedder = OllamaEmbeddings(model="test")
3334
res = embedder.embed_query("my text")
3435
assert isinstance(res, list)
3536
assert res == [1.0, 2.0]
37+
38+
39+
@patch("builtins.__import__")
40+
def test_ollama_embedder_empty_list(mock_import: Mock) -> None:
41+
mock_import.return_value.Client.return_value.embed.return_value = MagicMock(
42+
embeddings=[],
43+
)
44+
embedder = OllamaEmbeddings(model="test")
45+
with pytest.raises(EmbeddingsGenerationError):
46+
embedder.embed_query("my text")

0 commit comments

Comments
 (0)