Skip to content

Commit 69a808c

Browse files
committed
Check similiarities function in __init__
1 parent 5d58c5f commit 69a808c

File tree

3 files changed

+18
-72
lines changed

3 files changed

+18
-72
lines changed

django_mongodb_backend/indexes.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,15 @@ class VectorSearchIndex(SearchIndex):
168168
def __init__(self, *expressions, similarities="cosine", **kwargs):
169169
super().__init__(*expressions, **kwargs)
170170
self.similarities = similarities
171+
for func in similarities if isinstance(similarities, list) else (similarities,):
172+
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
173+
raise ValueError(
174+
f"{func} isn't a valid similarity function, options "
175+
f"are {', '.join(sorted(self.ALLOWED_SIMILARITY_FUNCTIONS))}"
176+
)
171177

172178
def check(self, model, connection):
173179
errors = super().check(model, connection)
174-
similarities = (
175-
self.similarities if isinstance(self.similarities, list) else [self.similarities]
176-
)
177-
for func in similarities:
178-
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
179-
errors.append(
180-
Error(
181-
f"{func} isn't a valid similarity function, options "
182-
f"are {', '.join(sorted(self.ALLOWED_SIMILARITY_FUNCTIONS))}",
183-
obj=self,
184-
id=f"{self._error_id_prefix}.E004",
185-
)
186-
)
187180
viewed = set()
188181
expected_similarities = 0
189182
for field_name, _ in self.fields_orders:
@@ -195,7 +188,7 @@ def check(self, model, connection):
195188
obj=self,
196189
hint="If you need different configurations for the same field, "
197190
"create separate indexes.",
198-
id=f"{self._error_id_prefix}.E005",
191+
id=f"{self._error_id_prefix}.E004",
199192
)
200193
)
201194
continue

tests/indexes_/test_atlas_indexes.py

+8
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ class VectorSearchIndexTests(TestMixin, TestCase):
104104
using Django's schema editor.
105105
"""
106106

107+
def test_invalid_similarity_function(self):
108+
msg = (
109+
"sum isn't a valid similarity function, options "
110+
f"are {', '.join(sorted(VectorSearchIndex.ALLOWED_SIMILARITY_FUNCTIONS))}"
111+
)
112+
with self.assertRaisesMessage(ValueError, msg):
113+
VectorSearchIndex(fields=["vector_data"], similarities="sum")
114+
107115
@skipUnlessDBFeature("supports_atlas_search")
108116
def test_deconstruct_default_similarity(self):
109117
index = VectorSearchIndex(

tests/indexes_/test_checks.py

+3-58
Original file line numberDiff line numberDiff line change
@@ -131,61 +131,6 @@ class Meta:
131131
],
132132
)
133133

134-
def test_invalid_similarity_function(self):
135-
class Article(models.Model):
136-
vector_data = ArrayField(models.DecimalField(), size=10)
137-
138-
class Meta:
139-
indexes = [
140-
VectorSearchIndex(fields=["vector_data"], similarities="sum"),
141-
]
142-
143-
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
144-
self.assertEqual(
145-
errors,
146-
[
147-
checks.Error(
148-
"sum isn't a valid similarity function, "
149-
"options are cosine, dotProduct, euclidean",
150-
id="django_mongodb_backend.indexes.VectorSearchIndex.E004",
151-
obj=Article._meta.indexes[0],
152-
)
153-
],
154-
)
155-
156-
def test_invalid_similarities_function(self):
157-
class Article(models.Model):
158-
vector1 = ArrayField(models.DecimalField(), size=10)
159-
vector2 = ArrayField(models.DecimalField(), size=10)
160-
vector3 = ArrayField(models.DecimalField(), size=10)
161-
162-
class Meta:
163-
indexes = [
164-
VectorSearchIndex(
165-
fields=["vector1", "vector2", "vector3"],
166-
similarities=["sum", "dotProduct", "tangh"],
167-
),
168-
]
169-
170-
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
171-
self.assertEqual(
172-
errors,
173-
[
174-
checks.Error(
175-
"sum isn't a valid similarity function, "
176-
"options are cosine, dotProduct, euclidean",
177-
id="django_mongodb_backend.indexes.VectorSearchIndex.E004",
178-
obj=Article._meta.indexes[0],
179-
),
180-
checks.Error(
181-
"tangh isn't a valid similarity function, "
182-
"options are cosine, dotProduct, euclidean",
183-
id="django_mongodb_backend.indexes.VectorSearchIndex.E004",
184-
obj=Article._meta.indexes[0],
185-
),
186-
],
187-
)
188-
189134
def test_define_field_twice(self):
190135
class Article(models.Model):
191136
vector_data = ArrayField(models.DecimalField(), size=10)
@@ -205,7 +150,7 @@ class Meta:
205150
checks.Error(
206151
"Field 'vector_data' is defined more than once. Vector and filter "
207152
"fields must use distinct field names.",
208-
id="django_mongodb_backend.indexes.VectorSearchIndex.E005",
153+
id="django_mongodb_backend.indexes.VectorSearchIndex.E004",
209154
hint="If you need different configurations for the same field,"
210155
" create separate indexes.",
211156
obj=Article._meta.indexes[0],
@@ -233,7 +178,7 @@ class Meta:
233178
"An Atlas vector search index requires the same number of similarities "
234179
"and vector fields, but 1 similarity function were expected and 2 "
235180
"were provided.",
236-
id="django_mongodb_backend.indexes.VectorSearchIndex.E006",
181+
id="django_mongodb_backend.indexes.VectorSearchIndex.E005",
237182
obj=Article._meta.indexes[0],
238183
),
239184
],
@@ -260,7 +205,7 @@ class Meta:
260205
"An Atlas vector search index requires the same number of similarities "
261206
"and vector fields, but 2 similarities functions were expected and 1 "
262207
"was provided.",
263-
id="django_mongodb_backend.indexes.VectorSearchIndex.E006",
208+
id="django_mongodb_backend.indexes.VectorSearchIndex.E005",
264209
obj=Article._meta.indexes[0],
265210
),
266211
],

0 commit comments

Comments
 (0)