diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 2c709a3b..e63d12bb 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: language: [ python ] - sqla-version: ['1.3.24', '1.4.46'] + sqla-version: ['1.3.24', '1.4.46', '2.0.0'] steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1c0f67ac..d406d61e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: os: ['ubuntu-latest', 'macos-latest'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] cratedb-version: ['5.2.0'] - sqla-version: ['1.3.24', '1.4.46'] + sqla-version: ['1.3.24', '1.4.46', '2.0.0'] # To save resources, only use the most recent Python version on macOS. exclude: - os: 'macos-latest' diff --git a/CHANGES.txt b/CHANGES.txt index 7d58509a..5a68b70a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -5,9 +5,11 @@ Changes for crate Unreleased ========== -- Add deprecation warning about dropping support for SQLAlchemy 1.3 soon, it is +- Added deprecation warning about dropping support for SQLAlchemy 1.3 soon, it is effectively EOL. +- Added support for SQLAlchemy 2.0. + 2022/12/08 0.29.0 ================= diff --git a/docs/by-example/sqlalchemy/crud.rst b/docs/by-example/sqlalchemy/crud.rst index d2840c52..5a62df40 100644 --- a/docs/by-example/sqlalchemy/crud.rst +++ b/docs/by-example/sqlalchemy/crud.rst @@ -130,7 +130,7 @@ Retrieve Using the connection to execute a select statement: - >>> result = connection.execute('select name from locations order by name') + >>> result = connection.execute(text('select name from locations order by name')) >>> result.rowcount 14 diff --git a/setup.cfg b/setup.cfg index d5ca9e8e..f60de556 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,4 +2,4 @@ universal = 1 [flake8] -ignore = E501, C901, W504 +ignore = E501, C901, W503, W504 diff --git a/setup.py b/setup.py index 506f25a7..f4923c0d 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ def read(path): }, install_requires=['urllib3>=1.9,<3'], extras_require=dict( - sqlalchemy=['sqlalchemy>=1.0,<1.5', + sqlalchemy=['sqlalchemy>=1.0,<2.1', 'geojson>=2.5.0,<3', 'backports.zoneinfo<1; python_version<"3.9"'], test=['tox>=3,<4', diff --git a/src/crate/client/sqlalchemy/compat/core10.py b/src/crate/client/sqlalchemy/compat/core10.py index e8d7deab..92c62dd8 100644 --- a/src/crate/client/sqlalchemy/compat/core10.py +++ b/src/crate/client/sqlalchemy/compat/core10.py @@ -20,6 +20,7 @@ # software solely pursuant to the terms of the relevant commercial agreement. import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.sql.crud import (REQUIRED, _create_bind_param, _extend_values_for_multiparams, _get_multitable_params, @@ -32,6 +33,12 @@ class CrateCompilerSA10(CrateCompiler): + def returning_clause(self, stmt, returning_cols): + """ + Generate RETURNING clause, PostgreSQL-compatible. + """ + return PGCompiler.returning_clause(self, stmt, returning_cols) + def visit_update(self, update_stmt, **kw): """ used to compile expressions diff --git a/src/crate/client/sqlalchemy/compat/core14.py b/src/crate/client/sqlalchemy/compat/core14.py index f37ea827..2dd6670a 100644 --- a/src/crate/client/sqlalchemy/compat/core14.py +++ b/src/crate/client/sqlalchemy/compat/core14.py @@ -20,6 +20,7 @@ # software solely pursuant to the terms of the relevant commercial agreement. import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.sql import selectable from sqlalchemy.sql.crud import (REQUIRED, _create_bind_param, _extend_values_for_multiparams, @@ -33,6 +34,12 @@ class CrateCompilerSA14(CrateCompiler): + def returning_clause(self, stmt, returning_cols): + """ + Generate RETURNING clause, PostgreSQL-compatible. + """ + return PGCompiler.returning_clause(self, stmt, returning_cols) + def visit_update(self, update_stmt, **kw): compile_state = update_stmt._compile_state_factory( diff --git a/src/crate/client/sqlalchemy/compat/core20.py b/src/crate/client/sqlalchemy/compat/core20.py new file mode 100644 index 00000000..6f128876 --- /dev/null +++ b/src/crate/client/sqlalchemy/compat/core20.py @@ -0,0 +1,447 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +from typing import Any, Dict, List, MutableMapping, Optional, Tuple, Union + +import sqlalchemy as sa +from sqlalchemy import ColumnClause, ValuesBase, cast, exc +from sqlalchemy.sql import dml +from sqlalchemy.sql.base import _from_objects +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.crud import (REQUIRED, _as_dml_column, _create_bind_param, + _CrudParamElement, _CrudParams, + _extend_values_for_multiparams, + _get_stmt_parameter_tuples_params, + _get_update_multitable_params, + _key_getters_for_crud_column, _scan_cols, + _scan_insert_from_select_cols, + _setup_delete_return_defaults) +from sqlalchemy.sql.dml import DMLState, _DMLColumnElement +from sqlalchemy.sql.dml import isinsert as _compile_state_isinsert + +from crate.client.sqlalchemy.compiler import CrateCompiler + + +class CrateCompilerSA20(CrateCompiler): + + def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw + ) + update_stmt = compile_state.statement + + # [20] CrateDB patch. + if not compile_state._dict_parameters and \ + not hasattr(update_stmt, '_crate_specific'): + return super().visit_update(update_stmt, **kw) + + toplevel = not self.stack + if toplevel: + self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(_from_objects(update_stmt.table)) + render_extra_froms = [ + f for f in extra_froms if f not in main_froms + ] + correlate_froms = main_froms.union(extra_froms) + else: + render_extra_froms = [] + correlate_froms = {update_stmt.table} + + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) + + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) + # [20] CrateDB patch. + crud_params_struct = _get_crud_params( + self, update_stmt, compile_state, toplevel, **kw + ) + crud_params = crud_params_struct.single_params + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text + ) + else: + dialect_hints = None + + if update_stmt._independent_ctes: + self._dispatch_independent_ctes(update_stmt, kw) + + text += table_text + + text += " SET " + + # [20] CrateDB patch begin. + include_table = extra_froms and \ + self.render_table_with_column_in_update_from + + set_clauses = [] + + for c, expr, value, _ in crud_params: + key = c._compiler_dispatch(self, include_table=include_table) + clause = key + ' = ' + value + set_clauses.append(clause) + + for k, v in compile_state._dict_parameters.items(): + if isinstance(k, str) and '[' in k: + bindparam = sa.sql.bindparam(k, v) + clause = k + ' = ' + self.process(bindparam) + set_clauses.append(clause) + + text += ', '.join(set_clauses) + # [20] CrateDB patch end. + + if self.implicit_returning or update_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + dialect_hints, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._where_criteria: + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw + ) + if t: + text += " WHERE " + t + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if ( + self.implicit_returning or update_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + toplevel: bool, + **kw: Any, +) -> _CrudParams: + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the CursorResult's prefetch_cols() and postfetch_cols() + collections. + + """ + + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + + compiler.postfetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] + compiler.implicit_returning = [] + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._get_bind_name_for_col = _col_bind_name + + if stmt._returning and stmt._return_defaults: + raise exc.CompileError( + "Can't compile statement that includes returning() and " + "return_defaults() simultaneously" + ) + + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and compile_state._no_parameters: + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + (c.key,), + ) + for c in stmt.table.columns + ], + [], + ) + + stmt_parameter_tuples: Optional[ + List[Tuple[Union[str, ColumnClause[Any]], Any]] + ] + spd: Optional[MutableMapping[_DMLColumnElement, Any]] + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] + stmt_parameter_tuples = list(spd.items()) + elif compile_state._ordered_values: + spd = compile_state._dict_parameters + stmt_parameter_tuples = compile_state._ordered_values + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) + else: + stmt_parameter_tuples = spd = None + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + elif stmt_parameter_tuples: + assert spd is not None + parameters = { + _column_as_key(key): REQUIRED + for key in compiler.column_keys + if key not in spd + } + else: + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } + + # create a list of column assignment clauses as tuples + values: List[_CrudParamElement] = [] + + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, + ) + + check_columns: Dict[str, ColumnClause[Any]] = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if dml.isupdate(compile_state) and compile_state.is_multitable: + _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) + + if _compile_state_isinsert(compile_state) and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + + _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + else: + _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + + # [20] CrateDB patch. + # + # This sanity check performed by SQLAlchemy currently needs to be + # deactivated in order to satisfy the rewriting logic of the CrateDB + # dialect in `rewrite_update` and `visit_update`. + # + # It can be quickly reproduced by activating this section and running the + # test cases:: + # + # ./bin/test -vvvv -t dict_test + # + # That croaks like:: + # + # sqlalchemy.exc.CompileError: Unconsumed column names: characters_name + # + # TODO: Investigate why this is actually happening and eventually mitigate + # the root cause. + """ + if parameters and stmt_parameter_tuples: + check = ( + set(parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) + .difference(check_columns) + ) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" + % (", ".join("%s" % (c,) for c in check)) + ) + """ + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + cast( + "Sequence[_CrudParamElementStr]", + values, + ), + cast("Callable[..., str]", _column_as_key), + kw, + ) + return _CrudParams(values, multi_extended_values) + elif ( + not values + and compiler.for_executemany + and compiler.dialect.supports_default_metavalue + ): + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [ + ( + _as_dml_column(stmt.table.columns[0]), + compiler.preparer.format_column(stmt.table.columns[0]), + compiler.dialect.default_metavalue_token, + (), + ) + ] + + return _CrudParams(values, []) diff --git a/src/crate/client/sqlalchemy/compiler.py b/src/crate/client/sqlalchemy/compiler.py index efa06a0c..7e6dad7d 100644 --- a/src/crate/client/sqlalchemy/compiler.py +++ b/src/crate/client/sqlalchemy/compiler.py @@ -221,12 +221,6 @@ def visit_any(self, element, **kw): self.process(element.right, **kw) ) - def returning_clause(self, stmt, returning_cols): - """ - Generate RETURNING clause, PostgreSQL-compatible. - """ - return PGCompiler.returning_clause(self, stmt, returning_cols) - def limit_clause(self, select, **kw): """ Generate OFFSET / LIMIT clause, PostgreSQL-compatible. diff --git a/src/crate/client/sqlalchemy/dialect.py b/src/crate/client/sqlalchemy/dialect.py index 80ab2c20..3f5f4c4f 100644 --- a/src/crate/client/sqlalchemy/dialect.py +++ b/src/crate/client/sqlalchemy/dialect.py @@ -32,7 +32,7 @@ CrateDDLCompiler ) from crate.client.exceptions import TimezoneUnawareException -from .sa_version import SA_VERSION, SA_1_4 +from .sa_version import SA_VERSION, SA_1_4, SA_2_0 from .types import Object, ObjectArray TYPES_MAP = { @@ -155,7 +155,10 @@ def process(value): } -if SA_VERSION >= SA_1_4: +if SA_VERSION >= SA_2_0: + from .compat.core20 import CrateCompilerSA20 + statement_compiler = CrateCompilerSA20 +elif SA_VERSION >= SA_1_4: from .compat.core14 import CrateCompilerSA14 statement_compiler = CrateCompilerSA14 else: diff --git a/src/crate/client/sqlalchemy/sa_version.py b/src/crate/client/sqlalchemy/sa_version.py index 502e5228..35517e27 100644 --- a/src/crate/client/sqlalchemy/sa_version.py +++ b/src/crate/client/sqlalchemy/sa_version.py @@ -25,3 +25,4 @@ SA_VERSION = V(sa.__version__) SA_1_4 = V('1.4.0b1') +SA_2_0 = V('2.0.0') diff --git a/src/crate/client/sqlalchemy/tests/bulk_test.py b/src/crate/client/sqlalchemy/tests/bulk_test.py index 95bc1ddd..ee4099cf 100644 --- a/src/crate/client/sqlalchemy/tests/bulk_test.py +++ b/src/crate/client/sqlalchemy/tests/bulk_test.py @@ -78,4 +78,4 @@ def test_bulk_save(self): ('Banshee', 26), ('Callisto', 37) ) - self.assertEqual(expected_bulk_args, bulk_args) + self.assertSequenceEqual(expected_bulk_args, bulk_args)