@@ -132,11 +132,11 @@ def search_index_data_types(self, field, db_type):
132
132
"""
133
133
if field .get_internal_type () == "UUIDField" :
134
134
return "uuid"
135
- if field .get_internal_type () in ( "ObjectIdAutoField" , "ObjectIdField" ) :
135
+ if field .get_internal_type () in { "ObjectIdAutoField" , "ObjectIdField" } :
136
136
return "ObjectId"
137
137
if field .get_internal_type () == "EmbeddedModelField" :
138
138
return "embeddedDocuments"
139
- if db_type in ( "int" , "long" ) :
139
+ if db_type in { "int" , "long" } :
140
140
return "number"
141
141
if db_type == "binData" :
142
142
return "string"
@@ -164,26 +164,24 @@ def get_pymongo_index_model(
164
164
165
165
class VectorSearchIndex (SearchIndex ):
166
166
suffix = "vsi"
167
- ALLOWED_SIMILARITY_FUNCTIONS = frozenset (("euclidean" , "cosine" , "dotProduct" ))
167
+ VALID_SIMILARITIES = frozenset (("euclidean" , "cosine" , "dotProduct" ))
168
168
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
169
169
170
170
def __init__ (self , * expressions , fields = (), similarities = "cosine" , name = None , ** kwargs ):
171
171
super ().__init__ (* expressions , fields = fields , name = name , ** kwargs )
172
172
self .similarities = similarities
173
- for func in similarities if isinstance (similarities , list ) else (similarities ,):
174
- if func not in self .ALLOWED_SIMILARITY_FUNCTIONS :
173
+ self ._multiple_similarities = isinstance (similarities , tuple | list )
174
+ for func in similarities if self ._multiple_similarities else (similarities ,):
175
+ if func not in self .VALID_SIMILARITIES :
175
176
raise ValueError (
176
- f"{ func } isn't a valid similarity function, options "
177
- f"are { ', ' .join (sorted (self .ALLOWED_SIMILARITY_FUNCTIONS ))} "
177
+ f"' { func } ' isn't a valid similarity function "
178
+ f"( { ', ' .join (sorted (self .VALID_SIMILARITIES ))} ). "
178
179
)
179
- viewed = set ()
180
+ seen_fields = set ()
180
181
for field_name , _ in self .fields_orders :
181
- if field_name in viewed :
182
- raise ValueError (
183
- f"Field '{ field_name } ' is defined more than once. Vector and filter "
184
- "fields must use distinct field names." ,
185
- )
186
- viewed .add (field_name )
182
+ if field_name in seen_fields :
183
+ raise ValueError (f"Field '{ field_name } ' is duplicated in fields." )
184
+ seen_fields .add (field_name )
187
185
188
186
def check (self , model , connection ):
189
187
errors = super ().check (model , connection )
@@ -197,20 +195,20 @@ def check(self, model, connection):
197
195
except (ValueError , TypeError ):
198
196
errors .append (
199
197
Error (
200
- f"Atlas vector search requires size on { field_name } ." ,
198
+ f"VectorSearchIndex requires ' size' on field ' { field_name } ' ." ,
201
199
obj = model ,
202
- id = f"{ self ._error_id_prefix } .E001 " ,
200
+ id = f"{ self ._error_id_prefix } .E002 " ,
203
201
)
204
202
)
205
203
if not isinstance (field_ .base_field , FloatField | DecimalField ):
206
204
errors .append (
207
205
Error (
208
- "An Atlas vector search index requires the base "
209
- "field of ArrayField Model.field_name "
210
- "to be FloatField or DecimalField but "
211
- f"is { field_ .base_field .get_internal_type ()} ." ,
206
+ "VectorSearchIndex requires the base field of "
207
+ f" ArrayField ' { field_ . name } ' to be FloatField or "
208
+ "DecimalField but is "
209
+ f"{ field_ .base_field .get_internal_type ()} ." ,
212
210
obj = model ,
213
- id = f"{ self ._error_id_prefix } .E002 " ,
211
+ id = f"{ self ._error_id_prefix } .E003 " ,
214
212
)
215
213
)
216
214
else :
@@ -223,24 +221,22 @@ def check(self, model, connection):
223
221
errors .append (
224
222
Error (
225
223
"VectorSearchIndex does not support "
226
- f"' { field_ .get_internal_type ()} ' { field_name } ." ,
224
+ f"{ field_ .get_internal_type ()} ' { field_name } ' ." ,
227
225
obj = model ,
228
- id = f"{ self ._error_id_prefix } .E003 " ,
226
+ id = f"{ self ._error_id_prefix } .E004 " ,
229
227
)
230
228
)
231
- if isinstance ( self .similarities , list ) and expected_similarities != len (self .similarities ):
229
+ if self ._multiple_similarities and expected_similarities != len (self .similarities ):
232
230
similarity_function_text = (
233
- "similarities functions " if expected_similarities != 1 else "similarity function "
231
+ "similarity " if expected_similarities == 1 else "similarities "
234
232
)
235
233
errors .append (
236
234
Error (
237
- f"An Atlas vector search index requires the same number of similarities and "
238
- f"vector fields, but { expected_similarities } "
239
- f"{ similarity_function_text } were expected and "
240
- f"{ len (self .similarities )} { 'were' if len (self .similarities ) != 1 else 'was' } "
241
- "provided." ,
235
+ f"VectorSearchIndex requires the same number of similarities and "
236
+ f"vector fields; expected { expected_similarities } "
237
+ f"{ similarity_function_text } but got { len (self .similarities )} ." ,
242
238
obj = model ,
243
- id = f"{ self ._error_id_prefix } .E004 " ,
239
+ id = f"{ self ._error_id_prefix } .E005 " ,
244
240
)
245
241
)
246
242
return errors
0 commit comments