Skip to content

Commit de1be4d

Browse files
rbs333abrookins
authored andcommitted
Add option to normalize vector distances on query (#298)
This pr accomplishes 2 goals: 1. Add an option for users to easily get back a similarity value between 0 and 1 that they might expect to compare against other vector dbs. 2. Fix the current bug that `distance_threshold` is validated to be between 0 and 1 when in reality it can take values between 0 and 2. > Note: after much careful thought I believe it is best that for `0.5.0` we do **not** start enforcing all distance_thresholds between 0 and 1 and move to this option as default behavior. Ideally this metric would be consistent throughout our code and I don't love supporting this flag but I think it provides the value that is scoped for this ticket while inflicting the least amount of pain and confusion. Changes: 1. Adds the `normalize_vector_distance` flag to VectorQuery and VectorRangeQuery. Behavior: - If set to `True` it normalizes values returned from redis to a value between 0 and 1. - For cosine similarity, it applies `(2 - value)/2`. - For L2 distance, it applies normalization `(1/(1+value))`. - For IP, it does nothing and throws a warning since normalized IP is cosine by definition. - For VectorRangeQuery, if `normalize_vector_distance=True` the distance threshold is now validated to be between 0 and 1 and denormalized for execution against the database to make consistent. 2. Relaxes validation for semantic caching and routing to be between 0 and 2 fixing the current bug and aligning with how the database actually functions.
1 parent 4d08030 commit de1be4d

File tree

11 files changed

+258
-32
lines changed

11 files changed

+258
-32
lines changed

redisvl/extensions/llmcache/semantic.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SemanticCacheIndexSchema,
2323
)
2424
from redisvl.index import AsyncSearchIndex, SearchIndex
25-
from redisvl.query import RangeQuery
25+
from redisvl.query import VectorRangeQuery
2626
from redisvl.query.filter import FilterExpression
2727
from redisvl.query.query import BaseQuery
2828
from redisvl.redis.connection import RedisConnectionFactory
@@ -237,9 +237,9 @@ def set_threshold(self, distance_threshold: float) -> None:
237237
Raises:
238238
ValueError: If the threshold is not between 0 and 1.
239239
"""
240-
if not 0 <= float(distance_threshold) <= 1:
240+
if not 0 <= float(distance_threshold) <= 2:
241241
raise ValueError(
242-
f"Distance must be between 0 and 1, got {distance_threshold}"
242+
f"Distance must be between 0 and 2, got {distance_threshold}"
243243
)
244244
self._distance_threshold = float(distance_threshold)
245245

@@ -389,7 +389,7 @@ def check(
389389
vector = vector or self._vectorize_prompt(prompt)
390390
self._check_vector_dims(vector)
391391

392-
query = RangeQuery(
392+
query = VectorRangeQuery(
393393
vector=vector,
394394
vector_field_name=CACHE_VECTOR_FIELD_NAME,
395395
return_fields=self.return_fields,
@@ -472,14 +472,15 @@ async def acheck(
472472
vector = vector or await self._avectorize_prompt(prompt)
473473
self._check_vector_dims(vector)
474474

475-
query = RangeQuery(
475+
query = VectorRangeQuery(
476476
vector=vector,
477477
vector_field_name=CACHE_VECTOR_FIELD_NAME,
478478
return_fields=self.return_fields,
479479
distance_threshold=distance_threshold,
480480
num_results=num_results,
481481
return_score=True,
482482
filter_expression=filter_expression,
483+
normalize_vector_distance=True,
483484
)
484485

485486
# Search the cache!

redisvl/extensions/router/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Route(BaseModel):
1818
"""List of reference phrases for the route."""
1919
metadata: Dict[str, Any] = Field(default={})
2020
"""Metadata associated with the route."""
21-
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5
21+
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5
2222
"""Distance threshold for matching the route."""
2323

2424
@field_validator("name")

redisvl/extensions/router/semantic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SemanticRouterIndexSchema,
1818
)
1919
from redisvl.index import SearchIndex
20-
from redisvl.query import RangeQuery
20+
from redisvl.query import VectorRangeQuery
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
2323
from redisvl.utils.utils import deprecated_argument, model_to_dict
@@ -244,7 +244,7 @@ def _distance_threshold_filter(self) -> str:
244244

245245
def _build_aggregate_request(
246246
self,
247-
vector_range_query: RangeQuery,
247+
vector_range_query: VectorRangeQuery,
248248
aggregation_method: DistanceAggregationMethod,
249249
max_k: int,
250250
) -> AggregateRequest:
@@ -286,7 +286,7 @@ def _get_route_matches(
286286
# therefore you might take the max_threshold and further refine from there.
287287
distance_threshold = max(route.distance_threshold for route in self.routes)
288288

289-
vector_range_query = RangeQuery(
289+
vector_range_query = VectorRangeQuery(
290290
vector=vector,
291291
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
292292
distance_threshold=float(distance_threshold),

redisvl/index/index.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,15 @@
3939
SchemaValidationError,
4040
)
4141
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
42-
from redisvl.query import BaseQuery, CountQuery, FilterQuery
42+
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
4343
from redisvl.query.filter import FilterExpression
4444
from redisvl.redis.connection import (
4545
RedisConnectionFactory,
4646
convert_index_info_to_schema,
4747
)
4848
from redisvl.redis.utils import convert_bytes
4949
from redisvl.schema import IndexSchema, StorageType
50+
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
5051
from redisvl.utils.log import get_logger
5152

5253
logger = get_logger(__name__)
@@ -67,7 +68,7 @@
6768

6869

6970
def process_results(
70-
results: "Result", query: BaseQuery, storage_type: StorageType
71+
results: "Result", query: BaseQuery, schema: IndexSchema
7172
) -> List[Dict[str, Any]]:
7273
"""Convert a list of search Result objects into a list of document
7374
dictionaries.
@@ -92,11 +93,24 @@ def process_results(
9293

9394
# Determine if unpacking JSON is needed
9495
unpack_json = (
95-
(storage_type == StorageType.JSON)
96+
(schema.index.storage_type == StorageType.JSON)
9697
and isinstance(query, FilterQuery)
9798
and not query._return_fields # type: ignore
9899
)
99100

101+
if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance:
102+
dist_metric = VectorDistanceMetric(
103+
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
104+
)
105+
if dist_metric == VectorDistanceMetric.IP:
106+
warnings.warn(
107+
"Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition."
108+
)
109+
110+
norm_fn = VECTOR_NORM_MAP[dist_metric.value]
111+
else:
112+
norm_fn = None
113+
100114
# Process records
101115
def _process(doc: "Document") -> Dict[str, Any]:
102116
doc_dict = doc.__dict__
@@ -110,6 +124,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
110124
return {"id": doc_dict.get("id"), **json_data}
111125
raise ValueError(f"Unable to parse json data from Redis {json_data}")
112126

127+
if norm_fn:
128+
# convert float back to string to be consistent
129+
doc_dict[query.DISTANCE_ID] = str( # type: ignore
130+
norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore
131+
)
132+
113133
# Remove 'payload' if present
114134
doc_dict.pop("payload", None)
115135

@@ -740,11 +760,7 @@ def batch_query(
740760
)
741761
all_parsed = []
742762
for query, batch_results in zip(queries, results):
743-
parsed = process_results(
744-
batch_results,
745-
query=query,
746-
storage_type=self.schema.index.storage_type,
747-
)
763+
parsed = process_results(batch_results, query=query, schema=self.schema)
748764
# Create separate lists of parsed results for each query
749765
# passed in to the batch_search method, so that callers can
750766
# access the results for each query individually
@@ -754,9 +770,7 @@ def batch_query(
754770
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
755771
"""Execute a query and process results."""
756772
results = self.search(query.query, query_params=query.params)
757-
return process_results(
758-
results, query=query, storage_type=self.schema.index.storage_type
759-
)
773+
return process_results(results, query=query, schema=self.schema)
760774

761775
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
762776
"""Execute a query on the index.
@@ -1398,7 +1412,7 @@ async def batch_query(
13981412
parsed = process_results(
13991413
batch_results,
14001414
query=query,
1401-
storage_type=self.schema.index.storage_type,
1415+
schema=self.schema,
14021416
)
14031417
# Create separate lists of parsed results for each query
14041418
# passed in to the batch_search method, so that callers can
@@ -1410,9 +1424,7 @@ async def batch_query(
14101424
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14111425
"""Asynchronously execute a query and process results."""
14121426
results = await self.search(query.query, query_params=query.params)
1413-
return process_results(
1414-
results, query=query, storage_type=self.schema.index.storage_type
1415-
)
1427+
return process_results(results, query=query, schema=self.schema)
14161428

14171429
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14181430
"""Asynchronously execute a query on the index.

redisvl/query/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from redisvl.query.query import (
22
BaseQuery,
3+
BaseVectorQuery,
34
CountQuery,
45
FilterQuery,
56
RangeQuery,
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"BaseQuery",
13+
"BaseVectorQuery",
1214
"VectorQuery",
1315
"FilterQuery",
1416
"RangeQuery",

redisvl/query/query.py

+34
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from redisvl.query.filter import FilterExpression
77
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.utils.utils import denorm_cosine_distance
89

910

1011
class BaseQuery(RedisQuery):
@@ -175,6 +176,8 @@ class BaseVectorQuery:
175176
DISTANCE_ID: str = "vector_distance"
176177
VECTOR_PARAM: str = "vector"
177178

179+
_normalize_vector_distance: bool = False
180+
178181

179182
class HybridPolicy(str, Enum):
180183
"""Enum for valid hybrid policy options in vector queries."""
@@ -198,6 +201,7 @@ def __init__(
198201
in_order: bool = False,
199202
hybrid_policy: Optional[str] = None,
200203
batch_size: Optional[int] = None,
204+
normalize_vector_distance: bool = False,
201205
):
202206
"""A query for running a vector search along with an optional filter
203207
expression.
@@ -233,6 +237,12 @@ def __init__(
233237
of vectors to fetch in each batch. Larger values may improve performance
234238
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
235239
Defaults to None, which lets Redis auto-select an appropriate batch size.
240+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
241+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
242+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
243+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
244+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
245+
throw a warning since by definition COSINE similarity is normalized IP.
236246
237247
Raises:
238248
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -246,6 +256,7 @@ def __init__(
246256
self._num_results = num_results
247257
self._hybrid_policy: Optional[HybridPolicy] = None
248258
self._batch_size: Optional[int] = None
259+
self._normalize_vector_distance = normalize_vector_distance
249260
self.set_filter(filter_expression)
250261
query_string = self._build_query_string()
251262

@@ -394,6 +405,7 @@ def __init__(
394405
in_order: bool = False,
395406
hybrid_policy: Optional[str] = None,
396407
batch_size: Optional[int] = None,
408+
normalize_vector_distance: bool = False,
397409
):
398410
"""A query for running a filtered vector search based on semantic
399411
distance threshold.
@@ -437,6 +449,19 @@ def __init__(
437449
of vectors to fetch in each batch. Larger values may improve performance
438450
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
439451
Defaults to None, which lets Redis auto-select an appropriate batch size.
452+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
453+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
454+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
455+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
456+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
457+
throw a warning since by definition COSINE similarity is normalized IP.
458+
459+
Raises:
460+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
461+
462+
Note:
463+
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
464+
440465
"""
441466
self._vector = vector
442467
self._vector_field_name = vector_field_name
@@ -456,6 +481,7 @@ def __init__(
456481
if batch_size is not None:
457482
self.set_batch_size(batch_size)
458483

484+
self._normalize_vector_distance = normalize_vector_distance
459485
self.set_distance_threshold(distance_threshold)
460486
self.set_filter(filter_expression)
461487
query_string = self._build_query_string()
@@ -493,6 +519,14 @@ def set_distance_threshold(self, distance_threshold: float):
493519
raise TypeError("distance_threshold must be of type float or int")
494520
if distance_threshold < 0:
495521
raise ValueError("distance_threshold must be non-negative")
522+
if self._normalize_vector_distance:
523+
if distance_threshold > 1:
524+
raise ValueError(
525+
"distance_threshold must be between 0 and 1 when normalize_vector_distance is set to True"
526+
)
527+
528+
# User sets normalized value 0-1 denormalize for use in DB
529+
distance_threshold = denorm_cosine_distance(distance_threshold)
496530
self._distance_threshold = distance_threshold
497531

498532
# Reset the query string

redisvl/schema/fields.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
from redis.commands.search.field import TextField as RedisTextField
1717
from redis.commands.search.field import VectorField as RedisVectorField
1818

19+
from redisvl.utils.utils import norm_cosine_distance, norm_l2_distance
20+
21+
VECTOR_NORM_MAP = {
22+
"COSINE": norm_cosine_distance,
23+
"L2": norm_l2_distance,
24+
"IP": None, # normalized inner product is cosine similarity by definition
25+
}
26+
1927

2028
class FieldTypes(str, Enum):
2129
TAG = "tag"

redisvl/utils/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,22 @@ def wrapper():
191191
return
192192

193193
return wrapper
194+
195+
196+
def norm_cosine_distance(value: float) -> float:
197+
"""
198+
Normalize the cosine distance to a similarity score between 0 and 1.
199+
"""
200+
return max((2 - value) / 2, 0)
201+
202+
203+
def denorm_cosine_distance(value: float) -> float:
204+
"""Denormalize the distance threshold from [0, 1] to [0, 1] for our db."""
205+
return max(2 - 2 * value, 0)
206+
207+
208+
def norm_l2_distance(value: float) -> float:
209+
"""
210+
Normalize the L2 distance.
211+
"""
212+
return 1 / (1 + value)

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def sample_data(sample_datetimes):
151151
"last_updated": sample_datetimes["high"].timestamp(),
152152
"credit_score": "medium",
153153
"location": "-110.0839,37.3861",
154-
"user_embedding": [0.9, 0.9, 0.1],
154+
"user_embedding": [-0.1, -0.1, -0.5],
155155
},
156156
]
157157

0 commit comments

Comments
 (0)