Skip to content

Commit b1312cf

Browse files
authored
feat: add with_scores support to search command (#59)
Co-authored-by: Guy Korland <[email protected]>
1 parent bc3436e commit b1312cf

File tree

5 files changed

+51
-10
lines changed

5 files changed

+51
-10
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ res = client.search("search engine")
5555
res = client.search("search engine", snippet_sizes = {'body': 50})
5656

5757
# Searching with complext parameters:
58-
q = Query("search engine").verbatim().no_content().paging(0,5)
58+
q = Query("search engine").verbatim().no_content().with_scores().paging(0,5)
5959
res = client.search(q)
6060

6161

redisearch/client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ def search(self, query):
361361
return Result(res,
362362
not query._no_content,
363363
duration=(time.time() - st) * 1000.0,
364-
has_payload=query._with_payloads)
364+
has_payload=query._with_payloads,
365+
with_scores=query._with_scores)
365366

366367
def explain(self, query):
367368
args, query_text = self._mk_query_args(query)

redisearch/query.py

+11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, query_string):
2323
self._fields = None
2424
self._verbatim = False
2525
self._with_payloads = False
26+
self._with_scores = False
2627
self._filters = list()
2728
self._ids = None
2829
self._slop = -1
@@ -157,6 +158,9 @@ def get_args(self):
157158

158159
if self._with_payloads:
159160
args.append('WITHPAYLOADS')
161+
162+
if self._with_scores:
163+
args.append('WITHSCORES')
160164

161165
if self._ids:
162166
args.append('INKEYS')
@@ -225,6 +229,13 @@ def with_payloads(self):
225229
"""
226230
self._with_payloads = True
227231
return self
232+
233+
def with_scores(self):
234+
"""
235+
Ask the engine to return document search scores
236+
"""
237+
self._with_scores = True
238+
return self
228239

229240
def limit_fields(self, *fields):
230241
"""

redisearch/result.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class Result(object):
88
Represents the result of a search query, and has an array of Document objects
99
"""
1010

11-
def __init__(self, res, hascontent, duration=0, has_payload = False):
11+
def __init__(self, res, hascontent, duration=0, has_payload = False, with_scores = False):
1212
"""
1313
- **snippets**: An optional dictionary of the form {field: snippet_size} for snippet formatting
1414
"""
@@ -19,15 +19,20 @@ def __init__(self, res, hascontent, duration=0, has_payload = False):
1919

2020
step = 1
2121
if hascontent:
22-
step = 3 if has_payload else 2
23-
else:
24-
# we can't have nocontent and payloads in the same response
25-
has_payload = False
22+
step = step + 1
23+
if has_payload:
24+
step = step + 1
25+
if with_scores:
26+
step = step + 1
27+
28+
offset = 2 if with_scores else 1
2629

2730
for i in xrange(1, len(res), step):
2831
id = to_string(res[i])
29-
payload = to_string(res[i+1]) if has_payload else None
30-
fields_offset = 2 if has_payload else 1
32+
payload = to_string(res[i+offset]) if has_payload else None
33+
#fields_offset = 2 if has_payload else 1
34+
fields_offset = offset+1 if has_payload else offset
35+
score = float(res[i+1]) if with_scores else None
3136

3237
fields = {}
3338
if hascontent:
@@ -40,7 +45,7 @@ def __init__(self, res, hascontent, duration=0, has_payload = False):
4045
except KeyError:
4146
pass
4247

43-
doc = Document(id, payload=payload, **fields)
48+
doc = Document(id, score=score, payload=payload, **fields) if with_scores else Document(id, payload=payload, **fields)
4449
self.docs.append(doc)
4550

4651
def __repr__(self):

test/test.py

+24
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,30 @@ def testPayloads(self):
194194
self.assertEqual('foo baz', res.docs[1].payload)
195195
self.assertIsNone(res.docs[0].payload)
196196

197+
def testScores(self):
198+
199+
conn = self.redis()
200+
201+
with conn as r:
202+
# Creating a client with a given index name
203+
client = Client('idx', port=conn.port)
204+
client.redis.flushdb()
205+
client.create_index((TextField('txt'),))
206+
207+
client.add_document('doc1', txt = 'foo baz')
208+
client.add_document('doc2', txt = 'foo bar')
209+
210+
q = Query("foo ~bar").with_scores()
211+
res = client.search(q)
212+
print("RES", res)
213+
self.assertEqual(2, res.total)
214+
215+
self.assertEqual('doc2', res.docs[0].id)
216+
self.assertEqual(3.0, res.docs[0].score)
217+
218+
self.assertEqual('doc1', res.docs[1].id)
219+
self.assertEqual(0.2, res.docs[1].score)
220+
197221
def testReplace(self):
198222

199223
conn = self.redis()

0 commit comments

Comments
 (0)