|
| 1 | +# -*- coding: utf-8; -*- |
| 2 | +# |
| 3 | +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor |
| 4 | +# license agreements. See the NOTICE file distributed with this work for |
| 5 | +# additional information regarding copyright ownership. Crate licenses |
| 6 | +# this file to you under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. You may |
| 8 | +# obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 14 | +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 15 | +# License for the specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# |
| 18 | +# However, if you have executed another commercial license agreement |
| 19 | +# with Crate these terms will supersede the license and you may use the |
| 20 | +# software solely pursuant to the terms of the relevant commercial agreement. |
| 21 | + |
| 22 | +import sqlalchemy as sa |
| 23 | +from sqlalchemy.sql.crud import (REQUIRED, _create_bind_param, |
| 24 | + _extend_values_for_multiparams, |
| 25 | + _get_multitable_params, |
| 26 | + _get_stmt_parameters_params, |
| 27 | + _key_getters_for_crud_column, _scan_cols, |
| 28 | + _scan_insert_from_select_cols) |
| 29 | + |
| 30 | +from crate.client.sqlalchemy.compiler import CrateCompiler |
| 31 | + |
| 32 | + |
| 33 | +class CrateCompilerSA10(CrateCompiler): |
| 34 | + |
| 35 | + def visit_update(self, update_stmt, **kw): |
| 36 | + """ |
| 37 | + used to compile <sql.expression.Update> expressions |
| 38 | + Parts are taken from the SQLCompiler base class. |
| 39 | + """ |
| 40 | + |
| 41 | + # [10] CrateDB patch. |
| 42 | + if not update_stmt.parameters and \ |
| 43 | + not hasattr(update_stmt, '_crate_specific'): |
| 44 | + return super().visit_update(update_stmt, **kw) |
| 45 | + |
| 46 | + self.isupdate = True |
| 47 | + |
| 48 | + extra_froms = update_stmt._extra_froms |
| 49 | + |
| 50 | + text = 'UPDATE ' |
| 51 | + |
| 52 | + if update_stmt._prefixes: |
| 53 | + text += self._generate_prefixes(update_stmt, |
| 54 | + update_stmt._prefixes, **kw) |
| 55 | + |
| 56 | + table_text = self.update_tables_clause(update_stmt, update_stmt.table, |
| 57 | + extra_froms, **kw) |
| 58 | + |
| 59 | + dialect_hints = None |
| 60 | + if update_stmt._hints: |
| 61 | + dialect_hints, table_text = self._setup_crud_hints( |
| 62 | + update_stmt, table_text |
| 63 | + ) |
| 64 | + |
| 65 | + # [10] CrateDB patch. |
| 66 | + crud_params = _get_crud_params(self, update_stmt, **kw) |
| 67 | + |
| 68 | + text += table_text |
| 69 | + |
| 70 | + text += ' SET ' |
| 71 | + |
| 72 | + # [10] CrateDB patch begin. |
| 73 | + include_table = \ |
| 74 | + extra_froms and self.render_table_with_column_in_update_from |
| 75 | + |
| 76 | + set_clauses = [] |
| 77 | + |
| 78 | + for k, v in crud_params: |
| 79 | + clause = k._compiler_dispatch(self, |
| 80 | + include_table=include_table) + \ |
| 81 | + ' = ' + v |
| 82 | + set_clauses.append(clause) |
| 83 | + |
| 84 | + for k, v in update_stmt.parameters.items(): |
| 85 | + if isinstance(k, str) and '[' in k: |
| 86 | + bindparam = sa.sql.bindparam(k, v) |
| 87 | + set_clauses.append(k + ' = ' + self.process(bindparam)) |
| 88 | + |
| 89 | + text += ', '.join(set_clauses) |
| 90 | + # [10] CrateDB patch end. |
| 91 | + |
| 92 | + if self.returning or update_stmt._returning: |
| 93 | + if not self.returning: |
| 94 | + self.returning = update_stmt._returning |
| 95 | + if self.returning_precedes_values: |
| 96 | + text += " " + self.returning_clause( |
| 97 | + update_stmt, self.returning) |
| 98 | + |
| 99 | + if extra_froms: |
| 100 | + extra_from_text = self.update_from_clause( |
| 101 | + update_stmt, |
| 102 | + update_stmt.table, |
| 103 | + extra_froms, |
| 104 | + dialect_hints, |
| 105 | + **kw) |
| 106 | + if extra_from_text: |
| 107 | + text += " " + extra_from_text |
| 108 | + |
| 109 | + if update_stmt._whereclause is not None: |
| 110 | + t = self.process(update_stmt._whereclause) |
| 111 | + if t: |
| 112 | + text += " WHERE " + t |
| 113 | + |
| 114 | + limit_clause = self.update_limit_clause(update_stmt) |
| 115 | + if limit_clause: |
| 116 | + text += " " + limit_clause |
| 117 | + |
| 118 | + if self.returning and not self.returning_precedes_values: |
| 119 | + text += " " + self.returning_clause( |
| 120 | + update_stmt, self.returning) |
| 121 | + |
| 122 | + return text |
| 123 | + |
| 124 | + |
| 125 | +def _get_crud_params(compiler, stmt, **kw): |
| 126 | + """create a set of tuples representing column/string pairs for use |
| 127 | + in an INSERT or UPDATE statement. |
| 128 | +
|
| 129 | + Also generates the Compiled object's postfetch, prefetch, and |
| 130 | + returning column collections, used for default handling and ultimately |
| 131 | + populating the ResultProxy's prefetch_cols() and postfetch_cols() |
| 132 | + collections. |
| 133 | +
|
| 134 | + """ |
| 135 | + |
| 136 | + compiler.postfetch = [] |
| 137 | + compiler.insert_prefetch = [] |
| 138 | + compiler.update_prefetch = [] |
| 139 | + compiler.returning = [] |
| 140 | + |
| 141 | + # no parameters in the statement, no parameters in the |
| 142 | + # compiled params - return binds for all columns |
| 143 | + if compiler.column_keys is None and stmt.parameters is None: |
| 144 | + return [ |
| 145 | + (c, _create_bind_param(compiler, c, None, required=True)) |
| 146 | + for c in stmt.table.columns |
| 147 | + ] |
| 148 | + |
| 149 | + if stmt._has_multi_parameters: |
| 150 | + stmt_parameters = stmt.parameters[0] |
| 151 | + else: |
| 152 | + stmt_parameters = stmt.parameters |
| 153 | + |
| 154 | + # getters - these are normally just column.key, |
| 155 | + # but in the case of mysql multi-table update, the rules for |
| 156 | + # .key must conditionally take tablename into account |
| 157 | + ( |
| 158 | + _column_as_key, |
| 159 | + _getattr_col_key, |
| 160 | + _col_bind_name, |
| 161 | + ) = _key_getters_for_crud_column(compiler, stmt) |
| 162 | + |
| 163 | + # if we have statement parameters - set defaults in the |
| 164 | + # compiled params |
| 165 | + if compiler.column_keys is None: |
| 166 | + parameters = {} |
| 167 | + else: |
| 168 | + parameters = dict( |
| 169 | + (_column_as_key(key), REQUIRED) |
| 170 | + for key in compiler.column_keys |
| 171 | + if not stmt_parameters or key not in stmt_parameters |
| 172 | + ) |
| 173 | + |
| 174 | + # create a list of column assignment clauses as tuples |
| 175 | + values = [] |
| 176 | + |
| 177 | + if stmt_parameters is not None: |
| 178 | + _get_stmt_parameters_params( |
| 179 | + compiler, parameters, stmt_parameters, _column_as_key, values, kw |
| 180 | + ) |
| 181 | + |
| 182 | + check_columns = {} |
| 183 | + |
| 184 | + # special logic that only occurs for multi-table UPDATE |
| 185 | + # statements |
| 186 | + if compiler.isupdate and stmt._extra_froms and stmt_parameters: |
| 187 | + _get_multitable_params( |
| 188 | + compiler, |
| 189 | + stmt, |
| 190 | + stmt_parameters, |
| 191 | + check_columns, |
| 192 | + _col_bind_name, |
| 193 | + _getattr_col_key, |
| 194 | + values, |
| 195 | + kw, |
| 196 | + ) |
| 197 | + |
| 198 | + if compiler.isinsert and stmt.select_names: |
| 199 | + _scan_insert_from_select_cols( |
| 200 | + compiler, |
| 201 | + stmt, |
| 202 | + parameters, |
| 203 | + _getattr_col_key, |
| 204 | + _column_as_key, |
| 205 | + _col_bind_name, |
| 206 | + check_columns, |
| 207 | + values, |
| 208 | + kw, |
| 209 | + ) |
| 210 | + else: |
| 211 | + _scan_cols( |
| 212 | + compiler, |
| 213 | + stmt, |
| 214 | + parameters, |
| 215 | + _getattr_col_key, |
| 216 | + _column_as_key, |
| 217 | + _col_bind_name, |
| 218 | + check_columns, |
| 219 | + values, |
| 220 | + kw, |
| 221 | + ) |
| 222 | + |
| 223 | + # [10] CrateDB patch. |
| 224 | + # |
| 225 | + # This sanity check performed by SQLAlchemy currently needs to be |
| 226 | + # deactivated in order to satisfy the rewriting logic of the CrateDB |
| 227 | + # dialect in `rewrite_update` and `visit_update`. |
| 228 | + # |
| 229 | + # It can be quickly reproduced by activating this section and running the |
| 230 | + # test cases:: |
| 231 | + # |
| 232 | + # ./bin/test -vvvv -t dict_test |
| 233 | + # |
| 234 | + # That croaks like:: |
| 235 | + # |
| 236 | + # sqlalchemy.exc.CompileError: Unconsumed column names: characters_name, data['nested'] |
| 237 | + # |
| 238 | + # TODO: Investigate why this is actually happening and eventually mitigate |
| 239 | + # the root cause. |
| 240 | + """ |
| 241 | + if parameters and stmt_parameters: |
| 242 | + check = ( |
| 243 | + set(parameters) |
| 244 | + .intersection(_column_as_key(k) for k in stmt_parameters) |
| 245 | + .difference(check_columns) |
| 246 | + ) |
| 247 | + if check: |
| 248 | + raise exc.CompileError( |
| 249 | + "Unconsumed column names: %s" |
| 250 | + % (", ".join("%s" % c for c in check)) |
| 251 | + ) |
| 252 | + """ |
| 253 | + |
| 254 | + if stmt._has_multi_parameters: |
| 255 | + values = _extend_values_for_multiparams(compiler, stmt, values, kw) |
| 256 | + |
| 257 | + return values |
0 commit comments