diff --git a/odbcli/app.py b/odbcli/app.py index ae761af..556f6e9 100644 --- a/odbcli/app.py +++ b/odbcli/app.py @@ -87,7 +87,8 @@ def __init__( # Loop over side-bar when moving past the element on the bottom self.obj_list[len(self.obj_list) - 1].next_object = self.obj_list[0] self._selected_object = self.obj_list[0] - self.completer = MssqlCompleter(smart_completion = True, get_conn = lambda: self.active_conn) + self.completer = MssqlCompleter(smart_completion = True, + get_conn = lambda: self.active_conn) self.application = self._create_application() diff --git a/odbcli/completion/mssqlcompleter.py b/odbcli/completion/mssqlcompleter.py index 673eeae..146b378 100644 --- a/odbcli/completion/mssqlcompleter.py +++ b/odbcli/completion/mssqlcompleter.py @@ -14,6 +14,7 @@ from prompt_toolkit.completion import Completer, Completion, PathCompleter from prompt_toolkit.document import Document from ..conn import sqlConnection +from ..dbmetadata import DbMetadata from .sqlcompletion import (Blank, FromClauseItem, suggest_type, Special, NamedQuery, Database, Schema, Table, Function, Column, View, Keyword, Datatype, Alias, Path, JoinCondition, Join) @@ -129,12 +130,12 @@ def __init__( self.logger.debug("Completer instantiated") @property - def active_conn(self) -> sqlConnection: - return self._get_conn() + def dbmetadata(self) -> DbMetadata: + return self._get_conn().dbmetadata def escape_name(self, name): - if self.active_conn is not None: - name = self.active_conn.escape_name(name) + if self.dbmetadata is not None: + name = self.dbmetadata.escape_name(name) return name @@ -143,21 +144,17 @@ def escape_schema(self, name): def unescape_name(self, name): """ Unquote a string.""" - if self.active_conn is not None: - name = self.active_conn.unescape_name(name) + if self.dbmetadata is not None: + name = self.dbmetadata.unescape_name(name) return name def escape_names(self, names): - if self.active_conn is not None: - names = self.active_conn.escape_names(names) + if self.dbmetadata is not None: + names = self.dbmetadata.escape_names(names) return names - def extend_database_names(self, databases): - databases = self.escape_names(databases) - self.databases.extend(databases) - def extend_keywords(self, additional_keywords): self.keywords = self.keywords + additional_keywords # OG: Unclear what the roll of all_completions is @@ -179,8 +176,7 @@ def extend_functions(self, func_data): # dbmetadata['schema_name']['functions']['function_name'] should return # the function metadata namedtuple for the corresponding function - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data submeta = metadata['function'] for f in func_data: @@ -203,8 +199,7 @@ def _refresh_arg_list_cache(self): # This is used when suggesting functions, to avoid the latency that would result # if we'd recalculate the arg lists each time we suggest functions (in # large DBs) - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data self._arg_list_cache = { usage: { meta: self._arg_list(meta, usage) @@ -226,8 +221,7 @@ def extend_foreignkeys(self, fk_data): # These are added as a list of ForeignKey namedtuples to the # ColumnMetadata namedtuple for both the child and parent # OG: This needs catalog facelift - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data submeta = metadata['table'] for fk in fk_data: @@ -249,8 +243,7 @@ def extend_datatypes(self, type_data): # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data for t in type_data: schema, type_name = self.escape_names(t) @@ -276,8 +269,7 @@ def reset_completions(self): self.special_commands = [] # search_path at this point is not used self.search_path = [] - conn = self.active_conn - conn.dbmetadata.reset_metadata() + self.dbmetadata.reset_metadata() # OG: Unclear what the roll of all_completions is #self.all_completions = set(self.keywords + self.functions) @@ -676,47 +668,20 @@ def filt(_): return matches def get_schema_matches(self, suggestion, word_before_cursor): - conn = self.active_conn if suggestion.parent: catalog_u = self.unescape_name(suggestion.parent) else: - catalog_u = conn.current_catalog() + catalog_u = self.dbmetadata.current_catalog() catalog_e = self.escape_name(catalog_u) self.logger.debug("get_schema_matches: parent %s", suggestion.parent) # OG: Note here, if there is even a single schema in [catalog_e].keys() # we'll happily return a potentially incomplete result set. - schema_names_e = conn.dbmetadata.get_schemas(catalog = catalog_e) + schema_names_e = self.dbmetadata.get_schemas(catalog = catalog_e) if schema_names_e is None: # Asking for schema in a non-existant catalog return [] - - if len(schema_names_e) == 0: - # Catalog exists in dbmetadata but is empty - if suggestion.parent: - # Looking for schemas in a specified catalog - schema_names = [] - # Attempt list_schemas - schema_names = conn.list_schemas( - catalog = conn.sanitize_search_string(catalog_u)) - - if len(schema_names) < 1: - res = conn.find_tables( - catalog = conn.sanitize_search_string(catalog_u), - schema = "", - table = "", - type = "") - schema_names = [r.schema for r in res] - else: - # Looking for schemas in current catalog - schema_names = conn.list_schemas() - - schema_names = set(schema_names) - - schema_names_e = self.escape_names(schema_names) - conn.dbmetadata.extend_schemas(catalog = catalog_e, names = schema_names_e) - return self.find_matches( word_before_cursor, schema_names_e, meta='schema') @@ -836,11 +801,7 @@ def get_alias_matches(self, suggestion, word_before_cursor): meta='table alias') def get_database_matches(self, _, word_before_cursor): - conn = self.active_conn - catalogs_e = conn.dbmetadata.get_catalogs() - if catalogs_e is None and (conn.connected()): - catalogs_e = self.escape_names(conn.list_catalogs()) - conn.dbmetadata.extend_catalogs(catalogs_e) + catalogs_e = self.dbmetadata.get_catalogs() return self.find_matches(word_before_cursor, catalogs_e, meta='catalog') @@ -921,10 +882,9 @@ def populate_scoped_cols(self, scoped_tbls, local_tbls=()): :return: {TableReference:{colname:ColumnMetaData}} """ - conn = self.active_conn ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) columns = OrderedDict() - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data def addcols(schema, rel, alias, reltype, cols): tbl = TableReference(schema, rel, alias, reltype == 'functions') @@ -978,7 +938,7 @@ def addcols(catalog, schema, rel, alias, reltype, cols): if tbl.catalog: catalog_u = self.unescape_name(tbl.catalog) else: - catalog_u = self.active_conn.current_catalog() + catalog_u = self.dbmetadata.current_catalog() # TODO: What if no schema? Possible in some DBMS if tbl.schema: @@ -1007,21 +967,21 @@ def addcols(catalog, schema, rel, alias, reltype, cols): # cols = func.fields() # addcols(schema, relname, tbl.alias, 'functions', cols) else: - conn = self.active_conn - # Per SQLColumns spec: CatalogName cannot contain a string search pattern - res = conn.find_columns( - catalog = catalog_u, - schema = conn.sanitize_search_string(schema_u), - table = conn.sanitize_search_string(relname_u), - column = "%") - if len(res): - cols = [ColumnMetadata( - name = col.column, - datatype = col.data_type, - has_default = col.default, - default = col.default - ) for col in res] - addcols(catalog, schema, relname, tbl.alias, "table", cols) + for reltype in ("table", "view"): + res = self.dbmetadata.get_columns( + catalog = catalog, + schema = schema, + name = relname, + obj_type = reltype) + + if res is not None and len(res): + cols = [ColumnMetadata( + name = col.column, + datatype = col.data_type, + has_default = col.default, + default = col.default + ) for col in res] + addcols(catalog, schema, relname, tbl.alias, reltype, cols) return columns @@ -1032,8 +992,7 @@ def _get_schemas(self, obj_typ, schema): :param schema is the schema qualification input by the user (if any) """ - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data submeta = metadata[obj_typ] if schema: schema = self.escape_name(schema) @@ -1050,8 +1009,7 @@ def populate_schema_objects(self, schema, obj_type): :param schema is the schema qualification input by the user (if any) """ - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data return [ SchemaObject( @@ -1070,7 +1028,6 @@ def populate_objects(self, catalog, schema, obj_type): """ ret = [] obj_names = [] - conn = self.active_conn self.logger.debug("populate_objects(%s): Called for %s.%s", obj_type, catalog, schema) if catalog is None and schema is None: @@ -1078,7 +1035,7 @@ def populate_objects(self, catalog, schema, obj_type): schema = "" if catalog is None: # Set to current catalog - catalog = conn.current_catalog() + catalog = self.dbmetadata.current_catalog() if catalog is None or catalog == "": # Don't allow "".[schema] # Interpret this to mean [schema]."" @@ -1091,55 +1048,20 @@ def populate_objects(self, catalog, schema, obj_type): # dbmetadata always escaped? catalog_e = self.escape_name(self.unescape_name(catalog)) schema_e = self.escape_name(self.unescape_name(schema)) - obj_names = conn.dbmetadata.get_objects(catalog = catalog_e, schema = schema_e, obj_type = obj_type) + obj_names = self.dbmetadata.get_objects(catalog = catalog_e, schema = schema_e, obj_type = obj_type) if obj_names is None: self.logger.debug("populate_objects(%s): Called for %s.%s, catalog/schema not found", obj_type, catalog, schema) return [] - if len(obj_names) == 0: - # catalog.schema were found but dbmetadata had no information as to - # content. So let's attempt to query - obj_names = [] - self.logger.debug("populate_objects(%s): Did not find %s.%s metadata. Will query.", obj_type, catalog_e, schema_e) - # Special case: Look for tables without catalog/schema - if catalog == "" and schema == "": - res = conn.find_tables( - catalog = "\x00", - schema = "\x00", - table = "", - type = obj_type) - else: - res = conn.find_tables( - catalog = conn.sanitize_search_string( - self.unescape_name(catalog)), - schema = conn.sanitize_search_string( - self.unescape_name(schema)), - table = "", - type = obj_type) - for r in res: - name_e = self.escape_name(r.name) - ret.append( - SchemaObject( - name=name_e, - schema=schema_e, - catalog=catalog_e - ) - ) - obj_names.append(name_e) - self.logger.debug("populate_objects(%s): Query complete %s.%s", obj_type, catalog_e, schema_e) - conn.dbmetadata.extend_objects( - catalog = catalog_e, schema = schema_e, - names = obj_names, obj_type = obj_type) - else: - for name_e in obj_names: - ret.append( - SchemaObject( - name=name_e, - schema=schema_e, #should this be r.schema - catalog=catalog_e #should this be r.catalog - ) + for name_e in obj_names: + ret.append( + SchemaObject( + name=name_e, + schema=schema_e, #should this be r.schema + catalog=catalog_e #should this be r.catalog ) + ) return ret def populate_functions(self, schema, filter_func): @@ -1152,8 +1074,7 @@ def populate_functions(self, schema, filter_func): """ - conn = self.active_conn - metadata = conn.dbmetadata.data + metadata = self.dbmetadata.data # Because of multiple dispatch, we can have multiple functions # with the same name, which is why `for meta in metas` is necessary # in the comprehensions below diff --git a/odbcli/conn.py b/odbcli/conn.py index 5106a3f..4369d9b 100644 --- a/odbcli/conn.py +++ b/odbcli/conn.py @@ -38,7 +38,7 @@ def __init__( self.username = username self.password = password self.logger = getLogger(__name__) - self.dbmetadata = DbMetadata() + self.dbmetadata = DbMetadata(self) self._quotechar = None self._search_escapechar = None self._search_escapepattern = None @@ -450,8 +450,12 @@ def find_procedure_columns( def current_catalog(self) -> str: if self.conn.connected(): - return self.conn.catalog_name - return None + with self._lock: + res = self.conn.catalog_name + else: + res = None + + return res def connected(self) -> bool: return self.conn.connected() diff --git a/odbcli/dbmetadata.py b/odbcli/dbmetadata.py index 75d415e..04d188c 100644 --- a/odbcli/dbmetadata.py +++ b/odbcli/dbmetadata.py @@ -1,16 +1,55 @@ from threading import Lock, Event, Thread class DbMetadata(): - def __init__(self) -> None: + """ + Internal representation of a database. The structure is that of a nested + dictionary, with the top nodes being "table", "view", "function", and + "datatype". From there the structure is: + "table" = {..., "catalog.lower" = (catalog, {...., "schema".lower = (...)}) + Some notes: + * All identifiers (catalog, schema, table, column) are quoted using + the quoting character for the connection. + * The name of a node is *.lower to allow for case insensitive search + through the connection metadata. + """ + def __init__(self, conn: "sqlConnection") -> None: + self._conn = conn self._lock = Lock() self._dbmetadata = {'table': {}, 'view': {}, 'function': {}, 'datatype': {}} + def escape_name(self, name): + if self._conn.connected(): + name = self._conn.escape_name(name) + + return name + + def escape_names(self, names): + if self._conn.connected(): + names = self._conn.escape_names(names) + + return names + + def unescape_name(self, name): + """ Unquote a string.""" + if self._conn.connected(): + name = self._conn.unescape_name(name) + return name + + def current_catalog(self): + if self._conn.connected(): + return self._conn.current_catalog() + + return None + def extend_catalogs(self, names: list) -> None: - with self._lock: - for metadata in self._dbmetadata.values(): - for catalog in names: - metadata[catalog.lower()] = (catalog, {}) + if len(names): + # Add "" catalog to house tables without catalog / schema + names.append("") + with self._lock: + for metadata in self._dbmetadata.values(): + for catalog in names: + metadata[catalog.lower()] = (catalog, {}) return def get_catalogs(self, obj_type: str = "table", cased: bool = True) -> list: @@ -23,9 +62,11 @@ def get_catalogs(self, obj_type: str = "table", cased: bool = True) -> list: else: res = list(self._dbmetadata[obj_type].keys()) - - if len(res) == 0: - return None + if len(res) == 0 and self._conn.connected(): + res = self.escape_names(self._conn.list_catalogs()) + self.extend_catalogs(res) + # TODO: Should we recursively call get_catalogs here so as to be + # able to respect the cased argument? return res @@ -37,6 +78,8 @@ def extend_schemas(self, catalog, names: list) -> None: catlower = catalog.lower() cat_cased = catalog if len(names): + # Add "" schema to house tables without schema + names.append("") with self._lock: for metadata in self._dbmetadata.values(): # Preserve casing if an entry already there @@ -45,6 +88,15 @@ def extend_schemas(self, catalog, names: list) -> None: metadata[catlower] = (cat_cased, {}) for schema in names: metadata[catlower][1][schema.lower()] = (schema, {}) + # If we passed nothing then take out that element entirely out + # of the dict + else: + with self._lock: + for otype in self._dbmetadata.keys(): + try: + del self._dbmetadata[otype][catlower] + except KeyError: + pass return def get_schemas(self, catalog: str, obj_type: str = "table", cased: bool = True) -> list: @@ -64,6 +116,26 @@ def get_schemas(self, catalog: str, obj_type: str = "table", cased: bool = True) else: res = list(self._dbmetadata[obj_type][catlower][1].keys()) + if len(res) == 0 and self._conn.connected(): + # Looking for schemas in a specified catalog + res_u = [] + catalog_u = self.unescape_name(catalog) + # Attempt list_schemas + res_u = self._conn.list_schemas( + catalog = self._conn.sanitize_search_string(catalog_u)) + + if len(res_u) < 1: + res_u = self._conn.find_tables( + catalog = self._conn.sanitize_search_string(catalog_u), + schema = "", + table = "", + type = "") + res_u = [r.schema for r in res_u] + + res = self.escape_names(res_u) + self.extend_schemas(catalog = catalog, names = res) + # TODO: Should we recursively call get_schemas here so as to be + # able to respect the cased argument? return res @@ -84,11 +156,14 @@ def extend_objects(self, catalog, schema, names: list, obj_type: str) -> None: # of the dict else: with self._lock: - del self._dbmetadata[obj_type][catlower][1][schlower] + try: + del self._dbmetadata[obj_type][catlower][1][schlower] + except KeyError: + pass return - def get_objects(self, catalog: str, schema: str, obj_type: str = "table") -> list: + def get_objects(self, catalog: str, schema: str, obj_type: str = "table", cased: bool = True) -> list: """ Retrieve objects as the keys for _dbmetadata[obj_type][catalog][schema] If catalog is not part of the _dbmetadata[obj_type] keys, or schema not one of the keys in _dbmetadata[obj_type][catalog] will return None @@ -97,13 +172,108 @@ def get_objects(self, catalog: str, schema: str, obj_type: str = "table") -> lis catlower = catalog.lower() schlower = schema.lower() schemas = self.get_schemas(catalog = catalog, obj_type = obj_type, cased = False) - if schemas is None or schlower not in schemas: + if (schemas is None or schlower not in schemas): + return None + + res = [] + with self._lock: + if cased: + res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1][schlower][1].values()] + else: + res = list(self._dbmetadata[obj_type][catlower][1][schlower][1].keys()) + + if len(res) == 0 and self._conn.connected(): + # Special case: Look for tables without catalog/schema + res = [] + if catalog == "" and schema == "": + res_u = self._conn.find_tables( + catalog = "\x00", + schema = "\x00", + table = "", + type = obj_type) + else: + res_u = self._conn.find_tables( + catalog = self._conn.sanitize_search_string( + self.unescape_name(catalog)), + schema = self._conn.sanitize_search_string( + self.unescape_name(schema)), + table = "", + type = obj_type) + res_u = [r.name for r in res_u] + res = self.escape_names(res_u) + + self.extend_objects( + catalog = catalog, schema = schema, + names = res, obj_type = obj_type) + # TODO: Should we recursively call get_objects here so as to make + # sure we are returning the correct spec + + return res + + def extend_columns(self, catalog, schema, name, cols: list, obj_type: str) -> None: + catlower = catalog.lower() + schlower = schema.lower() + nmlower = name.lower() + if len(cols): + with self._lock: + for otype in self._dbmetadata.keys(): + # Loop over tables, views, functions + if catlower not in self._dbmetadata[otype].keys(): + self._dbmetadata[otype][catlower] = (catalog, {}) + if schlower not in self._dbmetadata[otype][catlower][1].keys(): + self._dbmetadata[otype][catlower][1][schlower] = (schema, {}) + for col in cols: + try: + self._dbmetadata[obj_type][catlower][1][schlower][1][nmlower][1][col.column.lower()] = (col.column, col) + except KeyError: + pass + # If we passed nothing then take out that element entirely out + # of the dict + else: + with self._lock: + try: + del self._dbmetadata[obj_type][catlower][1][schlower] + except KeyError: + pass + + return + + def get_columns(self, catalog: str, schema: str, name: str, obj_type: str = "table") -> list: + """ + Returns a list of named tuples. See cyanodbc.connection.find_columns + """ + catlower = catalog.lower() + schlower = schema.lower() + nmlower = name.lower() + objs = self.get_objects(catalog = catalog, schema = schema, + obj_type = obj_type, cased = False) # TODO add cased argument to get_objects + if (objs is None or nmlower not in objs): return None + res = [] with self._lock: - res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1][schlower][1].values()] + res = [mappedvalue for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1][schlower][1][nmlower][1].values()] - return list(res) + if len(res) == 0 and self._conn.connected(): + # res is a collections.namedtuple object + res = self._conn.find_columns( + # Per SQLColumns spec: CatalogName cannot contain a + # string search pattern. But should we sanitize + # regardless? + catalog = self.unescape_name(catalog), + schema = self._conn.sanitize_search_string( + self.unescape_name(schema)), + table = self._conn.sanitize_search_string( + self.unescape_name(name)), + column = "%") + + self.extend_columns( + catalog = catalog, schema = schema, name = name, + cols = res, obj_type = obj_type) + # TODO: Should we recursively call get_columns here so as to make + # sure we are returning the correct spec + + return res def reset_metadata(self) -> None: with self._lock: