Skip to content

Commit 11739ed

Browse files
vpetrovykhaljazerzen
authored andcommitted
Add codecs for dealing with pgsparse vector. (#478)
Add codecs for converting to/from regular arrays to sparse vectors.
1 parent 113ed0d commit 11739ed

File tree

2 files changed

+230
-0
lines changed

2 files changed

+230
-0
lines changed

edgedb/protocol/codecs/codecs.pyx

+114
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,113 @@ cdef pgvector_decode(pgproto.CodecContext settings, FRBuffer *buf):
798798
return val
799799

800800

801+
# The pg_sparse extension uses a signed int16 when reading dimesion in binary
802+
# format.
803+
DEF PGSPARSE_MAX_DIM = (1 << 15) - 1
804+
805+
806+
cdef pgsparse_encode(pgproto.CodecContext settings, WriteBuffer buf,
807+
object obj):
808+
cdef:
809+
int16_t n_elem = 0
810+
int64_t dim
811+
Py_ssize_t i
812+
float[:] memview
813+
814+
# If we can take a typed memview of the object, we use that.
815+
# That is good, because it means we can consume array.array and
816+
# numpy.ndarray without needing to unbox.
817+
# Otherwise we take the slow path, indexing into the array using
818+
# the normal protocol.
819+
try:
820+
memview = obj
821+
except (ValueError, TypeError) as e:
822+
pass
823+
else:
824+
# The actual dimentionality of the vector is the size of the raw array
825+
dim = len(memview)
826+
if dim > PGSPARSE_MAX_DIM:
827+
raise ValueError('too many elements in vector value')
828+
829+
# First pass to count the number of non-zero elements
830+
for i in range(dim):
831+
if memview[i] != 0:
832+
n_elem += 1
833+
834+
buf.write_int32(6 + n_elem*8)
835+
buf.write_int16(n_elem)
836+
buf.write_int16(<int16_t>dim)
837+
buf.write_int16(0)
838+
# Second pass will write the actual non-zero elements
839+
for i in range(dim):
840+
if memview[i] != 0:
841+
buf.write_int32(i)
842+
buf.write_float(memview[i])
843+
return
844+
845+
# Annoyingly, this is literally identical code to the fast path...
846+
# but the types are different in critical ways.
847+
if not _is_array_iterable(obj):
848+
raise TypeError(
849+
'a sized iterable container expected (got type {!r})'.format(
850+
type(obj).__name__))
851+
852+
# The actual dimentionality of the vector is the size of the raw array
853+
dim = len(obj)
854+
if dim > PGSPARSE_MAX_DIM:
855+
raise ValueError('too many elements in vector value')
856+
857+
# First pass to count the number of non-zero elements
858+
for i in range(dim):
859+
if obj[i] != 0:
860+
n_elem += 1
861+
862+
buf.write_int32(6 + n_elem*8)
863+
buf.write_int16(n_elem)
864+
buf.write_int16(dim)
865+
buf.write_int16(0)
866+
# Second pass will write the actual non-zero elements
867+
for i in range(dim):
868+
if obj[i] != 0:
869+
buf.write_int32(i)
870+
buf.write_float(obj[i])
871+
872+
873+
cdef pgsparse_decode(pgproto.CodecContext settings, FRBuffer *buf):
874+
cdef:
875+
int16_t n_elem
876+
int16_t dim
877+
Py_ssize_t i
878+
int32_t index
879+
float[::1] array_view
880+
881+
n_elem = hton.unpack_int16(frb_read(buf, 2))
882+
dim = hton.unpack_int16(frb_read(buf, 2))
883+
frb_read(buf, 2)
884+
885+
# Create a float array with size dim
886+
val = ONE_EL_ARRAY * dim
887+
array_view = val
888+
889+
# The underlying sparse Vector representation supports int32 as the
890+
# dimension and index, but when converting to binary format the dimensions
891+
# are maxed out at int16. So indexes beyond the truncated dimension will
892+
# cause an exception.
893+
if dim < 0:
894+
# This is actually an indicator of overflow when converting from int32
895+
# down to int16.
896+
raise ValueError('too many elements in vector value')
897+
try:
898+
# Fill the non-zero elements
899+
for i in range(n_elem):
900+
index = hton.unpack_int32(frb_read(buf, 4))
901+
array_view[index] = hton.unpack_float(frb_read(buf, 4))
902+
except IndexError:
903+
raise ValueError('too many elements in vector value')
904+
905+
return val
906+
907+
801908
cdef checked_decimal_encode(
802909
pgproto.CodecContext settings, WriteBuffer buf, obj
803910
):
@@ -1007,5 +1114,12 @@ cdef register_base_scalar_codecs():
10071114
uuid.UUID('9565dd88-04f5-11ee-a691-0b6ebe179825'),
10081115
)
10091116

