diff --git a/edgedb/base_client.py b/edgedb/base_client.py index 218d1aaa..d8c33767 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -118,6 +118,10 @@ def remove_log_listener( def dbname(self) -> str: return self._params.database + @property + def branch(self) -> str: + return self._params.branch + @abc.abstractmethod def is_closed(self) -> bool: ... @@ -679,6 +683,7 @@ def __init__( password: str = None, secret_key: str = None, database: str = None, + branch: str = None, tls_ca: str = None, tls_ca_file: str = None, tls_security: str = None, @@ -697,6 +702,7 @@ def __init__( "password": password, "secret_key": secret_key, "database": database, + "branch": branch, "timeout": timeout, "tls_ca": tls_ca, "tls_ca_file": tls_ca_file, diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py index 285790fc..52d4f0cc 100644 --- a/edgedb/con_utils.py +++ b/edgedb/con_utils.py @@ -181,9 +181,15 @@ class ResolvedConnectConfig: _port = None _port_source = None + # We keep track of database and branch separately, because we want to make + # sure that all the configuration is consistent and uses one or the other + # exclusively. _database = None _database_source = None + _branch = None + _branch_source = None + _user = None _user_source = None @@ -226,6 +232,9 @@ def set_port(self, port, source): def set_database(self, database, source): self._set_param('database', database, source, _validate_database) + def set_branch(self, branch, source): + self._set_param('branch', branch, source, _validate_branch) + def set_user(self, user, source): self._set_param('user', user, source, _validate_user) @@ -268,9 +277,24 @@ def address(self): self._port if self._port else 5656 ) + # The properties actually merge database and branch, but "default" is + # different. If you need to know the underlying config use the _database + # and _branch. @property def database(self): - return self._database if self._database else 'edgedb' + return ( + self._database if self._database else + self._branch if self._branch else + 'edgedb' + ) + + @property + def branch(self): + return ( + self._database if self._database else + self._branch if self._branch else + '__default__' + ) @property def user(self): @@ -391,6 +415,12 @@ def _validate_database(database): return database +def _validate_branch(branch): + if branch == '': + raise ValueError(f'invalid branch name: {branch}') + return branch + + def _validate_user(user): if user == '': raise ValueError(f'invalid user name: {user}') @@ -521,6 +551,7 @@ def _parse_connect_dsn_and_args( password, secret_key, database, + branch, tls_ca, tls_ca_file, tls_security, @@ -557,6 +588,10 @@ def _parse_connect_dsn_and_args( (database, '"database" option') if database is not None else None ), + branch=( + (branch, '"branch" option') + if branch is not None else None + ), user=(user, '"user" option') if user is not None else None, password=( (password, '"password" option') @@ -604,6 +639,7 @@ def _parse_connect_dsn_and_args( env_credentials_file = os.getenv('EDGEDB_CREDENTIALS_FILE') env_host = os.getenv('EDGEDB_HOST') env_database = os.getenv('EDGEDB_DATABASE') + env_branch = os.getenv('EDGEDB_BRANCH') env_user = os.getenv('EDGEDB_USER') env_password = os.getenv('EDGEDB_PASSWORD') env_secret_key = os.getenv('EDGEDB_SECRET_KEY') @@ -643,6 +679,10 @@ def _parse_connect_dsn_and_args( (env_database, '"EDGEDB_DATABASE" environment variable') if env_database is not None else None ), + branch=( + (env_branch, '"EDGEDB_BRANCH" environment variable') + if env_branch is not None else None + ), user=( (env_user, '"EDGEDB_USER" environment variable') if env_user is not None else None @@ -818,11 +858,52 @@ def handle_dsn_part( def strip_leading_slash(str): return str[1:] if str.startswith('/') else str - handle_dsn_part( - 'database', strip_leading_slash(database), - resolved_config._database, resolved_config.set_database, - strip_leading_slash - ) + if ( + 'branch' in query or + 'branch_env' in query or + 'branch_file' in query + ): + if ( + 'database' in query or + 'database_env' in query or + 'database_file' in query + ): + raise ValueError( + f"invalid DSN: `database` and `branch` cannot be present " + f"at the same time" + ) + if resolved_config._database is not None: + raise errors.ClientConnectionError( + f"`branch` in DSN and {resolved_config._database_source} " + f"are mutually exclusive" + ) + handle_dsn_part( + 'branch', strip_leading_slash(database), + resolved_config._branch, resolved_config.set_branch, + strip_leading_slash + ) + else: + if resolved_config._branch is not None: + if ( + 'database' in query or + 'database_env' in query or + 'database_file' in query + ): + raise errors.ClientConnectionError( + f"`database` in DSN and {resolved_config._branch_source} " + f"are mutually exclusive" + ) + handle_dsn_part( + 'branch', strip_leading_slash(database), + resolved_config._branch, resolved_config.set_branch, + strip_leading_slash + ) + else: + handle_dsn_part( + 'database', strip_leading_slash(database), + resolved_config._database, resolved_config.set_database, + strip_leading_slash + ) handle_dsn_part( 'user', user, resolved_config._user, resolved_config.set_user @@ -929,6 +1010,7 @@ def _resolve_config_options( host=None, port=None, database=None, + branch=None, user=None, password=None, secret_key=None, @@ -940,7 +1022,23 @@ def _resolve_config_options( cloud_profile=None, ): if database is not None: + if branch is not None: + raise errors.ClientConnectionError( + f"{database[1]} and {branch[1]} are mutually exclusive" + ) + if resolved_config._branch is not None: + raise errors.ClientConnectionError( + f"{database[1]} and {resolved_config._branch_source} are " + f"mutually exclusive" + ) resolved_config.set_database(*database) + if branch is not None: + if resolved_config._database is not None: + raise errors.ClientConnectionError( + f"{resolved_config._database_source} and {branch[1]} are " + f"mutually exclusive" + ) + resolved_config.set_branch(*branch) if user is not None: resolved_config.set_user(*user) if password is not None: @@ -950,7 +1048,8 @@ def _resolve_config_options( if tls_ca_file is not None: if tls_ca is not None: raise errors.ClientConnectionError( - f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive") + f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive" + ) resolved_config.set_tls_ca_file(*tls_ca_file) if tls_ca is not None: resolved_config.set_tls_ca_data(*tls_ca) @@ -1018,7 +1117,23 @@ def _resolve_config_options( resolved_config.set_host(creds.get('host'), source) resolved_config.set_port(creds.get('port'), source) - resolved_config.set_database(creds.get('database'), source) + # We know that credentials have been validated, but they might be + # inconsistent with other resolved config settings. + if 'database' in creds: + if resolved_config._branch is not None: + raise errors.ClientConnectionError( + f"`branch` in configuration and `database` " + f"in credentials are mutually exclusive" + ) + resolved_config.set_database(creds.get('database'), source) + + elif 'branch' in creds: + if resolved_config._database is not None: + raise errors.ClientConnectionError( + f"`database` in configuration and `branch` " + f"in credentials are mutually exclusive" + ) + resolved_config.set_branch(creds.get('branch'), source) resolved_config.set_user(creds.get('user'), source) resolved_config.set_password(creds.get('password'), source) resolved_config.set_tls_ca_data(creds.get('tls_ca'), source) @@ -1068,6 +1183,7 @@ def parse_connect_arguments( credentials, credentials_file, database, + branch, user, password, secret_key, @@ -1100,6 +1216,7 @@ def parse_connect_arguments( credentials=credentials, credentials_file=credentials_file, database=database, + branch=branch, user=user, password=password, secret_key=secret_key, diff --git a/edgedb/credentials.py b/edgedb/credentials.py index aa6266bc..6ed03dfc 100644 --- a/edgedb/credentials.py +++ b/edgedb/credentials.py @@ -14,7 +14,9 @@ class RequiredCredentials(typing.TypedDict, total=True): class Credentials(RequiredCredentials, total=False): host: typing.Optional[str] password: typing.Optional[str] + # Either database or branch may appear in credentials, but not both. database: typing.Optional[str] + branch: typing.Optional[str] tls_ca: typing.Optional[str] tls_security: typing.Optional[str] @@ -64,6 +66,15 @@ def validate_credentials(data: dict) -> Credentials: raise ValueError("`database` must be a string") result['database'] = database + branch = data.get('branch') + if branch is not None: + if not isinstance(branch, str): + raise ValueError("`branch` must be a string") + if database is not None: + raise ValueError( + f"`database` and `branch` cannot both be set") + result['branch'] = branch + password = data.get('password') if password is not None: if not isinstance(password, str): diff --git a/tests/shared-client-testcases b/tests/shared-client-testcases index b8959be8..0e0ae1c3 160000 --- a/tests/shared-client-testcases +++ b/tests/shared-client-testcases @@ -1 +1 @@ -Subproject commit b8959be8968aceeeac2af3da7639de02b19d7030 +Subproject commit 0e0ae1c31b3aa04104b344d200967b5ddad66605 diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py index c2820a6e..120c1a19 100644 --- a/tests/test_con_utils.py +++ b/tests/test_con_utils.py @@ -120,6 +120,7 @@ def run_testcase(self, testcase): host = opts.get('host') port = opts.get('port') database = opts.get('database') + branch = opts.get('branch') user = opts.get('user') password = opts.get('password') secret_key = opts.get('secretKey') @@ -233,6 +234,7 @@ def mocked_open(filepath, *args, **kwargs): credentials=credentials, credentials_file=credentials_file, database=database, + branch=branch, user=user, password=password, secret_key=secret_key, @@ -250,6 +252,7 @@ def mocked_open(filepath, *args, **kwargs): connect_config.address[0], connect_config.address[1] ], 'database': connect_config.database, + 'branch': connect_config.branch, 'user': connect_config.user, 'password': connect_config.password, 'secretKey': connect_config.secret_key, @@ -289,7 +292,7 @@ def test_test_connect_params_environ(self): if key in os.environ: del os.environ[key] - def test_test_connect_params_run_testcase(self): + def test_test_connect_params_run_testcase_01(self): with self.environ(EDGEDB_PORT='777'): self.run_testcase({ 'env': { @@ -301,6 +304,31 @@ def test_test_connect_params_run_testcase(self): 'result': { 'address': ['abc', 5656], 'database': 'edgedb', + 'branch': '__default__', + 'user': '__test__', + 'password': None, + 'secretKey': None, + 'tlsCAData': None, + 'tlsSecurity': 'strict', + 'serverSettings': {}, + 'waitUntilAvailable': 30, + }, + }) + + def test_test_connect_params_run_testcase_02(self): + with self.environ(EDGEDB_PORT='777'): + self.run_testcase({ + 'env': { + 'EDGEDB_HOST': 'abc' + }, + 'opts': { + 'user': '__test__', + 'branch': 'new_branch', + }, + 'result': { + 'address': ['abc', 5656], + 'database': 'new_branch', + 'branch': 'new_branch', 'user': '__test__', 'password': None, 'secretKey': None, @@ -399,6 +427,7 @@ def test_project_config(self): password=None, secret_key=None, database=None, + branch=None, tls_ca=None, tls_ca_file=None, tls_security=None,