Skip to content

Commit 60da99a

Browse files
msullivanaljazerzen
authored andcommitted
Support passing dicts and namedtuples for namedtuple arguments (#473)
We support dicts (and other mappings), and for named tuples we do lookup by name instead of by position. Fixes #374.
1 parent ff08d9d commit 60da99a

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

edgedb/protocol/codecs/base.pyx

+82
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import codecs
2121

22+
from collections.abc import Mapping as MappingABC
23+
2224

2325
cdef uint64_t RECORD_ENCODER_CHECKED = 1 << 0
2426
cdef uint64_t RECORD_ENCODER_INVALID = 1 << 1
@@ -225,6 +227,86 @@ cdef class BaseNamedRecordCodec(BaseRecordCodec):
225227
(<BaseCodec>codec).dump(level + 1).strip()))
226228
return '\n'.join(buf)
227229

230+
cdef encode(self, WriteBuffer buf, object obj):
231+
cdef:
232+
WriteBuffer elem_data
233+
Py_ssize_t objlen
234+
Py_ssize_t i
235+
BaseCodec sub_codec
236+
Py_ssize_t is_dict
237+
Py_ssize_t is_namedtuple
238+
239+
self._check_encoder()
240+
241+
# We check in this order (dict, _is_array_iterable,
242+
# MappingABC) so that in the common case of dict or tuple, we
243+
# never do an ABC check.
244+
if cpython.PyDict_Check(obj):
245+
is_dict = True
246+
elif _is_array_iterable(obj):
247+
is_dict = False
248+
elif isinstance(obj, MappingABC):
249+
is_dict = True
250+
else:
251+
raise TypeError(
252+
'a sized iterable container or mapping '
253+
'expected (got type {!r})'.format(
254+
type(obj).__name__))
255+
is_namedtuple = not is_dict and hasattr(obj, '_fields')
256+
257+
objlen = len(obj)
258+
if objlen == 0:
259+
buf.write_bytes(EMPTY_RECORD_DATA)
260+
return
261+
262+
if objlen > _MAXINT32:
263+
raise ValueError('too many elements for a tuple')
264+
265+
if objlen != len(self.fields_codecs):
266+
raise ValueError(
267+
f'expected {len(self.fields_codecs)} elements in the tuple, '
268+
f'got {objlen}')
269+
270+
elem_data = WriteBuffer.new()
271+
for i in range(objlen):
272+
if is_dict:
273+
name = datatypes.record_desc_pointer_name(self.descriptor, i)
274+
try:
275+
item = obj[name]
276+
except KeyError:
277+
raise ValueError(
278+
f"named tuple dict is missing '{name}' key",
279+
) from None
280+
elif is_namedtuple:
281+
name = datatypes.record_desc_pointer_name(self.descriptor, i)
282+
try:
283+
item = getattr(obj, name)
284+
except AttributeError:
285+
raise ValueError(
286+
f"named tuple is missing '{name}' attribute",
287+
) from None
288+
else:
289+
item = obj[i]
290+
291+
elem_data.write_int32(0) # reserved bytes
292+
if item is None:
293+
elem_data.write_int32(-1)
294+
else:
295+
sub_codec = <BaseCodec>(self.fields_codecs[i])
296+
try:
297+
sub_codec.encode(elem_data, item)
298+
except (TypeError, ValueError) as e:
299+
value_repr = repr(item)
300+
if len(value_repr) > 40:
301+
value_repr = value_repr[:40] + '...'
302+
raise errors.InvalidArgumentError(
303+
'invalid input for query argument'
304+
' ${n}: {v} ({msg})'.format(
305+
n=i, v=value_repr, msg=e)) from e
306+
307+
buf.write_int32(4 + elem_data.len()) # buffer length
308+
buf.write_int32(<int32_t><uint32_t>objlen)
309+
buf.write_buffer(elem_data)
228310

229311
@cython.final
230312
cdef class EdegDBCodecContext(pgproto.CodecContext):

tests/test_namedtuples.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#
2+
# This source file is part of the EdgeDB open source project.
3+
#
4+
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
20+
from collections import namedtuple, UserDict
21+
22+
import edgedb
23+
from edgedb import _testbase as tb
24+
25+
26+
class TestNamedTupleTypes(tb.SyncQueryTestCase):
27+
28+
async def test_namedtuple_01(self):
29+
NT1 = namedtuple('NT2', ['x', 'y'])
30+
NT2 = namedtuple('NT2', ['y', 'x'])
31+
32+
ctors = [dict, UserDict, NT1, NT2]
33+
for ctor in ctors:
34+
val = ctor(x=10, y='y')
35+
res = self.client.query_single('''
36+
select <tuple<x: int64, y: str>>$0
37+
''', val)
38+
39+
self.assertEqual(res, (10, 'y'))
40+
41+
async def test_namedtuple_02(self):
42+
NT1 = namedtuple('NT2', ['x', 'z'])
43+
44+
with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'is missing'):
45+
self.client.query_single('''
46+
select <tuple<x: int64, y: str>>$0
47+
''', dict(x=20, z='test'))
48+
49+
with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'is missing'):
50+
self.client.query_single('''
51+
select <tuple<x: int64, y: str>>$0
52+
''', NT1(x=20, z='test'))

0 commit comments

Comments
 (0)