1117+
register_base_scalar_codec(
1118+
'ext::pgsparse::vector',
1119+
pgsparse_encode,
1120+
pgsparse_decode,
1121+
uuid.UUID('b646ace0-266d-47ce-8263-1224c38a4a12'),
1122+
)
1123+
10101124

10111125
register_base_scalar_codecs()

tests/test_vector.py

+116
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,119 @@ async def test_vector_01(self):
129129
''',
130130
'foo',
131131
)
132+
133+
134+
class TestSparseVector(tb.SyncQueryTestCase):
135+
def setUp(self):
136+
super().setUp()
137+
138+
if not self.client.query_required_single('''
139+
select exists (
140+
select sys::ExtensionPackage filter .name = 'pgsparse'
141+
)
142+
'''):
143+
self.skipTest("feature not implemented")
144+
145+
self.client.execute('''
146+
create extension pgsparse;
147+
''')
148+
149+
def tearDown(self):
150+
try:
151+
self.client.execute('''
152+
drop extension pgsparse;
153+
''')
154+
finally:
155+
super().tearDown()
156+
157+
async def test_vector_01(self):
158+
val = self.client.query_single('''
159+
select <ext::pgsparse::vector>[1.5,0,0,0,2.0,3.8]
160+
''')
161+
self.assertTrue(isinstance(val, array.array))
162+
self.assertEqual(val, array.array('f', [1.5, 0, 0, 0, 2.0, 3.8]))
163+
164+
val = self.client.query_single(
165+
'''
166+
select <json><ext::pgsparse::vector>$0
167+
''',
168+
[3.0, 9.0, -42.5],
169+
)
170+
self.assertEqual(val, '[3, 9, -42.5]')
171+
172+
val = self.client.query_single(
173+
'''
174+
select <json><ext::pgsparse::vector>$0
175+
''',
176+
array.array('f', [3.0, 9.0, -42.5])
177+
)
178+
self.assertEqual(val, '[3, 9, -42.5]')
179+
180+
val = self.client.query_single(
181+
'''
182+
select <json><ext::pgsparse::vector>$0
183+
''',
184+
array.array('i', [1, 2, 3]),
185+
)
186+
self.assertEqual(val, '[1, 2, 3]')
187+
188+
val = self.client.query_single(
189+
'''
190+
select <ext::pgsparse::vector>$0
191+
''',
192+
array.array('f', ([0] * 10000) + [1, 2]),
193+
)
194+
self.assertEqual(val, array.array('f', ([0] * 10000) + [1, 2]))
195+
196+
val = self.client.query_single(
197+
'''
198+
with zeros := array_agg(
199+
(for x in range_unpack(range(0, 20000)) union 0)
200+
)
201+
select <ext::pgsparse::vector>(zeros ++ [1, 2]);
202+
''',
203+
)
204+
self.assertEqual(val, array.array('f', ([0] * 20000) + [1, 2]))
205+
206+
# Some sad path tests
207+
with self.assertRaises(edgedb.InvalidArgumentError):
208+
self.client.query_single(
209+
'''
210+
select <ext::pgsparse::vector>$0
211+
''',
212+
[3.0, None, -42.5],
213+
)
214+
215+
with self.assertRaises(edgedb.InvalidArgumentError):
216+
self.client.query_single(
217+
'''
218+
select <ext::pgsparse::vector>$0
219+
''',
220+
[3.0, 'x', -42.5],
221+
)
222+
223+
with self.assertRaises(edgedb.InvalidArgumentError):
224+
self.client.query_single(
225+
'''
226+
select <ext::pgsparse::vector>$0
227+
''',
228+
'foo',
229+
)
230+
231+
with self.assertRaises(edgedb.InvalidArgumentError):
232+
self.client.query_single(
233+
'''
234+
select <ext::pgsparse::vector>$0
235+
''',
236+
array.array('f', ([0] * 50000) + [1, 2]),
237+
)
238+
239+
with self.assertRaises(edgedb.ClientError):
240+
self.client.query_single(
241+
'''
242+
with zeros := array_agg(
243+
(for x in range_unpack(range(0, 50000)) union 0)
244+
)
245+
select <ext::pgsparse::vector>(zeros ++ [1, 2]);
246+
''',
247+
)

0 commit comments

Comments
 (0)