-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathmssqlcliclient.py
368 lines (304 loc) · 15.2 KB
/
mssqlcliclient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import sys
import copy
import logging
import time
from time import sleep
import uuid
import sqlparse
import click
from mssqlcli import mssqlqueries
from mssqlcli.jsonrpc.contracts import connectionservice, queryexecutestringservice as queryservice
from mssqlcli.packages import special
from mssqlcli.packages.parseutils.meta import ForeignKey
logger = logging.getLogger(u'mssqlcli.mssqlcliclient')
time_wait_if_no_response = 0.05
def generate_owner_uri():
return u'mssql-cli-' + uuid.uuid4().urn
class MssqlCliClient:
# pylint: disable=too-many-instance-attributes
def __init__(self, mssqlcli_options, sql_tools_client, owner_uri=None, **kwargs):
self.server_name = mssqlcli_options.server
if ',' in mssqlcli_options.server:
self.prompt_host, self.prompt_port = self.server_name.split(',')
else:
self.prompt_host = mssqlcli_options.server
self.prompt_port = 1433
self.user_name = mssqlcli_options.username
self.password = mssqlcli_options.password
self.authentication_type = u'Integrated' if mssqlcli_options.integrated_auth \
else u'SqlLogin'
self.database = mssqlcli_options.database
self.connected_database = None
self.encrypt = mssqlcli_options.encrypt
self.trust_server_certificate = mssqlcli_options.trust_server_certificate
self.connection_timeout = mssqlcli_options.connection_timeout
self.application_intent = mssqlcli_options.application_intent
self.multi_subnet_failover = mssqlcli_options.multi_subnet_failover
self.packet_size = mssqlcli_options.packet_size
self.owner_uri = owner_uri if owner_uri else generate_owner_uri()
self.sql_tools_client = sql_tools_client
self.is_connected = False
self.server_version = None
self.server_edition = None
self.is_cloud = False
self.extra_params = kwargs
logger.info(u'Initialized MssqlCliClient with owner Uri %s', self.owner_uri)
def get_base_connection_params(self):
return {
u'ServerName': self.server_name,
u'DatabaseName': self.connected_database if self.connected_database \
else self.database,
u'UserName': self.user_name,
u'Password': self.password,
u'AuthenticationType': self.authentication_type,
u'OwnerUri': self.owner_uri
}
def add_optional_connection_params(self, base_connection_params):
if self.encrypt:
base_connection_params[u'Encrypt'] = self.encrypt
if self.trust_server_certificate:
base_connection_params[u'TrustServerCertificate'] = self.trust_server_certificate
if self.connection_timeout:
base_connection_params[u'ConnectTimeout'] = self.connection_timeout
if self.application_intent:
base_connection_params[u'ApplicationIntent'] = self.application_intent
if self.multi_subnet_failover:
base_connection_params[u'MultiSubnetFailover'] = self.multi_subnet_failover
if self.packet_size:
base_connection_params[u'PacketSize'] = self.packet_size
base_connection_params.update(self.extra_params)
return base_connection_params
def connect_to_database(self):
connection_params = self.get_base_connection_params()
connection_params = self.add_optional_connection_params(connection_params)
owner_uri, error_messages = self._execute_connection_request_with(connection_params)
return owner_uri, error_messages
def execute_query(self, query):
# Try to run first as special command
try:
for rows, columns, status, statement, is_error in special.execute(self, query):
yield rows, columns, status, statement, is_error
except special.CommandNotFound:
# Execute as normal sql
# Remove spaces, EOL and semi-colons from end
query = query.strip()
if not query:
yield None, None, None, query, False
else:
for single_query in sqlparse.split(query):
# Remove spaces, EOL and semi-colons from end
single_query = single_query.strip()
if single_query:
for rows, columns, status, statement, is_error \
in self._execute_query(single_query):
yield rows, columns, status, statement, is_error
else:
yield None, None, None, None, False
continue
def _execute_query(self, query):
query_response, query_messages, query_had_error \
= self._execute_query_execute_request_for(query)
if self._exception_found_in(query_response):
yield self._generate_query_results_to_tuples(query=query,
message=query_response.exception_message,
is_error=query_had_error)
return
if self._no_results_found_in(query_response) or self._no_rows_found_in(query_response):
query_message = query_messages[0].message if query_messages else u''
yield self._generate_query_results_to_tuples(query=query,
message=query_message,
is_error=query_had_error)
else:
query_subset_responses_and_summaries \
= self._execute_query_subset_request_for(query_response)
for query_subset_response, result_set_summary, query_subset_error \
in query_subset_responses_and_summaries:
if self._error_message_found_in(query_subset_response):
yield self._generate_query_results_to_tuples(query=query,
message=query_subset_response\
.error_message,
is_error=query_subset_error)
query_message_for_current_result_set \
= query_messages[result_set_summary.request_id].message if query_messages \
else u''
yield self._generate_query_results_to_tuples(
column_info=result_set_summary.column_info,
result_rows=query_subset_response.rows,
query=query,
message=query_message_for_current_result_set
)
def clone(self, sqltoolsclient=None):
cloned_mssqlcli_client = copy.copy(self)
cloned_mssqlcli_client.owner_uri = generate_owner_uri()
cloned_mssqlcli_client.is_connected = False
if sqltoolsclient:
cloned_mssqlcli_client.sql_tools_client = sqltoolsclient
return cloned_mssqlcli_client
def _execute_connection_request_with(self, connection_params):
if self.is_connected:
return self.owner_uri, []
connection_request = self.sql_tools_client.create_request(
self.sql_tools_client.CONNECTION_REQUEST,
connection_params,
self.owner_uri
)
connection_request.execute()
error_messages = []
response = None
while not connection_request.completed():
response = connection_request.get_response()
if isinstance(response, connectionservice.ConnectionCompleteEvent):
if response.error_message:
error_messages.append(u'Error message: {}'.format(response.error_message))
if response.messages:
logger.error(response.messages)
else:
time.sleep(time_wait_if_no_response)
if response and response.connection_id:
assert response.owner_uri == self.owner_uri
self.is_connected = True
self.server_version = response.server_version
self.server_edition = response.server_edition
self.is_cloud = response.is_cloud
self.connected_database = response.connected_database
logger.info(u'Connection Successful. Connection Id: %s Connected database" %s',
response.connection_id, response.connected_database)
return self.owner_uri, error_messages
return None, error_messages
def _execute_query_execute_request_for(self, query):
if not self.is_connected:
click.secho(u'No connection established with the server.',
err=True,
fg='yellow')
sys.exit(1)
query_request = self.sql_tools_client.create_request(
self.sql_tools_client.QUERY_EXECUTE_STRING_REQUEST,
{
u'OwnerUri': self.owner_uri,
u'Query': query
},
self.owner_uri
)
query_request.execute()
query_response = None
query_messages = []
while not query_request.completed():
query_response = query_request.get_response()
if isinstance(query_response, queryservice.QueryMessageEvent):
query_messages.append(query_response)
else:
sleep(time_wait_if_no_response)
query_has_exception = query_response.exception_message
query_has_error_messages = query_messages[0].is_error if query_messages else False
query_has_batch_error = query_response.batch_summaries[0].has_error \
if hasattr(query_response, 'batch_summaries') \
and len(query_response.batch_summaries) > 0 else False
query_failed = query_has_exception or query_has_batch_error or query_has_error_messages
return query_response, query_messages, query_failed
def _execute_query_subset_request_for(self, query_response):
subset_responses_and_summaries = []
for result_set_summary in query_response.batch_summaries[0].result_set_summaries:
query_subset_request = self.sql_tools_client.create_request(
self.sql_tools_client.QUERY_SUBSET_REQUEST,
{
u'OwnerUri': query_response.owner_uri,
u'BatchIndex': result_set_summary.batch_id,
u'ResultSetIndex': result_set_summary.request_id,
u'RowsStartIndex': 0,
u'RowCount': result_set_summary.row_count
},
self.owner_uri)
query_subset_request.execute()
query_subset_response = None
while not query_subset_request.completed():
query_subset_response = query_subset_request.get_response()
if not query_subset_response:
sleep(time_wait_if_no_response)
query_subset_had_error = query_subset_request.error_message \
if hasattr(query_subset_request, 'error_message') else False
subset_responses_and_summaries.append(
(query_subset_response, result_set_summary, query_subset_had_error)
)
return subset_responses_and_summaries
@staticmethod
def _error_message_found_in(query_subset_response):
return query_subset_response.error_message
@staticmethod
def _exception_found_in(query_response):
return query_response.exception_message
@staticmethod
def _no_results_found_in(query_response):
return not query_response.batch_summaries \
or not query_response.batch_summaries[0].result_set_summaries
@staticmethod
def _no_rows_found_in(query_response):
return query_response.batch_summaries[0].result_set_summaries[0].row_count == 0
@staticmethod
def _generate_query_results_to_tuples(query, message, column_info=None, result_rows=None, \
is_error=False):
# Returns a generator of rows, columns, status(rows affected) or
# message, sql (the query), is_error
if is_error:
return (), None, message, query, is_error
columns = [col.column_name for col in column_info] if column_info else None
rows = ([[cell.display_value for cell in result_row.result_cells]
for result_row in result_rows]) if result_rows else ()
return rows, columns, message, query, is_error
def get_schemas(self):
""" Returns a list of schema names"""
query = mssqlqueries.get_schemas()
logger.info(u'Schemas query: %s', query)
for tabular_result in self.execute_query(query):
return [x[0] for x in tabular_result[0]]
def get_databases(self):
""" Returns a list of database names"""
query = mssqlqueries.get_databases()
logger.info(u'Databases query: %s', query)
for tabular_result in self.execute_query(query):
return [x[0] for x in tabular_result[0]]
def get_tables(self):
""" Yields (schema_name, table_name) tuples"""
query = mssqlqueries.get_tables()
logger.info(u'Tables query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield (row[0], row[1])
def get_table_columns(self):
""" Yields (schema_name, table_name, column_name, data_type, column_default) tuples"""
query = mssqlqueries.get_table_columns()
logger.info(u'Table columns query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield (row[0], row[1], row[2], row[3], row[4])
def get_views(self):
""" Yields (schema_name, table_name) tuples"""
query = mssqlqueries.get_views()
logger.info(u'Views query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield (row[0], row[1])
def get_view_columns(self):
""" Yields (schema_name, table_name, column_name, data_type, column_default) tuples"""
query = mssqlqueries.get_view_columns()
logger.info(u'View columns query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield (row[0], row[1], row[2], row[3], row[4])
def get_user_defined_types(self):
""" Yields (schema_name, type_name) tuples"""
query = mssqlqueries.get_user_defined_types()
logger.info(u'UDTs query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield (row[0], row[1])
def get_foreign_keys(self):
""" Yields (parent_schema, parent_table, parent_column, child_schema, child_table,
child_column) typles"""
query = mssqlqueries.get_foreignkeys()
logger.info(u'Foreign keys query: %s', query)
for tabular_result in self.execute_query(query):
for row in tabular_result[0]:
yield ForeignKey(*row)
def shutdown(self):
self.sql_tools_client.shutdown()
logger.info(u'Shutdown MssqlCliClient')