Skip to content

Commit 3fac8f9

Browse files
committed
SA20: Add CrateDB-specific patches to CrateCompilerSA20
This effectively concludes the support for SQLAlchemy 2.0, by implementing the same strategy as for the previous versions: After vendoring the vanilla dialect's compiler's `visit_update` and `_get_crud_params` methods, they are patched at a few spots to accommodate CrateDB's features. The changed behavior is mostly for running update statements on nested `OBJECT` or `ARRAY` data types.
1 parent bf2b6f5 commit 3fac8f9

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

src/crate/client/sqlalchemy/compat/core20.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222
from typing import Any, Dict, List, MutableMapping, Optional, Tuple, Union
2323

24+
import sqlalchemy as sa
2425
from sqlalchemy import ColumnClause, ValuesBase, cast, exc
25-
from sqlalchemy.sql import crud, dml
26+
from sqlalchemy.sql import dml
2627
from sqlalchemy.sql.base import _from_objects
2728
from sqlalchemy.sql.compiler import SQLCompiler
2829
from sqlalchemy.sql.crud import (REQUIRED, _as_dml_column, _create_bind_param,
@@ -47,6 +48,11 @@ def visit_update(self, update_stmt, **kw):
4748
)
4849
update_stmt = compile_state.statement
4950

51+
# [20] CrateDB patch.
52+
if not compile_state._dict_parameters and \
53+
not hasattr(update_stmt, '_crate_specific'):
54+
return super().visit_update(update_stmt, **kw)
55+
5056
toplevel = not self.stack
5157
if toplevel:
5258
self.isupdate = True
@@ -87,7 +93,8 @@ def visit_update(self, update_stmt, **kw):
8793
table_text = self.update_tables_clause(
8894
update_stmt, update_stmt.table, render_extra_froms, **kw
8995
)
90-
crud_params_struct = crud._get_crud_params(
96+
# [20] CrateDB patch.
97+
crud_params_struct = _get_crud_params(
9198
self, update_stmt, compile_state, toplevel, **kw
9299
)
93100
crud_params = crud_params_struct.single_params
@@ -105,12 +112,38 @@ def visit_update(self, update_stmt, **kw):
105112
text += table_text
106113

107114
text += " SET "
115+
116+
# [20] CrateDB patch begin.
117+
include_table = extra_froms and \
118+
self.render_table_with_column_in_update_from
119+
120+
set_clauses = []
121+
122+
for c, expr, value, _ in crud_params:
123+
key = c._compiler_dispatch(self, include_table=include_table)
124+
clause = key + ' = ' + value
125+
set_clauses.append(clause)
126+
127+
for k, v in compile_state._dict_parameters.items():
128+
if isinstance(k, str) and '[' in k:
129+
bindparam = sa.sql.bindparam(k, v)
130+
clause = k + ' = ' + self.process(bindparam)
131+
set_clauses.append(clause)
132+
133+
text += ', '.join(set_clauses)
134+
# [20] CrateDB patch end.
135+
136+
"""
137+
# TODO: Complete SA20 migration.
138+
# This is the column name/value joining code from SA20.
139+
# It may be sensible to use this procedure instead of the old one.
108140
text += ", ".join(
109141
expr + "=" + value
110142
for _, expr, value, _ in cast(
111143
"List[Tuple[Any, str, str, Any]]", crud_params
112144
)
113145
)
146+
"""
114147

115148
if self.implicit_returning or update_stmt._returning:
116149
if self.returning_precedes_values:
@@ -356,6 +389,24 @@ def _get_crud_params(
356389
kw,
357390
)
358391

392+
# [20] CrateDB patch.
393+
#
394+
# This sanity check performed by SQLAlchemy currently needs to be
395+
# deactivated in order to satisfy the rewriting logic of the CrateDB
396+
# dialect in `rewrite_update` and `visit_update`.
397+
#
398+
# It can be quickly reproduced by activating this section and running the
399+
# test cases::
400+
#
401+
# ./bin/test -vvvv -t dict_test
402+
#
403+
# That croaks like::
404+
#
405+
# sqlalchemy.exc.CompileError: Unconsumed column names: characters_name
406+
#
407+
# TODO: Investigate why this is actually happening and eventually mitigate
408+
# the root cause.
409+
"""
359410
if parameters and stmt_parameter_tuples:
360411
check = (
361412
set(parameters)
@@ -367,6 +418,7 @@ def _get_crud_params(
367418
"Unconsumed column names: %s"
368419
% (", ".join("%s" % (c,) for c in check))
369420
)
421+
"""
370422

371423
if (
372424
_compile_state_isinsert(compile_state)

src/crate/client/sqlalchemy/dialect.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
CrateDDLCompiler
3333
)
3434
from crate.client.exceptions import TimezoneUnawareException
35-
from .sa_version import SA_VERSION, SA_1_4
35+
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
3636
from .types import Object, ObjectArray
3737

3838
TYPES_MAP = {
@@ -155,7 +155,10 @@ def process(value):
155155
}
156156

157157

158-
if SA_VERSION >= SA_1_4:
158+
if SA_VERSION >= SA_2_0:
159+
from .compat.core20 import CrateCompilerSA20
160+
statement_compiler = CrateCompilerSA20
161+
elif SA_VERSION >= SA_1_4:
159162
from .compat.core14 import CrateCompilerSA14
160163
statement_compiler = CrateCompilerSA14
161164
else:

src/crate/client/sqlalchemy/sa_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
SA_VERSION = V(sa.__version__)
2626

2727
SA_1_4 = V('1.4.0b1')
28+
SA_2_0 = V('2.0.0b1')

src/crate/client/sqlalchemy/tests/bulk_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def test_bulk_save(self):
7878
('Banshee', 26),
7979
('Callisto', 37)
8080
)
81-
self.assertEqual(expected_bulk_args, bulk_args)
81+
self.assertSequenceEqual(expected_bulk_args, bulk_args)

0 commit comments

Comments
 (0)