Skip to content

Commit 140a057

Browse files
authored
Add OllamaLLM and OllamaEmbeddings classes (#231)
* Add OllamaLLM and OllamaEmbeddings classes using the ollama python client * Try removing import * :( * Add tests + reformat import in ollama embeddings for consistency with all other imports * Fix after merge
1 parent ff6862e commit 140a057

File tree

16 files changed

+293
-64
lines changed

16 files changed

+293
-64
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
77
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.
88
- Ability to create a Pipeline or SimpleKGPipeline from a config file. See [the example](examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py).
9+
- Added `OllamaLLM` and `OllamaEmbeddings` classes to make Ollama support more explicit. Implementations using the `OpenAILLM` and `OpenAIEmbeddings` classes will still work.
910

1011
## Changed
1112
- Updated LLM prompts to include stricter instructions for generating valid JSON.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ This package has some optional features that can be enabled using
3131
the extra dependencies described below:
3232

3333
- LLM providers (at least one is required for RAG and KG Builder Pipeline):
34+
- **ollama**: LLMs from Ollama
3435
- **openai**: LLMs from OpenAI (including AzureOpenAI)
3536
- **google**: LLMs from Vertex AI
3637
- **cohere**: LLMs from Cohere

docs/source/api.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ AzureOpenAIEmbeddings
239239
.. autoclass:: neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings
240240
:members:
241241

242+
OllamaEmbeddings
243+
================
244+
245+
.. autoclass:: neo4j_graphrag.embeddings.ollama.OllamaEmbeddings
246+
:members:
247+
242248
VertexAIEmbeddings
243249
==================
244250

@@ -286,6 +292,12 @@ AzureOpenAILLM
286292
:members:
287293
:undoc-members: get_messages, client_class, async_client_class
288294

295+
OllamaLLM
296+
---------
297+
298+
.. autoclass:: neo4j_graphrag.llm.ollama_llm.OllamaLLM
299+
:members:
300+
289301

290302
VertexAILLM
291303
-----------

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ Extra dependencies can be installed with:
8888
List of extra dependencies:
8989

9090
- LLM providers (at least one is required for RAG and KG Builder Pipeline):
91+
- **ollama**: LLMs from Ollama
9192
- **openai**: LLMs from OpenAI (including AzureOpenAI)
9293
- **google**: LLMs from Vertex AI
9394
- **cohere**: LLMs from Cohere

docs/source/user_guide_rag.rst

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,13 @@ See :ref:`coherellm`.
218218
Using a Local Model via Ollama
219219
-------------------------------
220220

221-
Similarly to the official OpenAI Python client, the `OpenAILLM` can be
222-
used with Ollama. Assuming Ollama is running on the default address `127.0.0.1:11434`,
221+
Assuming Ollama is running on the default address `127.0.0.1:11434`,
223222
it can be queried using the following:
224223

225224
.. code:: python
226225
227-
from neo4j_graphrag.llm import OpenAILLM
228-
llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini")
226+
from neo4j_graphrag.llm import OllamaLLM
227+
llm = OllamaLLM(model_name="orca-mini")
229228
llm.invoke("say something")
230229
231230
@@ -428,6 +427,7 @@ Currently, this package supports the following embedders:
428427
- :ref:`mistralaiembeddings`
429428
- :ref:`cohereembeddings`
430429
- :ref:`azureopenaiembeddings`
430+
- :ref:`ollamaembeddings`
431431

432432
The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:
433433

@@ -438,31 +438,7 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente
438438
embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model
439439
440440
441-
If another embedder is desired, a custom embedder can be created. For example, consider
442-
the following implementation of an embedder that wraps the `OllamaEmbedding` model from LlamaIndex:
443-
444-
.. code:: python
445-
446-
from llama_index.embeddings.ollama import OllamaEmbedding
447-
from neo4j_graphrag.embeddings.base import Embedder
448-
449-
class OllamaEmbedder(Embedder):
450-
def __init__(self, ollama_embedding):
451-
self.embedder = ollama_embedding
452-
453-
def embed_query(self, text: str) -> list[float]:
454-
embedding = self.embedder.get_text_embedding_batch(
455-
[text], show_progress=True
456-
)
457-
return embedding[0]
458-
459-
ollama_embedding = OllamaEmbedding(
460-
model_name="llama3",
461-
base_url="http://localhost:11434",
462-
ollama_additional_kwargs={"mirostat": 0},
463-
)
464-
embedder = OllamaEmbedder(ollama_embedding)
465-
vector = embedder.embed_query("some text")
441+
If another embedder is desired, a custom embedder can be created, using the `Embedder` interface.
466442

467443

468444
Other Vector Retriever Configuration

examples/customize/embeddings/ollama_embeddings.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
"""This example demonstrate how to embed a text into a vector
2-
using OpenAI models and API.
2+
using a local model served by Ollama.
33
"""
44

5-
from neo4j_graphrag.embeddings import OpenAIEmbeddings
5+
from neo4j_graphrag.embeddings import OllamaEmbeddings
66

7-
# not used but needs to be provided
8-
api_key = "ollama"
9-
10-
embeder = OpenAIEmbeddings(
11-
base_url="http://localhost:11434/v1",
12-
api_key=api_key,
7+
embeder = OllamaEmbeddings(
138
model="<model_name>",
149
)
1510
res = embeder.embed_query("my question")

examples/customize/llms/ollama_llm.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
1+
"""This example demonstrate how to invoke an LLM using a local model
2+
served by Ollama.
3+
"""
24

3-
# not used but needs to be provided
4-
api_key = "ollama"
5+
from neo4j_graphrag.llm import LLMResponse, OllamaLLM
56

6-
llm = OpenAILLM(
7-
base_url="http://localhost:11434/v1",
7+
llm = OllamaLLM(
88
model_name="<model_name>",
9-
api_key=api_key,
109
)
1110
res: LLMResponse = llm.invoke("What is the additive color model?")
1211
print(res.content)

poetry.lock

Lines changed: 36 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ anthropic = { version = "^0.36.0", optional = true}
5050
sentence-transformers = {version = "^3.0.0", optional = true }
5151
json-repair = "^0.30.2"
5252
types-pyyaml = "^6.0.12.20240917"
53+
ollama = {version = "^0.4.4", optional = true}
5354

5455
[tool.poetry.group.dev.dependencies]
5556
urllib3 = "<2"
@@ -69,6 +70,7 @@ pinecone = ["pinecone-client"]
6970
google = ["google-cloud-aiplatform"]
7071
cohere = ["cohere"]
7172
anthropic = ["anthropic"]
73+
ollama = ["ollama"]
7274
openai = ["openai"]
7375
mistralai = ["mistralai"]
7476
qdrant = ["qdrant-client"]

src/neo4j_graphrag/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from .base import Embedder
1616
from .cohere import CohereEmbeddings
1717
from .mistral import MistralAIEmbeddings
18+
from .ollama import OllamaEmbeddings
1819
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1920
from .sentence_transformers import SentenceTransformerEmbeddings
2021
from .vertexai import VertexAIEmbeddings
2122

2223
__all__ = [
2324
"Embedder",
2425
"SentenceTransformerEmbeddings",
26+
"OllamaEmbeddings",
2527
"OpenAIEmbeddings",
2628
"AzureOpenAIEmbeddings",
2729
"VertexAIEmbeddings",

src/neo4j_graphrag/embeddings/mistral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5757
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
5858
"""
5959
embeddings_batch_response = self.mistral_client.embeddings.create(
60-
model=self.model,
61-
inputs=[text],
60+
model=self.model, inputs=[text], **kwargs
6261
)
6362
if embeddings_batch_response is None or not embeddings_batch_response.data:
6463
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
from typing import Any
19+
20+
from neo4j_graphrag.embeddings.base import Embedder
21+
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
22+
23+
24+
class OllamaEmbeddings(Embedder):
25+
"""
26+
Ollama embeddings class.
27+
This class uses the ollama Python client to generate vector embeddings for text data.
28+
29+
Args:
30+
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
31+
"""
32+
33+
def __init__(self, model: str, **kwargs: Any) -> None:
34+
try:
35+
import ollama
36+
except ImportError:
37+
raise ImportError(
38+
"Could not import ollama python client. "
39+
"Please install it with `pip install ollama`."
40+
)
41+
self.model = model
42+
self.client = ollama.Client(**kwargs)
43+
44+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
45+
"""
46+
Generate embeddings for a given query using an Ollama text embedding model.
47+
48+
Args:
49+
text (str): The text to generate an embedding for.
50+
**kwargs (Any): Additional keyword arguments to pass to the Ollama client.
51+
"""
52+
embeddings_response = self.client.embed(
53+
model=self.model,
54+
input=text,
55+
**kwargs,
56+
)
57+
58+
if embeddings_response is None or embeddings_response.embeddings is None:
59+
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
60+
61+
embedding = embeddings_response.embeddings
62+
if not isinstance(embedding, list):
63+
raise EmbeddingsGenerationError("Embedding is not a list of floats.")
64+
65+
return embedding

0 commit comments

Comments
 (0)