diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ca0eecc0..015cf07d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,14 +4,15 @@ on: workflow_dispatch: inputs: {} - pull_request: - branches: - - "master" - - "ci" - - "release/[0-9]+.x" - - "release/[0-9]+.[0-9]+.x" - paths: - - "edgedb/_version.py" + # XXX: Commented out to prevent firing during this refactor. + # pull_request: + # branches: + # - "master" + # - "ci" + # - "release/[0-9]+.x" + # - "release/[0-9]+.[0-9]+.x" + # paths: + # - "gel/_version.py" jobs: validate-release-request: diff --git a/.gitignore b/.gitignore index 84d8e3dc..4f1a755b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ docs/_build /edgedb/**/*.html /tmp /wheelhouse +/env* diff --git a/.gitmodules b/.gitmodules index 1830ced8..cdefde59 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,5 @@ [submodule "edgedb/pgproto"] - path = edgedb/pgproto + path = gel/pgproto url = https://github.com/MagicStack/py-pgproto.git [submodule "tests/shared-client-testcases"] path = tests/shared-client-testcases diff --git a/Makefile b/Makefile index 4244be36..9f641039 100644 --- a/Makefile +++ b/Makefile @@ -11,25 +11,25 @@ all: compile clean: rm -fr $(ROOT)/dist/ rm -fr $(ROOT)/doc/_build/ - rm -fr $(ROOT)/edgedb/pgproto/*.c - rm -fr $(ROOT)/edgedb/pgproto/*.html - rm -fr $(ROOT)/edgedb/pgproto/codecs/*.html - rm -fr $(ROOT)/edgedb/protocol/*.c - rm -fr $(ROOT)/edgedb/protocol/*.html - rm -fr $(ROOT)/edgedb/protocol/*.so - rm -fr $(ROOT)/edgedb/datatypes/*.so - rm -fr $(ROOT)/edgedb/datatypes/datatypes.c + rm -fr $(ROOT)/gel/pgproto/*.c + rm -fr $(ROOT)/gel/pgproto/*.html + rm -fr $(ROOT)/gel/pgproto/codecs/*.html + rm -fr $(ROOT)/gel/protocol/*.c + rm -fr $(ROOT)/gel/protocol/*.html + rm -fr $(ROOT)/gel/protocol/*.so + rm -fr $(ROOT)/gel/datatypes/*.so + rm -fr $(ROOT)/gel/datatypes/datatypes.c rm -fr $(ROOT)/build - rm -fr $(ROOT)/edgedb/protocol/codecs/*.html + rm -fr $(ROOT)/gel/protocol/codecs/*.html find . -name '__pycache__' | xargs rm -rf _touch: - rm -fr $(ROOT)/edgedb/datatypes/datatypes.c - rm -fr $(ROOT)/edgedb/protocol/protocol.c - find $(ROOT)/edgedb/protocol -name '*.pyx' | xargs touch - find $(ROOT)/edgedb/datatypes -name '*.pyx' | xargs touch - find $(ROOT)/edgedb/datatypes -name '*.c' | xargs touch + rm -fr $(ROOT)/gel/datatypes/datatypes.c + rm -fr $(ROOT)/gel/protocol/protocol.c + find $(ROOT)/gel/protocol -name '*.pyx' | xargs touch + find $(ROOT)/gel/datatypes -name '*.pyx' | xargs touch + find $(ROOT)/gel/datatypes -name '*.c' | xargs touch compile: _touch @@ -39,12 +39,12 @@ compile: _touch gen-errors: edb gen-errors --import "$(echo "from edgedb.errors._base import *"; echo "from edgedb.errors.tags import *")" \ --extra-all "_base.__all__" --stdout --client > $(ROOT)/.errors - mv $(ROOT)/.errors $(ROOT)/edgedb/errors/__init__.py + mv $(ROOT)/.errors $(ROOT)/gel/errors/__init__.py $(PYTHON) tools/gen_init.py gen-types: - edb gen-types --stdout > $(ROOT)/edgedb/protocol/codecs/edb_types.pxi + edb gen-types --stdout > $(ROOT)/gel/protocol/codecs/edb_types.pxi debug: _touch diff --git a/edgedb/__init__.py b/edgedb/__init__.py index f714cc29..711d3ee0 100644 --- a/edgedb/__init__.py +++ b/edgedb/__init__.py @@ -1,284 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -# flake8: noqa - -from ._version import __version__ - -from edgedb.datatypes.datatypes import ( - Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory -) -from edgedb.datatypes.datatypes import Set, Object, Array -from edgedb.datatypes.range import Range, MultiRange - -from .abstract import ( - Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor, -) - -from .asyncio_client import ( - create_async_client, - AsyncIOClient -) - -from .blocking_client import create_client, Client -from .enums import Cardinality, ElementKind -from .options import RetryCondition, IsolationLevel, default_backoff -from .options import RetryOptions, TransactionOptions -from .options import State - -from .errors._base import EdgeDBError, EdgeDBMessage - -__all__ = [ - "Array", - "AsyncIOClient", - "AsyncIOExecutor", - "AsyncIOReadOnlyExecutor", - "Cardinality", - "Client", - "ConfigMemory", - "DateDuration", - "EdgeDBError", - "EdgeDBMessage", - "ElementKind", - "EnumValue", - "Executor", - "IsolationLevel", - "NamedTuple", - "Object", - "Range", - "ReadOnlyExecutor", - "RelativeDuration", - "RetryCondition", - "RetryOptions", - "Set", - "State", - "TransactionOptions", - "Tuple", - "create_async_client", - "create_client", - "default_backoff", -] - - -# The below is generated by `make gen-errors`. -# DO NOT MODIFY BY HAND. -# -# -from .errors import ( - InternalServerError, - UnsupportedFeatureError, - ProtocolError, - BinaryProtocolError, - UnsupportedProtocolVersionError, - TypeSpecNotFoundError, - UnexpectedMessageError, - InputDataError, - ParameterTypeMismatchError, - StateMismatchError, - ResultCardinalityMismatchError, - CapabilityError, - UnsupportedCapabilityError, - DisabledCapabilityError, - QueryError, - InvalidSyntaxError, - EdgeQLSyntaxError, - SchemaSyntaxError, - GraphQLSyntaxError, - InvalidTypeError, - InvalidTargetError, - InvalidLinkTargetError, - InvalidPropertyTargetError, - InvalidReferenceError, - UnknownModuleError, - UnknownLinkError, - UnknownPropertyError, - UnknownUserError, - UnknownDatabaseError, - UnknownParameterError, - SchemaError, - SchemaDefinitionError, - InvalidDefinitionError, - InvalidModuleDefinitionError, - InvalidLinkDefinitionError, - InvalidPropertyDefinitionError, - InvalidUserDefinitionError, - InvalidDatabaseDefinitionError, - InvalidOperatorDefinitionError, - InvalidAliasDefinitionError, - InvalidFunctionDefinitionError, - InvalidConstraintDefinitionError, - InvalidCastDefinitionError, - DuplicateDefinitionError, - DuplicateModuleDefinitionError, - DuplicateLinkDefinitionError, - DuplicatePropertyDefinitionError, - DuplicateUserDefinitionError, - DuplicateDatabaseDefinitionError, - DuplicateOperatorDefinitionError, - DuplicateViewDefinitionError, - DuplicateFunctionDefinitionError, - DuplicateConstraintDefinitionError, - DuplicateCastDefinitionError, - DuplicateMigrationError, - SessionTimeoutError, - IdleSessionTimeoutError, - QueryTimeoutError, - TransactionTimeoutError, - IdleTransactionTimeoutError, - ExecutionError, - InvalidValueError, - DivisionByZeroError, - NumericOutOfRangeError, - AccessPolicyError, - QueryAssertionError, - IntegrityError, - ConstraintViolationError, - CardinalityViolationError, - MissingRequiredError, - TransactionError, - TransactionConflictError, - TransactionSerializationError, - TransactionDeadlockError, - WatchError, - ConfigurationError, - AccessError, - AuthenticationError, - AvailabilityError, - BackendUnavailableError, - ServerOfflineError, - BackendError, - UnsupportedBackendFeatureError, - LogMessage, - WarningMessage, - ClientError, - ClientConnectionError, - ClientConnectionFailedError, - ClientConnectionFailedTemporarilyError, - ClientConnectionTimeoutError, - ClientConnectionClosedError, - InterfaceError, - QueryArgumentError, - MissingArgumentError, - UnknownArgumentError, - InvalidArgumentError, - NoDataError, - InternalClientError, -) - -__all__.extend([ - "InternalServerError", - "UnsupportedFeatureError", - "ProtocolError", - "BinaryProtocolError", - "UnsupportedProtocolVersionError", - "TypeSpecNotFoundError", - "UnexpectedMessageError", - "InputDataError", - "ParameterTypeMismatchError", - "StateMismatchError", - "ResultCardinalityMismatchError", - "CapabilityError", - "UnsupportedCapabilityError", - "DisabledCapabilityError", - "QueryError", - "InvalidSyntaxError", - "EdgeQLSyntaxError", - "SchemaSyntaxError", - "GraphQLSyntaxError", - "InvalidTypeError", - "InvalidTargetError", - "InvalidLinkTargetError", - "InvalidPropertyTargetError", - "InvalidReferenceError", - "UnknownModuleError", - "UnknownLinkError", - "UnknownPropertyError", - "UnknownUserError", - "UnknownDatabaseError", - "UnknownParameterError", - "SchemaError", - "SchemaDefinitionError", - "InvalidDefinitionError", - "InvalidModuleDefinitionError", - "InvalidLinkDefinitionError", - "InvalidPropertyDefinitionError", - "InvalidUserDefinitionError", - "InvalidDatabaseDefinitionError", - "InvalidOperatorDefinitionError", - "InvalidAliasDefinitionError", - "InvalidFunctionDefinitionError", - "InvalidConstraintDefinitionError", - "InvalidCastDefinitionError", - "DuplicateDefinitionError", - "DuplicateModuleDefinitionError", - "DuplicateLinkDefinitionError", - "DuplicatePropertyDefinitionError", - "DuplicateUserDefinitionError", - "DuplicateDatabaseDefinitionError", - "DuplicateOperatorDefinitionError", - "DuplicateViewDefinitionError", - "DuplicateFunctionDefinitionError", - "DuplicateConstraintDefinitionError", - "DuplicateCastDefinitionError", - "DuplicateMigrationError", - "SessionTimeoutError", - "IdleSessionTimeoutError", - "QueryTimeoutError", - "TransactionTimeoutError", - "IdleTransactionTimeoutError", - "ExecutionError", - "InvalidValueError", - "DivisionByZeroError", - "NumericOutOfRangeError", - "AccessPolicyError", - "QueryAssertionError", - "IntegrityError", - "ConstraintViolationError", - "CardinalityViolationError", - "MissingRequiredError", - "TransactionError", - "TransactionConflictError", - "TransactionSerializationError", - "TransactionDeadlockError", - "WatchError", - "ConfigurationError", - "AccessError", - "AuthenticationError", - "AvailabilityError", - "BackendUnavailableError", - "ServerOfflineError", - "BackendError", - "UnsupportedBackendFeatureError", - "LogMessage", - "WarningMessage", - "ClientError", - "ClientConnectionError", - "ClientConnectionFailedError", - "ClientConnectionFailedTemporarilyError", - "ClientConnectionTimeoutError", - "ClientConnectionClosedError", - "InterfaceError", - "QueryArgumentError", - "MissingArgumentError", - "UnknownArgumentError", - "InvalidArgumentError", - "NoDataError", - "InternalClientError", -]) -# +# Auto-generated shim +import gel as _mod +import sys as _sys +_cur = _sys.modules['edgedb'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/_taskgroup.py b/edgedb/_taskgroup.py index 0a1859d7..93ca15db 100644 --- a/edgedb/_taskgroup.py +++ b/edgedb/_taskgroup.py @@ -1,295 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import asyncio -import functools -import itertools -import textwrap -import traceback - - -class TaskGroup: - - def __init__(self, *, name=None): - if name is None: - self._name = f'tg-{_name_counter()}' - else: - self._name = str(name) - - self._entered = False - self._exiting = False - self._aborting = False - self._loop = None - self._parent_task = None - self._parent_cancel_requested = False - self._tasks = set() - self._unfinished_tasks = 0 - self._errors = [] - self._base_error = None - self._on_completed_fut = None - - def get_name(self): - return self._name - - def __repr__(self): - msg = f'= 0 - - if self._exiting and not self._unfinished_tasks: - if not self._on_completed_fut.done(): - self._on_completed_fut.set_result(True) - - if task.cancelled(): - return - - exc = task.exception() - if exc is None: - return - - self._errors.append(exc) - if self._is_base_error(exc) and self._base_error is None: - self._base_error = exc - - if self._parent_task.done(): - # Not sure if this case is possible, but we want to handle - # it anyways. - self._loop.call_exception_handler({ - 'message': f'Task {task!r} has errored out but its parent ' - f'task {self._parent_task} is already completed', - 'exception': exc, - 'task': task, - }) - return - - self._abort() - if not self._parent_task.__cancel_requested__: - # If parent task *is not* being cancelled, it means that we want - # to manually cancel it to abort whatever is being run right now - # in the TaskGroup. But we want to mark parent task as - # "not cancelled" later in __aexit__. Example situation that - # we need to handle: - # - # async def foo(): - # try: - # async with TaskGroup() as g: - # g.create_task(crash_soon()) - # await something # <- this needs to be canceled - # # by the TaskGroup, e.g. - # # foo() needs to be cancelled - # except Exception: - # # Ignore any exceptions raised in the TaskGroup - # pass - # await something_else # this line has to be called - # # after TaskGroup is finished. - self._parent_cancel_requested = True - self._parent_task.cancel() - - -class MultiError(Exception): - - def __init__(self, msg, *args, errors=()): - if errors: - types = set(type(e).__name__ for e in errors) - msg = f'{msg}; {len(errors)} sub errors: ({", ".join(types)})' - for er in errors: - msg += f'\n + {type(er).__name__}: {er}' - if er.__traceback__: - er_tb = ''.join(traceback.format_tb(er.__traceback__)) - er_tb = textwrap.indent(er_tb, ' | ') - msg += f'\n{er_tb}\n' - super().__init__(msg, *args) - self.__errors__ = tuple(errors) - - def get_error_types(self): - return {type(e) for e in self.__errors__} - - def __reduce__(self): - return (type(self), (self.args,), {'__errors__': self.__errors__}) - - -class TaskGroupError(MultiError): - pass - - -_name_counter = itertools.count(1).__next__ +# Auto-generated shim +import gel._taskgroup as _mod +import sys as _sys +_cur = _sys.modules['edgedb._taskgroup'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/_version.py b/edgedb/_version.py index 41998d40..5f0756de 100644 --- a/edgedb/_version.py +++ b/edgedb/_version.py @@ -1,31 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This file MUST NOT contain anything but the __version__ assignment. -# -# When making a release, change the value of __version__ -# to an appropriate value, and open a pull request against -# the correct branch (master if making a new feature release). -# The commit message MUST contain a properly formatted release -# log, and the commit must be signed. -# -# The release automation will: build and test the packages for the -# supported platforms, publish the packages on PyPI, merge the PR -# to the target branch, create a Git tag pointing to the commit. - -__version__ = '3.0.0b2' +# Auto-generated shim +import gel._version as _mod +import sys as _sys +_cur = _sys.modules['edgedb._version'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/abstract.py b/edgedb/abstract.py index 0c2c06cc..8ba2b2a2 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -1,437 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2020-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from __future__ import annotations -import abc -import dataclasses -import typing - -from . import describe -from . import enums -from . import options -from .protocol import protocol - -__all__ = ( - "QueryWithArgs", - "QueryCache", - "QueryOptions", - "QueryContext", - "Executor", - "AsyncIOExecutor", - "ReadOnlyExecutor", - "AsyncIOReadOnlyExecutor", - "DescribeContext", - "DescribeResult", -) - - -class QueryWithArgs(typing.NamedTuple): - query: str - args: typing.Tuple - kwargs: typing.Dict[str, typing.Any] - input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL - - -class QueryCache(typing.NamedTuple): - codecs_registry: protocol.CodecsRegistry - query_cache: protocol.LRUMapping - - -class QueryOptions(typing.NamedTuple): - output_format: protocol.OutputFormat - expect_one: bool - required_one: bool - - -class QueryContext(typing.NamedTuple): - query: QueryWithArgs - cache: QueryCache - query_options: QueryOptions - retry_options: typing.Optional[options.RetryOptions] - state: typing.Optional[options.State] - warning_handler: options.WarningHandler - - def lower( - self, *, allow_capabilities: enums.Capability - ) -> protocol.ExecuteContext: - return protocol.ExecuteContext( - query=self.query.query, - args=self.query.args, - kwargs=self.query.kwargs, - reg=self.cache.codecs_registry, - qc=self.cache.query_cache, - input_language=self.query.input_language, - output_format=self.query_options.output_format, - expect_one=self.query_options.expect_one, - required_one=self.query_options.required_one, - allow_capabilities=allow_capabilities, - state=self.state.as_dict() if self.state else None, - ) - - -class ExecuteContext(typing.NamedTuple): - query: QueryWithArgs - cache: QueryCache - state: typing.Optional[options.State] - warning_handler: options.WarningHandler - - def lower( - self, *, allow_capabilities: enums.Capability - ) -> protocol.ExecuteContext: - return protocol.ExecuteContext( - query=self.query.query, - args=self.query.args, - kwargs=self.query.kwargs, - reg=self.cache.codecs_registry, - qc=self.cache.query_cache, - input_language=self.query.input_language, - output_format=protocol.OutputFormat.NONE, - allow_capabilities=allow_capabilities, - state=self.state.as_dict() if self.state else None, - ) - - -@dataclasses.dataclass -class DescribeContext: - query: str - state: typing.Optional[options.State] - inject_type_names: bool - input_language: protocol.InputLanguage - output_format: protocol.OutputFormat - expect_one: bool - - def lower( - self, *, allow_capabilities: enums.Capability - ) -> protocol.ExecuteContext: - return protocol.ExecuteContext( - query=self.query, - args=None, - kwargs=None, - reg=protocol.CodecsRegistry(), - qc=protocol.LRUMapping(maxsize=1), - input_language=self.input_language, - output_format=self.output_format, - expect_one=self.expect_one, - inline_typenames=self.inject_type_names, - allow_capabilities=allow_capabilities, - state=self.state.as_dict() if self.state else None, - ) - - -@dataclasses.dataclass -class DescribeResult: - input_type: typing.Optional[describe.AnyType] - output_type: typing.Optional[describe.AnyType] - output_cardinality: enums.Cardinality - capabilities: enums.Capability - - -_query_opts = QueryOptions( - output_format=protocol.OutputFormat.BINARY, - expect_one=False, - required_one=False, -) -_query_single_opts = QueryOptions( - output_format=protocol.OutputFormat.BINARY, - expect_one=True, - required_one=False, -) -_query_required_single_opts = QueryOptions( - output_format=protocol.OutputFormat.BINARY, - expect_one=True, - required_one=True, -) -_query_json_opts = QueryOptions( - output_format=protocol.OutputFormat.JSON, - expect_one=False, - required_one=False, -) -_query_single_json_opts = QueryOptions( - output_format=protocol.OutputFormat.JSON, - expect_one=True, - required_one=False, -) -_query_required_single_json_opts = QueryOptions( - output_format=protocol.OutputFormat.JSON, - expect_one=True, - required_one=True, -) - - -class BaseReadOnlyExecutor(abc.ABC): - __slots__ = () - - @abc.abstractmethod - def _get_query_cache(self) -> QueryCache: - ... - - def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: - return None - - @abc.abstractmethod - def _get_state(self) -> options.State: - ... - - @abc.abstractmethod - def _get_warning_handler(self) -> options.WarningHandler: - ... - - -class ReadOnlyExecutor(BaseReadOnlyExecutor): - """Subclasses can execute *at least* read-only queries""" - - __slots__ = () - - @abc.abstractmethod - def _query(self, query_context: QueryContext): - ... - - def query(self, query: str, *args, **kwargs) -> list: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_single_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_required_single_json(self, query: str, *args, **kwargs) -> str: - return self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def query_sql(self, query: str, *args, **kwargs) -> typing.Any: - return self._query(QueryContext( - query=QueryWithArgs( - query, - args, - kwargs, - input_language=protocol.InputLanguage.SQL, - ), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - @abc.abstractmethod - def _execute(self, execute_context: ExecuteContext): - ... - - def execute(self, commands: str, *args, **kwargs) -> None: - self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), - cache=self._get_query_cache(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - def execute_sql(self, commands: str, *args, **kwargs) -> None: - self._execute(ExecuteContext( - query=QueryWithArgs( - commands, - args, - kwargs, - input_language=protocol.InputLanguage.SQL, - ), - cache=self._get_query_cache(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - -class Executor(ReadOnlyExecutor): - """Subclasses can execute both read-only and modification queries""" - - __slots__ = () - - -class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor): - """Subclasses can execute *at least* read-only queries""" - - __slots__ = () - - @abc.abstractmethod - async def _query(self, query_context: QueryContext): - ... - - async def query(self, query: str, *args, **kwargs) -> list: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_single(self, query: str, *args, **kwargs) -> typing.Any: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_required_single( - self, - query: str, - *args, - **kwargs - ) -> typing.Any: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_json(self, query: str, *args, **kwargs) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_single_json(self, query: str, *args, **kwargs) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_required_single_json( - self, - query: str, - *args, - **kwargs - ) -> str: - return await self._query(QueryContext( - query=QueryWithArgs(query, args, kwargs), - cache=self._get_query_cache(), - query_options=_query_required_single_json_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: - return await self._query(QueryContext( - query=QueryWithArgs( - query, - args, - kwargs, - input_language=protocol.InputLanguage.SQL, - ), - cache=self._get_query_cache(), - query_options=_query_opts, - retry_options=self._get_retry_options(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - @abc.abstractmethod - async def _execute(self, execute_context: ExecuteContext) -> None: - ... - - async def execute(self, commands: str, *args, **kwargs) -> None: - await self._execute(ExecuteContext( - query=QueryWithArgs(commands, args, kwargs), - cache=self._get_query_cache(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - async def execute_sql(self, commands: str, *args, **kwargs) -> None: - await self._execute(ExecuteContext( - query=QueryWithArgs( - commands, - args, - kwargs, - input_language=protocol.InputLanguage.SQL, - ), - cache=self._get_query_cache(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - -class AsyncIOExecutor(AsyncIOReadOnlyExecutor): - """Subclasses can execute both read-only and modification queries""" - - __slots__ = () +# Auto-generated shim +import gel.abstract as _mod +import sys as _sys +_cur = _sys.modules['edgedb.abstract'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/ai/__init__.py b/edgedb/ai/__init__.py index 96111c2b..71c802a7 100644 --- a/edgedb/ai/__init__.py +++ b/edgedb/ai/__init__.py @@ -1,32 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext -from .core import create_ai, EdgeDBAI -from .core import create_async_ai, AsyncEdgeDBAI - -__all__ = [ - "AIOptions", - "ChatParticipantRole", - "Prompt", - "QueryContext", - "create_ai", - "EdgeDBAI", - "create_async_ai", - "AsyncEdgeDBAI", -] +# Auto-generated shim +import gel.ai as _mod +import sys as _sys +_cur = _sys.modules['edgedb.ai'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index 69fe235d..ec1645ca 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -1,191 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations -import typing - -import edgedb -import httpx -import httpx_sse - -from . import types - - -def create_ai(client: edgedb.Client, **kwargs) -> EdgeDBAI: - client.ensure_connected() - return EdgeDBAI(client, types.AIOptions(**kwargs)) - - -async def create_async_ai( - client: edgedb.AsyncIOClient, **kwargs -) -> AsyncEdgeDBAI: - await client.ensure_connected() - return AsyncEdgeDBAI(client, types.AIOptions(**kwargs)) - - -class BaseEdgeDBAI: - options: types.AIOptions - context: types.QueryContext - client_cls = NotImplemented - - def __init__( - self, - client: typing.Union[edgedb.Client, edgedb.AsyncIOClient], - options: types.AIOptions, - **kwargs, - ): - pool = client._impl - host, port = pool._working_addr - params = pool._working_params - proto = "http" if params.tls_security == "insecure" else "https" - branch = params.branch - self.options = options - self.context = types.QueryContext(**kwargs) - args = dict( - base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai", - verify=params.ssl_ctx, - ) - if params.password is not None: - args["auth"] = (params.user, params.password) - elif params.secret_key is not None: - args["headers"] = {"Authorization": f"Bearer {params.secret_key}"} - self._init_client(**args) - - def _init_client(self, **kwargs): - raise NotImplementedError - - def with_config(self, **kwargs) -> typing.Self: - cls = type(self) - rv = cls.__new__(cls) - rv.options = self.options.derive(kwargs) - rv.context = self.context - rv.client = self.client - return rv - - def with_context(self, **kwargs) -> typing.Self: - cls = type(self) - rv = cls.__new__(cls) - rv.options = self.options - rv.context = self.context.derive(kwargs) - rv.client = self.client - return rv - - def _make_rag_request( - self, - *, - message: str, - context: typing.Optional[types.QueryContext] = None, - stream: bool, - ) -> types.RAGRequest: - if context is None: - context = self.context - return types.RAGRequest( - model=self.options.model, - prompt=self.options.prompt, - context=context, - query=message, - stream=stream, - ) - - -class EdgeDBAI(BaseEdgeDBAI): - client: httpx.Client - - def _init_client(self, **kwargs): - self.client = httpx.Client(**kwargs) - - def query_rag( - self, message: str, context: typing.Optional[types.QueryContext] = None - ) -> str: - resp = self.client.post( - **self._make_rag_request( - context=context, - message=message, - stream=False, - ).to_httpx_request() - ) - resp.raise_for_status() - return resp.json()["response"] - - def stream_rag( - self, message: str, context: typing.Optional[types.QueryContext] = None - ) -> typing.Iterator[str]: - with httpx_sse.connect_sse( - self.client, - "post", - **self._make_rag_request( - context=context, - message=message, - stream=True, - ).to_httpx_request(), - ) as event_source: - event_source.response.raise_for_status() - for sse in event_source.iter_sse(): - yield sse.data - - def generate_embeddings(self, *inputs: str, model: str) -> list[float]: - resp = self.client.post( - "/embeddings", json={"input": inputs, "model": model} - ) - resp.raise_for_status() - return resp.json()["data"][0]["embedding"] - - -class AsyncEdgeDBAI(BaseEdgeDBAI): - client: httpx.AsyncClient - - def _init_client(self, **kwargs): - self.client = httpx.AsyncClient(**kwargs) - - async def query_rag( - self, message: str, context: typing.Optional[types.QueryContext] = None - ) -> str: - resp = await self.client.post( - **self._make_rag_request( - context=context, - message=message, - stream=False, - ).to_httpx_request() - ) - resp.raise_for_status() - return resp.json()["response"] - - async def stream_rag( - self, message: str, context: typing.Optional[types.QueryContext] = None - ) -> typing.Iterator[str]: - async with httpx_sse.aconnect_sse( - self.client, - "post", - **self._make_rag_request( - context=context, - message=message, - stream=True, - ).to_httpx_request(), - ) as event_source: - event_source.response.raise_for_status() - async for sse in event_source.aiter_sse(): - yield sse.data - - async def generate_embeddings( - self, *inputs: str, model: str - ) -> list[float]: - resp = await self.client.post( - "/embeddings", json={"input": inputs, "model": model} - ) - resp.raise_for_status() - return resp.json()["data"][0]["embedding"] +# Auto-generated shim +import gel.ai.core as _mod +import sys as _sys +_cur = _sys.modules['edgedb.ai.core'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/ai/types.py b/edgedb/ai/types.py index 41bf24c0..e4d1f77d 100644 --- a/edgedb/ai/types.py +++ b/edgedb/ai/types.py @@ -1,81 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import typing - -import dataclasses as dc -import enum - - -class ChatParticipantRole(enum.Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - - -class Custom(typing.TypedDict): - role: ChatParticipantRole - content: str - - -class Prompt: - name: typing.Optional[str] - id: typing.Optional[str] - custom: typing.Optional[typing.List[Custom]] - - -@dc.dataclass -class AIOptions: - model: str - prompt: typing.Optional[Prompt] = None - - def derive(self, kwargs): - return AIOptions(**{**dc.asdict(self), **kwargs}) - - -@dc.dataclass -class QueryContext: - query: str = "" - variables: typing.Optional[typing.Dict[str, typing.Any]] = None - globals: typing.Optional[typing.Dict[str, typing.Any]] = None - max_object_count: typing.Optional[int] = None - - def derive(self, kwargs): - return QueryContext(**{**dc.asdict(self), **kwargs}) - - -@dc.dataclass -class RAGRequest: - model: str - prompt: typing.Optional[Prompt] - context: QueryContext - query: str - stream: typing.Optional[bool] - - def to_httpx_request(self) -> typing.Dict[str, typing.Any]: - return dict( - url="/rag", - headers={ - "Content-Type": "application/json", - "Accept": ( - "text/event-stream" if self.stream else "application/json" - ), - }, - json=dc.asdict(self), - ) +# Auto-generated shim +import gel.ai.types as _mod +import sys as _sys +_cur = _sys.modules['edgedb.ai.types'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index 78625040..bc80c00f 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -1,448 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import asyncio -import contextlib -import logging -import socket -import ssl -import typing - -from . import abstract -from . import base_client -from . import con_utils -from . import errors -from . import transaction -from .protocol import asyncio_proto -from .protocol.protocol import InputLanguage, OutputFormat - - -__all__ = ( - 'create_async_client', 'AsyncIOClient' -) - - -logger = logging.getLogger(__name__) - - -class AsyncIOConnection(base_client.BaseConnection): - __slots__ = ("_loop",) - - def __init__(self, loop, *args, **kwargs): - super().__init__(*args, **kwargs) - self._loop = loop - - def is_closed(self): - protocol = self._protocol - return protocol is None or not protocol.connected - - async def connect_addr(self, addr, timeout): - try: - await asyncio.wait_for(self._connect_addr(addr), timeout) - except asyncio.TimeoutError as e: - raise TimeoutError from e - - async def sleep(self, seconds): - await asyncio.sleep(seconds) - - async def aclose(self): - """Send graceful termination message wait for connection to drop.""" - if not self.is_closed(): - try: - self._protocol.terminate() - await self._protocol.wait_for_disconnect() - except (Exception, asyncio.CancelledError): - self.terminate() - raise - finally: - self._cleanup() - - def _protocol_factory(self): - return asyncio_proto.AsyncIOProtocol(self._params, self._loop) - - async def _connect_addr(self, addr): - tr = None - - try: - if isinstance(addr, str): - # UNIX socket - tr, pr = await self._loop.create_unix_connection( - self._protocol_factory, addr - ) - else: - try: - tr, pr = await self._loop.create_connection( - self._protocol_factory, - *addr, - ssl=self._params.ssl_ctx, - server_hostname=( - self._params.tls_server_name or addr[0] - ), - ) - except ssl.CertificateError as e: - raise con_utils.wrap_error(e) from e - except ssl.SSLError as e: - raise con_utils.wrap_error(e) from e - else: - con_utils.check_alpn_protocol( - tr.get_extra_info('ssl_object') - ) - except socket.gaierror as e: - # All name resolution errors are considered temporary - raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e - except OSError as e: - raise con_utils.wrap_error(e) from e - except Exception: - if tr is not None: - tr.close() - raise - - pr.set_connection(self) - - try: - await pr.connect() - except OSError as e: - if tr is not None: - tr.close() - raise con_utils.wrap_error(e) from e - except BaseException: - if tr is not None: - tr.close() - raise - - self._protocol = pr - self._addr = addr - - def _dispatch_log_message(self, msg): - for cb in self._log_listeners: - self._loop.call_soon(cb, self, msg) - - -class _PoolConnectionHolder(base_client.PoolConnectionHolder): - __slots__ = () - _event_class = asyncio.Event - - async def close(self, *, wait=True): - if self._con is None: - return - if wait: - await self._con.aclose() - else: - self._pool._loop.create_task(self._con.aclose()) - - async def wait_until_released(self, timeout=None): - await self._release_event.wait() - - -class _AsyncIOPoolImpl(base_client.BasePoolImpl): - __slots__ = ('_loop',) - _holder_class = _PoolConnectionHolder - - def __init__( - self, - connect_args, - *, - max_concurrency: typing.Optional[int], - connection_class, - ): - if not issubclass(connection_class, AsyncIOConnection): - raise TypeError( - f'connection_class is expected to be a subclass of ' - f'edgedb.asyncio_client.AsyncIOConnection, ' - f'got {connection_class}') - self._loop = None - super().__init__( - connect_args, - lambda *args: connection_class(self._loop, *args), - max_concurrency=max_concurrency, - ) - - def _ensure_initialized(self): - if self._loop is None: - self._loop = asyncio.get_event_loop() - self._queue = asyncio.LifoQueue(maxsize=self._max_concurrency) - self._first_connect_lock = asyncio.Lock() - self._resize_holder_pool() - - def _set_queue_maxsize(self, maxsize): - self._queue._maxsize = maxsize - - async def _maybe_get_first_connection(self): - async with self._first_connect_lock: - if self._working_addr is None: - return await self._get_first_connection() - - async def acquire(self, timeout=None): - self._ensure_initialized() - - async def _acquire_impl(): - ch = await self._queue.get() # type: _PoolConnectionHolder - try: - proxy = await ch.acquire() # type: AsyncIOConnection - except (Exception, asyncio.CancelledError): - self._queue.put_nowait(ch) - raise - else: - # Record the timeout, as we will apply it by default - # in release(). - ch._timeout = timeout - return proxy - - if self._closing: - raise errors.InterfaceError('pool is closing') - - if timeout is None: - return await _acquire_impl() - else: - return await asyncio.wait_for( - _acquire_impl(), timeout=timeout) - - async def _release(self, holder): - - if not isinstance(holder._con, AsyncIOConnection): - raise errors.InterfaceError( - f'release() received invalid connection: ' - f'{holder._con!r} does not belong to any connection pool' - ) - - timeout = None - - # Use asyncio.shield() to guarantee that task cancellation - # does not prevent the connection from being returned to the - # pool properly. - return await asyncio.shield(holder.release(timeout)) - - async def aclose(self): - """Attempt to gracefully close all connections in the pool. - - Wait until all pool connections are released, close them and - shut down the pool. If any error (including cancellation) occurs - in ``close()`` the pool will terminate by calling - _AsyncIOPoolImpl.terminate() . - - It is advisable to use :func:`python:asyncio.wait_for` to set - a timeout. - """ - if self._closed: - return - - if not self._loop: - self._closed = True - return - - self._closing = True - - try: - warning_callback = self._loop.call_later( - 60, self._warn_on_long_close) - - release_coros = [ - ch.wait_until_released() for ch in self._holders] - await asyncio.gather(*release_coros) - - close_coros = [ - ch.close() for ch in self._holders] - await asyncio.gather(*close_coros) - - except (Exception, asyncio.CancelledError): - self.terminate() - raise - - finally: - warning_callback.cancel() - self._closed = True - self._closing = False - - def _warn_on_long_close(self): - logger.warning( - 'AsyncIOClient.aclose() is taking over 60 seconds to complete. ' - 'Check if you have any unreleased connections left. ' - 'Use asyncio.wait_for() to set a timeout for ' - 'AsyncIOClient.aclose().') - - -class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor): - - __slots__ = ("_managed", "_locked") - - def __init__(self, retry, client, iteration): - super().__init__(retry, client, iteration) - self._managed = False - self._locked = False - - async def __aenter__(self): - if self._managed: - raise errors.InterfaceError( - 'cannot enter context: already in an `async with` block') - self._managed = True - return self - - async def __aexit__(self, extype, ex, tb): - with self._exclusive(): - self._managed = False - return await self._exit(extype, ex) - - async def _ensure_transaction(self): - if not self._managed: - raise errors.InterfaceError( - "Only managed retriable transactions are supported. " - "Use `async with transaction:`" - ) - await super()._ensure_transaction() - - async def _query(self, query_context: abstract.QueryContext): - with self._exclusive(): - return await super()._query(query_context) - - async def _execute(self, execute_context: abstract.ExecuteContext) -> None: - with self._exclusive(): - await super()._execute(execute_context) - - @contextlib.contextmanager - def _exclusive(self): - if self._locked: - raise errors.InterfaceError( - "concurrent queries within the same transaction " - "are not allowed" - ) - self._locked = True - try: - yield - finally: - self._locked = False - - -class AsyncIORetry(transaction.BaseRetry): - - def __aiter__(self): - return self - - async def __anext__(self): - # Note: when changing this code consider also - # updating Retry.__next__. - if self._done: - raise StopAsyncIteration - if self._next_backoff: - await asyncio.sleep(self._next_backoff) - self._done = True - iteration = AsyncIOIteration(self, self._owner, self._iteration) - self._iteration += 1 - return iteration - - -class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor): - """A lazy connection pool. - - A Client can be used to manage a set of connections to the database. - Connections are first acquired from the pool, then used, and then released - back to the pool. Once a connection is released, it's reset to close all - open cursors and other resources *except* prepared statements. - - Clients are created by calling - :func:`~edgedb.asyncio_client.create_async_client`. - """ - - __slots__ = () - _impl_class = _AsyncIOPoolImpl - - async def ensure_connected(self): - await self._impl.ensure_connected() - return self - - async def aclose(self): - """Attempt to gracefully close all connections in the pool. - - Wait until all pool connections are released, close them and - shut down the pool. If any error (including cancellation) occurs - in ``aclose()`` the pool will terminate by calling - AsyncIOClient.terminate() . - - It is advisable to use :func:`python:asyncio.wait_for` to set - a timeout. - """ - await self._impl.aclose() - - def transaction(self) -> AsyncIORetry: - return AsyncIORetry(self) - - async def __aenter__(self): - return await self.ensure_connected() - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.aclose() - - async def _describe_query( - self, - query: str, - *, - inject_type_names: bool = False, - input_language: InputLanguage = InputLanguage.EDGEQL, - output_format: OutputFormat = OutputFormat.BINARY, - expect_one: bool = False, - ) -> abstract.DescribeResult: - return await self._describe(abstract.DescribeContext( - query=query, - state=self._get_state(), - inject_type_names=inject_type_names, - input_language=input_language, - output_format=output_format, - expect_one=expect_one, - )) - - -def create_async_client( - dsn=None, - *, - max_concurrency=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - branch: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, - wait_until_available: int = 30, - timeout: int = 10, -): - return AsyncIOClient( - connection_class=AsyncIOConnection, - max_concurrency=max_concurrency, - - # connect arguments - dsn=dsn, - host=host, - port=port, - credentials=credentials, - credentials_file=credentials_file, - user=user, - password=password, - secret_key=secret_key, - database=database, - branch=branch, - tls_ca=tls_ca, - tls_ca_file=tls_ca_file, - tls_security=tls_security, - wait_until_available=wait_until_available, - timeout=timeout, - ) +# Auto-generated shim +import gel.asyncio_client as _mod +import sys as _sys +_cur = _sys.modules['edgedb.asyncio_client'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/base_client.py b/edgedb/base_client.py index e1610219..0a491133 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -1,734 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import abc -import random -import time -import typing - -from . import abstract -from . import con_utils -from . import enums -from . import errors -from . import options as _options -from .protocol import protocol - - -BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') -QUERY_CACHE_SIZE = 1000 - - -class BaseConnection(metaclass=abc.ABCMeta): - _protocol: typing.Any - _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]] - _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]] - _config: con_utils.ClientConfiguration - _params: con_utils.ResolvedConnectConfig - _log_listeners: typing.Set[ - typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None] - ] - __slots__ = ( - "__weakref__", - "_protocol", - "_addr", - "_addrs", - "_config", - "_params", - "_log_listeners", - "_holder", - ) - - def __init__( - self, - addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]], - config: con_utils.ClientConfiguration, - params: con_utils.ResolvedConnectConfig, - ): - self._addr = None - self._protocol = None - self._addrs = addrs - self._config = config - self._params = params - self._log_listeners = set() - self._holder = None - - @abc.abstractmethod - def _dispatch_log_message(self, msg): - ... - - def _on_log_message(self, msg): - if self._log_listeners: - self._dispatch_log_message(msg) - - def connected_addr(self): - return self._addr - - def _get_last_status(self) -> typing.Optional[str]: - if self._protocol is None: - return None - status = self._protocol.last_status - if status is not None: - status = status.decode() - return status - - def _cleanup(self): - self._log_listeners.clear() - if self._holder: - self._holder._release_on_close() - self._holder = None - - def add_log_listener( - self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] - ) -> None: - """Add a listener for EdgeDB log messages. - - :param callable callback: - A callable receiving the following arguments: - **connection**: a Connection the callback is registered with; - **message**: the `edgedb.EdgeDBMessage` message. - """ - self._log_listeners.add(callback) - - def remove_log_listener( - self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] - ) -> None: - """Remove a listening callback for log messages.""" - self._log_listeners.discard(callback) - - @property - def dbname(self) -> str: - return self._params.database - - @property - def branch(self) -> str: - return self._params.branch - - @abc.abstractmethod - def is_closed(self) -> bool: - ... - - @abc.abstractmethod - async def connect_addr(self, addr, timeout): - ... - - @abc.abstractmethod - async def sleep(self, seconds): - ... - - async def connect(self, *, single_attempt=False): - start = time.monotonic() - if single_attempt: - max_time = 0 - else: - max_time = start + self._config.wait_until_available - iteration = 1 - - while True: - for addr in self._addrs: - try: - await self.connect_addr(addr, self._config.connect_timeout) - except TimeoutError as e: - if iteration == 1 or time.monotonic() < max_time: - continue - else: - raise errors.ClientConnectionTimeoutError( - f"connecting to {addr} failed in" - f" {self._config.connect_timeout} sec" - ) from e - except errors.ClientConnectionError as e: - if ( - e.has_tag(errors.SHOULD_RECONNECT) and - (iteration == 1 or time.monotonic() < max_time) - ): - continue - nice_err = e.__class__( - con_utils.render_client_no_connection_error( - e, - addr, - attempts=iteration, - duration=time.monotonic() - start, - )) - raise nice_err from e.__cause__ - else: - return - - iteration += 1 - await self.sleep(0.01 + random.random() * 0.2) - - async def privileged_execute( - self, execute_context: abstract.ExecuteContext - ): - if self._protocol.is_legacy: - await self._protocol.legacy_simple_query( - execute_context.query.query, enums.Capability.ALL - ) - else: - await self._protocol.execute( - execute_context.lower(allow_capabilities=enums.Capability.ALL) - ) - - def is_in_transaction(self) -> bool: - """Return True if Connection is currently inside a transaction. - - :return bool: True if inside transaction, False otherwise. - """ - return self._protocol.is_in_transaction() - - def get_settings(self) -> typing.Dict[str, typing.Any]: - return self._protocol.get_settings() - - async def raw_query(self, query_context: abstract.QueryContext): - if self.is_closed(): - await self.connect() - - reconnect = False - i = 0 - if self._protocol.is_legacy: - allow_capabilities = enums.Capability.LEGACY_EXECUTE - else: - allow_capabilities = enums.Capability.EXECUTE - ctx = query_context.lower(allow_capabilities=allow_capabilities) - while True: - i += 1 - try: - if reconnect: - await self.connect(single_attempt=True) - if self._protocol.is_legacy: - return await self._protocol.legacy_execute_anonymous(ctx) - else: - res = await self._protocol.query(ctx) - if ctx.warnings: - res = query_context.warning_handler(ctx.warnings, res) - return res - - except errors.EdgeDBError as e: - if query_context.retry_options is None: - raise - if not e.has_tag(errors.SHOULD_RETRY): - raise e - # A query is read-only if it has no capabilities i.e. - # capabilities == 0. Read-only queries are safe to retry. - # Explicit transaction conflicts as well. - if ( - ctx.capabilities != 0 - and not isinstance(e, errors.TransactionConflictError) - ): - raise e - rule = query_context.retry_options.get_rule_for_exception(e) - if i >= rule.attempts: - raise e - await self.sleep(rule.backoff(i)) - reconnect = self.is_closed() - - async def _execute(self, execute_context: abstract.ExecuteContext) -> None: - if self._protocol.is_legacy: - if execute_context.query.args or execute_context.query.kwargs: - raise errors.InterfaceError( - "Legacy protocol doesn't support arguments in execute()" - ) - await self._protocol.legacy_simple_query( - execute_context.query.query, enums.Capability.LEGACY_EXECUTE - ) - else: - ctx = execute_context.lower( - allow_capabilities=enums.Capability.EXECUTE - ) - res = await self._protocol.execute(ctx) - if ctx.warnings: - res = execute_context.warning_handler(ctx.warnings, res) - - async def describe( - self, describe_context: abstract.DescribeContext - ) -> abstract.DescribeResult: - ctx = describe_context.lower( - allow_capabilities=enums.Capability.EXECUTE - ) - await self._protocol._parse(ctx) - return abstract.DescribeResult( - input_type=ctx.in_dc.make_type(describe_context), - output_type=ctx.out_dc.make_type(describe_context), - output_cardinality=enums.Cardinality(ctx.cardinality[0]), - capabilities=ctx.capabilities, - ) - - def terminate(self): - if not self.is_closed(): - try: - self._protocol.abort() - finally: - self._cleanup() - - def __repr__(self): - if self.is_closed(): - return '<{classname} [closed] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) - else: - return '<{classname} [connected to {addr}] {id:#x}>'.format( - classname=self.__class__.__name__, - addr=self.connected_addr(), - id=id(self)) - - -class PoolConnectionHolder(abc.ABC): - __slots__ = ( - "_con", - "_pool", - "_release_event", - "_timeout", - "_generation", - ) - _event_class = NotImplemented - - def __init__(self, pool): - - self._pool = pool - self._con = None - - self._timeout = None - self._generation = None - - self._release_event = self._event_class() - self._release_event.set() - - @abc.abstractmethod - async def close(self, *, wait=True): - ... - - @abc.abstractmethod - async def wait_until_released(self, timeout=None): - ... - - async def connect(self): - if self._con is not None: - raise errors.InternalClientError( - 'PoolConnectionHolder.connect() called while another ' - 'connection already exists') - - self._con = await self._pool._get_new_connection() - assert self._con._holder is None - self._con._holder = self - self._generation = self._pool._generation - - async def acquire(self) -> BaseConnection: - if self._con is None or self._con.is_closed(): - self._con = None - await self.connect() - - elif self._generation != self._pool._generation: - # Connections have been expired, re-connect the holder. - self._con._holder = None # don't release the connection - await self.close(wait=False) - self._con = None - await self.connect() - - self._release_event.clear() - - return self._con - - async def release(self, timeout): - if self._release_event.is_set(): - raise errors.InternalClientError( - 'PoolConnectionHolder.release() called on ' - 'a free connection holder') - - if self._con.is_closed(): - # This is usually the case when the connection is broken rather - # than closed by the user, so we need to call _release_on_close() - # here to release the holder back to the queue, because - # self._con._cleanup() was never called. On the other hand, it is - # safe to call self._release() twice - the second call is no-op. - self._release_on_close() - return - - self._timeout = None - - if self._generation != self._pool._generation: - # The connection has expired because it belongs to - # an older generation (BasePoolImpl.expire_connections() has - # been called.) - await self.close() - return - - # Free this connection holder and invalidate the - # connection proxy. - self._release() - - def terminate(self): - if self._con is not None: - # AsyncIOConnection.terminate() will call _release_on_close() to - # finish holder cleanup. - self._con.terminate() - - def _release_on_close(self): - self._release() - self._con = None - - def _release(self): - """Release this connection holder.""" - if self._release_event.is_set(): - # The holder is not checked out. - return - - self._release_event.set() - - # Put ourselves back to the pool queue. - self._pool._queue.put_nowait(self) - - -class BasePoolImpl(abc.ABC): - __slots__ = ( - "_connect_args", - "_codecs_registry", - "_query_cache", - "_connection_factory", - "_queue", - "_user_max_concurrency", - "_max_concurrency", - "_first_connect_lock", - "_working_addr", - "_working_config", - "_working_params", - "_holders", - "_initialized", - "_initializing", - "_closing", - "_closed", - "_generation", - ) - - _holder_class = NotImplemented - - def __init__( - self, - connect_args, - connection_factory, - *, - max_concurrency: typing.Optional[int], - ): - self._connection_factory = connection_factory - self._connect_args = connect_args - self._codecs_registry = protocol.CodecsRegistry() - self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) - - if max_concurrency is not None and max_concurrency <= 0: - raise ValueError( - 'max_concurrency is expected to be greater than zero' - ) - - self._user_max_concurrency = max_concurrency - self._max_concurrency = max_concurrency if max_concurrency else 1 - - self._holders = [] - self._queue = None - - self._first_connect_lock = None - self._working_addr = None - self._working_config = None - self._working_params = None - - self._closing = False - self._closed = False - self._generation = 0 - - @abc.abstractmethod - def _ensure_initialized(self): - ... - - @abc.abstractmethod - def _set_queue_maxsize(self, maxsize): - ... - - @abc.abstractmethod - async def _maybe_get_first_connection(self): - ... - - @abc.abstractmethod - async def acquire(self, timeout=None): - ... - - @abc.abstractmethod - async def _release(self, connection): - ... - - @property - def codecs_registry(self): - return self._codecs_registry - - @property - def query_cache(self): - return self._query_cache - - def _resize_holder_pool(self): - resize_diff = self._max_concurrency - len(self._holders) - - if (resize_diff > 0): - if self._queue.maxsize != self._max_concurrency: - self._set_queue_maxsize(self._max_concurrency) - - for _ in range(resize_diff): - ch = self._holder_class(self) - - self._holders.append(ch) - self._queue.put_nowait(ch) - elif resize_diff < 0: - # TODO: shrink the pool - pass - - def get_max_concurrency(self): - return self._max_concurrency - - def get_free_size(self): - if self._queue is None: - # Queue has not been initialized yet - return self._max_concurrency - - return self._queue.qsize() - - def set_connect_args(self, dsn=None, **connect_kwargs): - r"""Set the new connection arguments for this pool. - - The new connection arguments will be used for all subsequent - new connection attempts. Existing connections will remain until - they expire. Use BasePoolImpl.expire_connections() to expedite - the connection expiry. - - :param str dsn: - Connection arguments specified using as a single string in - the following format: - ``edgedb://user:pass@host:port/database?option=value``. - - :param \*\*connect_kwargs: - Keyword arguments for the - :func:`~edgedb.asyncio_client.create_async_client` function. - """ - - connect_kwargs["dsn"] = dsn - self._connect_args = connect_kwargs - self._codecs_registry = protocol.CodecsRegistry() - self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) - self._working_addr = None - self._working_config = None - self._working_params = None - - async def _get_first_connection(self): - # First connection attempt on this pool. - connect_config, client_config = con_utils.parse_connect_arguments( - **self._connect_args, - # ToDos - command_timeout=None, - server_settings=None, - ) - con = self._connection_factory( - [connect_config.address], client_config, connect_config - ) - await con.connect() - self._working_addr = con.connected_addr() - self._working_config = client_config - self._working_params = connect_config - - if self._user_max_concurrency is None: - suggested_concurrency = con.get_settings().get( - 'suggested_pool_concurrency') - if suggested_concurrency: - self._max_concurrency = suggested_concurrency - self._resize_holder_pool() - return con - - async def _get_new_connection(self): - con = None - if self._working_addr is None: - con = await self._maybe_get_first_connection() - if con is None: - assert self._working_addr is not None - # We've connected before and have a resolved address, - # and parsed options and config. - con = self._connection_factory( - [self._working_addr], - self._working_config, - self._working_params, - ) - await con.connect() - - return con - - async def release(self, connection): - - if not isinstance(connection, BaseConnection): - raise errors.InterfaceError( - f'BasePoolImpl.release() received invalid connection: ' - f'{connection!r} does not belong to any connection pool' - ) - - ch = connection._holder - if ch is None: - # Already released, do nothing. - return - - if ch._pool is not self: - raise errors.InterfaceError( - f'BasePoolImpl.release() received invalid connection: ' - f'{connection!r} is not a member of this pool' - ) - - return await self._release(ch) - - def terminate(self): - """Terminate all connections in the pool.""" - if self._closed: - return - for ch in self._holders: - ch.terminate() - self._closed = True - - def expire_connections(self): - """Expire all currently open connections. - - Cause all currently open connections to get replaced on the - next query. - """ - self._generation += 1 - - async def ensure_connected(self): - self._ensure_initialized() - - for ch in self._holders: - if ch._con is not None and not ch._con.is_closed(): - return - - ch = self._holders[0] - ch._con = None - await ch.connect() - - -class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin): - __slots__ = ("_impl", "_options") - _impl_class = NotImplemented - - def __init__( - self, - *, - connection_class, - max_concurrency: typing.Optional[int], - dsn=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - branch: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, - tls_server_name: str = None, - wait_until_available: int = 30, - timeout: int = 10, - **kwargs, - ): - super().__init__() - connect_args = { - "dsn": dsn, - "host": host, - "port": port, - "credentials": credentials, - "credentials_file": credentials_file, - "user": user, - "password": password, - "secret_key": secret_key, - "database": database, - "branch": branch, - "timeout": timeout, - "tls_ca": tls_ca, - "tls_ca_file": tls_ca_file, - "tls_security": tls_security, - "tls_server_name": tls_server_name, - "wait_until_available": wait_until_available, - } - - self._impl = self._impl_class( - connect_args, - connection_class=connection_class, - max_concurrency=max_concurrency, - **kwargs, - ) - - def _shallow_clone(self): - new_client = self.__class__.__new__(self.__class__) - new_client._impl = self._impl - return new_client - - def _get_query_cache(self) -> abstract.QueryCache: - return abstract.QueryCache( - codecs_registry=self._impl.codecs_registry, - query_cache=self._impl.query_cache, - ) - - def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: - return self._options.retry_options - - def _get_state(self) -> _options.State: - return self._options.state - - def _get_warning_handler(self) -> _options.WarningHandler: - return self._options.warning_handler - - @property - def max_concurrency(self) -> int: - """Max number of connections in the pool.""" - - return self._impl.get_max_concurrency() - - @property - def free_size(self) -> int: - """Number of available connections in the pool.""" - - return self._impl.get_free_size() - - async def _query(self, query_context: abstract.QueryContext): - con = await self._impl.acquire() - try: - return await con.raw_query(query_context) - finally: - await self._impl.release(con) - - async def _execute(self, execute_context: abstract.ExecuteContext) -> None: - con = await self._impl.acquire() - try: - await con._execute(execute_context) - finally: - await self._impl.release(con) - - async def _describe( - self, describe_context: abstract.DescribeContext - ) -> abstract.DescribeResult: - con = await self._impl.acquire() - try: - return await con.describe(describe_context) - finally: - await self._impl.release(con) - - def terminate(self): - """Terminate all connections in the pool.""" - self._impl.terminate() +# Auto-generated shim +import gel.base_client as _mod +import sys as _sys +_cur = _sys.modules['edgedb.base_client'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index f1706e59..12788a48 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -1,490 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import contextlib -import datetime -import queue -import socket -import ssl -import threading -import time -import typing - -from . import abstract -from . import base_client -from . import con_utils -from . import errors -from . import transaction -from .protocol import blocking_proto -from .protocol.protocol import InputLanguage, OutputFormat - - -DEFAULT_PING_BEFORE_IDLE_TIMEOUT = datetime.timedelta(seconds=5) -MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1) - - -class BlockingIOConnection(base_client.BaseConnection): - __slots__ = ("_ping_wait_time",) - - async def connect_addr(self, addr, timeout): - deadline = time.monotonic() + timeout - - if isinstance(addr, str): - # UNIX socket - res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)] - else: - host, port = addr - try: - # getaddrinfo() doesn't take timeout!! - res_list = socket.getaddrinfo( - host, port, socket.AF_UNSPEC, socket.SOCK_STREAM - ) - except socket.gaierror as e: - # All name resolution errors are considered temporary - err = errors.ClientConnectionFailedTemporarilyError(str(e)) - raise err from e - - for i, res in enumerate(res_list): - af, socktype, proto, _, sa = res - try: - sock = socket.socket(af, socktype, proto) - except OSError as e: - sock.close() - if i < len(res_list) - 1: - continue - else: - raise con_utils.wrap_error(e) from e - try: - await self._connect_addr(sock, addr, sa, deadline) - except TimeoutError: - raise - except Exception: - if i < len(res_list) - 1: - continue - else: - raise - else: - break - - async def _connect_addr(self, sock, addr, sa, deadline): - try: - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - try: - sock.settimeout(time_left) - sock.connect(sa) - except OSError as e: - raise con_utils.wrap_error(e) from e - - if not isinstance(addr, str): - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - try: - # Upgrade to TLS - sock.settimeout(time_left) - try: - sock = self._params.ssl_ctx.wrap_socket( - sock, - server_hostname=( - self._params.tls_server_name or addr[0] - ), - ) - except ssl.CertificateError as e: - raise con_utils.wrap_error(e) from e - except ssl.SSLError as e: - raise con_utils.wrap_error(e) from e - else: - con_utils.check_alpn_protocol(sock) - except OSError as e: - raise con_utils.wrap_error(e) from e - - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - - if not isinstance(addr, str): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - proto = blocking_proto.BlockingIOProtocol(self._params, sock) - proto.set_connection(self) - - try: - await proto.wait_for(proto.connect(), time_left) - except TimeoutError: - raise - except OSError as e: - raise con_utils.wrap_error(e) from e - - self._protocol = proto - self._addr = addr - self._ping_wait_time = max( - ( - self.get_settings() - .get("system_config") - .session_idle_timeout - - DEFAULT_PING_BEFORE_IDLE_TIMEOUT - ), - MINIMUM_PING_WAIT_TIME, - ).total_seconds() - - except Exception: - sock.close() - raise - - async def sleep(self, seconds): - time.sleep(seconds) - - def is_closed(self): - proto = self._protocol - return not (proto and proto.sock is not None and - proto.sock.fileno() >= 0 and proto.connected) - - async def close(self, timeout=None): - """Send graceful termination message wait for connection to drop.""" - if not self.is_closed(): - try: - self._protocol.terminate() - if timeout is None: - await self._protocol.wait_for_disconnect() - else: - await self._protocol.wait_for( - self._protocol.wait_for_disconnect(), timeout - ) - except TimeoutError: - self.terminate() - raise errors.QueryTimeoutError() - except Exception: - self.terminate() - raise - finally: - self._cleanup() - - def _dispatch_log_message(self, msg): - for cb in self._log_listeners: - cb(self, msg) - - async def raw_query(self, query_context: abstract.QueryContext): - try: - if ( - time.monotonic() - self._protocol.last_active_timestamp - > self._ping_wait_time - ): - await self._protocol.ping() - except (errors.IdleSessionTimeoutError, errors.ClientConnectionError): - await self.connect() - - return await super().raw_query(query_context) - - -class _PoolConnectionHolder(base_client.PoolConnectionHolder): - __slots__ = () - _event_class = threading.Event - - async def close(self, *, wait=True, timeout=None): - if self._con is None: - return - await self._con.close(timeout=timeout) - - async def wait_until_released(self, timeout=None): - return self._release_event.wait(timeout) - - -class _PoolImpl(base_client.BasePoolImpl): - _holder_class = _PoolConnectionHolder - - def __init__( - self, - connect_args, - *, - max_concurrency: typing.Optional[int], - connection_class, - ): - if not issubclass(connection_class, BlockingIOConnection): - raise TypeError( - f'connection_class is expected to be a subclass of ' - f'edgedb.blocking_client.BlockingIOConnection, ' - f'got {connection_class}') - super().__init__( - connect_args, - connection_class, - max_concurrency=max_concurrency, - ) - - def _ensure_initialized(self): - if self._queue is None: - self._queue = queue.LifoQueue(maxsize=self._max_concurrency) - self._first_connect_lock = threading.Lock() - self._resize_holder_pool() - - def _set_queue_maxsize(self, maxsize): - with self._queue.mutex: - self._queue.maxsize = maxsize - - async def _maybe_get_first_connection(self): - with self._first_connect_lock: - if self._working_addr is None: - return await self._get_first_connection() - - async def acquire(self, timeout=None): - self._ensure_initialized() - - if self._closing: - raise errors.InterfaceError('pool is closing') - - ch = self._queue.get(timeout=timeout) - try: - con = await ch.acquire() - except Exception: - self._queue.put_nowait(ch) - raise - else: - # Record the timeout, as we will apply it by default - # in release(). - ch._timeout = timeout - return con - - async def _release(self, holder): - if not isinstance(holder._con, BlockingIOConnection): - raise errors.InterfaceError( - f'release() received invalid connection: ' - f'{holder._con!r} does not belong to any connection pool' - ) - - timeout = None - return await holder.release(timeout) - - async def close(self, timeout=None): - if self._closed: - return - self._closing = True - try: - if timeout is None: - for ch in self._holders: - await ch.wait_until_released() - for ch in self._holders: - await ch.close() - else: - deadline = time.monotonic() + timeout - for ch in self._holders: - secs = deadline - time.monotonic() - if secs <= 0: - raise TimeoutError - if not await ch.wait_until_released(secs): - raise TimeoutError - for ch in self._holders: - secs = deadline - time.monotonic() - if secs <= 0: - raise TimeoutError - await ch.close(timeout=secs) - except TimeoutError as e: - self.terminate() - raise errors.InterfaceError( - "client is not fully closed in {} seconds; " - "terminating now.".format(timeout) - ) from e - except Exception: - self.terminate() - raise - finally: - self._closed = True - self._closing = False - - -class Iteration(transaction.BaseTransaction, abstract.Executor): - - __slots__ = ("_managed", "_lock") - - def __init__(self, retry, client, iteration): - super().__init__(retry, client, iteration) - self._managed = False - self._lock = threading.Lock() - - def __enter__(self): - with self._exclusive(): - if self._managed: - raise errors.InterfaceError( - 'cannot enter context: already in a `with` block') - self._managed = True - return self - - def __exit__(self, extype, ex, tb): - with self._exclusive(): - self._managed = False - return self._client._iter_coroutine(self._exit(extype, ex)) - - async def _ensure_transaction(self): - if not self._managed: - raise errors.InterfaceError( - "Only managed retriable transactions are supported. " - "Use `with transaction:`" - ) - await super()._ensure_transaction() - - def _query(self, query_context: abstract.QueryContext): - with self._exclusive(): - return self._client._iter_coroutine(super()._query(query_context)) - - def _execute(self, execute_context: abstract.ExecuteContext) -> None: - with self._exclusive(): - self._client._iter_coroutine(super()._execute(execute_context)) - - @contextlib.contextmanager - def _exclusive(self): - if not self._lock.acquire(blocking=False): - raise errors.InterfaceError( - "concurrent queries within the same transaction " - "are not allowed" - ) - try: - yield - finally: - self._lock.release() - - -class Retry(transaction.BaseRetry): - - def __iter__(self): - return self - - def __next__(self): - # Note: when changing this code consider also - # updating AsyncIORetry.__anext__. - if self._done: - raise StopIteration - if self._next_backoff: - time.sleep(self._next_backoff) - self._done = True - iteration = Iteration(self, self._owner, self._iteration) - self._iteration += 1 - return iteration - - -class Client(base_client.BaseClient, abstract.Executor): - """A lazy connection pool. - - A Client can be used to manage a set of connections to the database. - Connections are first acquired from the pool, then used, and then released - back to the pool. Once a connection is released, it's reset to close all - open cursors and other resources *except* prepared statements. - - Clients are created by calling - :func:`~edgedb.blocking_client.create_client`. - """ - - __slots__ = () - _impl_class = _PoolImpl - - def _iter_coroutine(self, coro): - try: - coro.send(None) - except StopIteration as ex: - return ex.value - finally: - coro.close() - - def _query(self, query_context: abstract.QueryContext): - return self._iter_coroutine(super()._query(query_context)) - - def _execute(self, execute_context: abstract.ExecuteContext) -> None: - self._iter_coroutine(super()._execute(execute_context)) - - def ensure_connected(self): - self._iter_coroutine(self._impl.ensure_connected()) - return self - - def transaction(self) -> Retry: - return Retry(self) - - def close(self, timeout=None): - """Attempt to gracefully close all connections in the client. - - Wait until all pool connections are released, close them and - shut down the pool. If any error (including cancellation) occurs - in ``close()`` the pool will terminate by calling - Client.terminate() . - """ - self._iter_coroutine(self._impl.close(timeout)) - - def __enter__(self): - return self.ensure_connected() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def _describe_query( - self, - query: str, - *, - inject_type_names: bool = False, - input_language: InputLanguage = InputLanguage.EDGEQL, - output_format: OutputFormat = OutputFormat.BINARY, - expect_one: bool = False, - ) -> abstract.DescribeResult: - return self._iter_coroutine(self._describe(abstract.DescribeContext( - query=query, - state=self._get_state(), - inject_type_names=inject_type_names, - input_language=input_language, - output_format=output_format, - expect_one=expect_one, - ))) - - -def create_client( - dsn=None, - *, - max_concurrency=None, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - secret_key: str = None, - database: str = None, - branch: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, - wait_until_available: int = 30, - timeout: int = 10, -): - return Client( - connection_class=BlockingIOConnection, - max_concurrency=max_concurrency, - - # connect arguments - dsn=dsn, - host=host, - port=port, - credentials=credentials, - credentials_file=credentials_file, - user=user, - password=password, - secret_key=secret_key, - database=database, - branch=branch, - tls_ca=tls_ca, - tls_ca_file=tls_ca_file, - tls_security=tls_security, - wait_until_available=wait_until_available, - timeout=timeout, - ) +# Auto-generated shim +import gel.blocking_client as _mod +import sys as _sys +_cur = _sys.modules['edgedb.blocking_client'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/codegen.py b/edgedb/codegen.py new file mode 100644 index 00000000..7bae2079 --- /dev/null +++ b/edgedb/codegen.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.codegen as _mod +import sys as _sys +_cur = _sys.modules['edgedb.codegen'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/color.py b/edgedb/color.py index 2b972aaa..c761ee80 100644 --- a/edgedb/color.py +++ b/edgedb/color.py @@ -1,61 +1,11 @@ -import os -import sys -import warnings - -COLOR = None - - -class Color: - HEADER = "" - BLUE = "" - CYAN = "" - GREEN = "" - WARNING = "" - FAIL = "" - ENDC = "" - BOLD = "" - UNDERLINE = "" - - -def get_color() -> Color: - global COLOR - - if COLOR is None: - COLOR = Color() - if type(USE_COLOR) is bool: - use_color = USE_COLOR - else: - try: - use_color = USE_COLOR() - except Exception: - use_color = False - if use_color: - COLOR.HEADER = '\033[95m' - COLOR.BLUE = '\033[94m' - COLOR.CYAN = '\033[96m' - COLOR.GREEN = '\033[92m' - COLOR.WARNING = '\033[93m' - COLOR.FAIL = '\033[91m' - COLOR.ENDC = '\033[0m' - COLOR.BOLD = '\033[1m' - COLOR.UNDERLINE = '\033[4m' - - return COLOR - - -try: - USE_COLOR = { - "default": lambda: sys.stderr.isatty(), - "auto": lambda: sys.stderr.isatty(), - "enabled": True, - "disabled": False, - }[ - os.getenv("EDGEDB_COLOR_OUTPUT", "default") - ] -except KeyError: - warnings.warn( - "EDGEDB_COLOR_OUTPUT can only be one of: " - "default, auto, enabled or disabled", - stacklevel=1, - ) - USE_COLOR = False +# Auto-generated shim +import gel.color as _mod +import sys as _sys +_cur = _sys.modules['edgedb.color'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py index 0863ead9..cad210b4 100644 --- a/edgedb/con_utils.py +++ b/edgedb/con_utils.py @@ -1,1278 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import base64 -import binascii -import errno -import json -import os -import re -import ssl -import typing -import urllib.parse -import warnings -import hashlib - -from . import errors -from . import credentials as cred_utils -from . import platform - - -EDGEDB_PORT = 5656 -ERRNO_RE = re.compile(r"\[Errno (\d+)\]") -TEMPORARY_ERRORS = ( - ConnectionAbortedError, - ConnectionRefusedError, - ConnectionResetError, - FileNotFoundError, -) -TEMPORARY_ERROR_CODES = frozenset({ - errno.ECONNREFUSED, - errno.ECONNABORTED, - errno.ECONNRESET, - errno.ENOENT, -}) - -ISO_SECONDS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S') -ISO_MINUTES_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') -ISO_HOURS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H') -ISO_UNITLESS_HOURS_RE = re.compile(r'^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$') -ISO_DAYS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D') -ISO_WEEKS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W') -ISO_MONTHS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') -ISO_YEARS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y') - -HUMAN_HOURS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))', -) -HUMAN_MINUTES_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))', -) -HUMAN_SECONDS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))', -) -HUMAN_MS_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))', -) -HUMAN_US_RE = re.compile( - r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))', -) -INSTANCE_NAME_RE = re.compile( - r'^(\w(?:-?\w)*)$', - re.ASCII, -) -CLOUD_INSTANCE_NAME_RE = re.compile( - r'^([A-Za-z0-9_-](?:-?[A-Za-z0-9_])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$', - re.ASCII, -) -DSN_RE = re.compile( - r'^[a-z]+://', - re.IGNORECASE, -) -DOMAIN_LABEL_MAX_LENGTH = 63 - - -class ClientConfiguration(typing.NamedTuple): - - connect_timeout: float - command_timeout: float - wait_until_available: float - - -def _validate_port_spec(hosts, port): - if isinstance(port, list): - # If there is a list of ports, its length must - # match that of the host list. - if len(port) != len(hosts): - raise errors.InterfaceError( - 'could not match {} port numbers to {} hosts'.format( - len(port), len(hosts))) - else: - port = [port for _ in range(len(hosts))] - - return port - - -def _parse_hostlist(hostlist, port): - if ',' in hostlist: - # A comma-separated list of host addresses. - hostspecs = hostlist.split(',') - else: - hostspecs = [hostlist] - - hosts = [] - hostlist_ports = [] - - if not port: - portspec = _getenv('PORT') - if portspec: - if ',' in portspec: - default_port = [int(p) for p in portspec.split(',')] - else: - default_port = int(portspec) - else: - default_port = EDGEDB_PORT - - default_port = _validate_port_spec(hostspecs, default_port) - - else: - port = _validate_port_spec(hostspecs, port) - - for i, hostspec in enumerate(hostspecs): - addr, _, hostspec_port = hostspec.partition(':') - hosts.append(addr) - - if not port: - if hostspec_port: - hostlist_ports.append(int(hostspec_port)) - else: - hostlist_ports.append(default_port[i]) - - if not port: - port = hostlist_ports - - return hosts, port - - -def _hash_path(path): - path = os.path.realpath(path) - if platform.IS_WINDOWS and not path.startswith('\\\\'): - path = '\\\\?\\' + path - return hashlib.sha1(str(path).encode('utf-8')).hexdigest() - - -def _stash_path(path): - base_name = os.path.basename(path) - dir_name = base_name + '-' + _hash_path(path) - return platform.search_config_dir('projects', dir_name) - - -def _validate_tls_security(val: str) -> str: - val = val.lower() - if val not in {"insecure", "no_host_verification", "strict", "default"}: - raise ValueError( - "tls_security can only be one of " - "`insecure`, `no_host_verification`, `strict` or `default`" - ) - - return val - - -def _getenv_and_key(key: str) -> typing.Tuple[typing.Optional[str], str]: - edgedb_key = f'EDGEDB_{key}' - edgedb_val = os.getenv(edgedb_key) - gel_key = f'GEL_{key}' - gel_val = os.getenv(gel_key) - if edgedb_val is not None and gel_val is not None: - warnings.warn( - f'Both {gel_key} and {edgedb_key} are set; ' - f'{edgedb_key} will be ignored', - stacklevel=1, - ) - - if gel_val is None and edgedb_val is not None: - return edgedb_val, edgedb_key - else: - return gel_val, gel_key - - -def _getenv(key: str) -> typing.Optional[str]: - return _getenv_and_key(key)[0] - - -class ResolvedConnectConfig: - _host = None - _host_source = None - - _port = None - _port_source = None - - # We keep track of database and branch separately, because we want to make - # sure that we don't use both at the same time on the same configuration - # level. - _database = None - _database_source = None - - _branch = None - _branch_source = None - - _user = None - _user_source = None - - _password = None - _password_source = None - - _secret_key = None - _secret_key_source = None - - _tls_ca_data = None - _tls_ca_data_source = None - - _tls_server_name = None - _tls_security = None - _tls_security_source = None - - _wait_until_available = None - - _cloud_profile = None - _cloud_profile_source = None - - server_settings = {} - - def _set_param(self, param, value, source, validator=None): - param_name = '_' + param - if getattr(self, param_name) is None: - setattr(self, param_name + '_source', source) - if value is not None: - setattr( - self, - param_name, - validator(value) if validator else value - ) - - def set_host(self, host, source): - self._set_param('host', host, source, _validate_host) - - def set_port(self, port, source): - self._set_param('port', port, source, _validate_port) - - def set_database(self, database, source): - self._set_param('database', database, source, _validate_database) - - def set_branch(self, branch, source): - self._set_param('branch', branch, source, _validate_branch) - - def set_user(self, user, source): - self._set_param('user', user, source, _validate_user) - - def set_password(self, password, source): - self._set_param('password', password, source) - - def set_secret_key(self, secret_key, source): - self._set_param('secret_key', secret_key, source) - - def set_tls_ca_data(self, ca_data, source): - self._set_param('tls_ca_data', ca_data, source) - - def set_tls_ca_file(self, ca_file, source): - def read_ca_file(file_path): - with open(file_path) as f: - return f.read() - - self._set_param('tls_ca_data', ca_file, source, read_ca_file) - - def set_tls_server_name(self, ca_data, source): - self._set_param('tls_server_name', ca_data, source) - - def set_tls_security(self, security, source): - self._set_param('tls_security', security, source, - _validate_tls_security) - - def set_wait_until_available(self, wait_until_available, source): - self._set_param( - 'wait_until_available', - wait_until_available, - source, - _validate_wait_until_available, - ) - - def add_server_settings(self, server_settings): - _validate_server_settings(server_settings) - self.server_settings = {**server_settings, **self.server_settings} - - @property - def address(self): - return ( - self._host if self._host else 'localhost', - self._port if self._port else 5656 - ) - - # The properties actually merge database and branch, but "default" is - # different. If you need to know the underlying config use the _database - # and _branch. - @property - def database(self): - return ( - self._database if self._database else - self._branch if self._branch else - 'edgedb' - ) - - @property - def branch(self): - return ( - self._database if self._database else - self._branch if self._branch else - '__default__' - ) - - @property - def user(self): - return self._user if self._user else 'edgedb' - - @property - def password(self): - return self._password - - @property - def secret_key(self): - return self._secret_key - - @property - def tls_server_name(self): - return self._tls_server_name - - @property - def tls_security(self): - tls_security = self._tls_security or 'default' - security, security_key = _getenv_and_key('CLIENT_SECURITY') - security = security or 'default' - if security not in {'default', 'insecure_dev_mode', 'strict'}: - raise ValueError( - f'environment variable {security_key} should be ' - f'one of strict, insecure_dev_mode or default, ' - f'got: {security!r}') - - if security == 'default': - pass - elif security == 'insecure_dev_mode': - if tls_security == 'default': - tls_security = 'insecure' - elif security == 'strict': - if tls_security == 'default': - tls_security = 'strict' - elif tls_security in {'no_host_verification', 'insecure'}: - raise ValueError( - f'{security_key}=strict but ' - f'tls_security={tls_security}, tls_security must be ' - f'set to strict when {security_key} is strict') - - if tls_security != 'default': - return tls_security - - if self._tls_ca_data is not None: - return "no_host_verification" - - return "strict" - - _ssl_ctx = None - - @property - def ssl_ctx(self): - if (self._ssl_ctx): - return self._ssl_ctx - - self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - - if self._tls_ca_data: - self._ssl_ctx.load_verify_locations( - cadata=self._tls_ca_data - ) - else: - self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) - if platform.IS_WINDOWS: - import certifi - self._ssl_ctx.load_verify_locations(cafile=certifi.where()) - - tls_security = self.tls_security - self._ssl_ctx.check_hostname = tls_security == "strict" - - if tls_security in {"strict", "no_host_verification"}: - self._ssl_ctx.verify_mode = ssl.CERT_REQUIRED - else: - self._ssl_ctx.verify_mode = ssl.CERT_NONE - - self._ssl_ctx.set_alpn_protocols(['edgedb-binary']) - - return self._ssl_ctx - - @property - def wait_until_available(self): - return ( - self._wait_until_available - if self._wait_until_available is not None - else 30 - ) - - -def _validate_host(host): - if '/' in host: - raise ValueError('unix socket paths not supported') - if host == '' or ',' in host: - raise ValueError(f'invalid host: "{host}"') - return host - - -def _prepare_host_for_dsn(host): - host = _validate_host(host) - if ':' in host: - # IPv6 - host = f'[{host}]' - return host - - -def _validate_port(port): - try: - if isinstance(port, str): - port = int(port) - if not isinstance(port, int): - raise ValueError() - except Exception: - raise ValueError(f'invalid port: {port}, not an integer') - if port < 1 or port > 65535: - raise ValueError(f'invalid port: {port}, must be between 1 and 65535') - return port - - -def _validate_database(database): - if database == '': - raise ValueError(f'invalid database name: {database}') - return database - - -def _validate_branch(branch): - if branch == '': - raise ValueError(f'invalid branch name: {branch}') - return branch - - -def _validate_user(user): - if user == '': - raise ValueError(f'invalid user name: {user}') - return user - - -def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]: - s = string - total = 0 - match = rgex.search(string) - if match: - total += float(match.group(1)) - s = s.replace(match.group(0), "", 1) - - return (total, s) - - -def _parse_iso_duration(string: str) -> typing.Union[float, int]: - if not string.startswith("PT"): - raise ValueError(f"invalid duration {string!r}") - - time = string[2:] - match = ISO_UNITLESS_HOURS_RE.search(time) - if match: - hours = float(match.group(0)) - return 3600 * hours - - hours, time = _pop_iso_unit(ISO_HOURS_RE, time) - minutes, time = _pop_iso_unit(ISO_MINUTES_RE, time) - seconds, time = _pop_iso_unit(ISO_SECONDS_RE, time) - - if time: - raise ValueError(f'invalid duration {string!r}') - - return 3600 * hours + 60 * minutes + seconds - - -def _remove_white_space(s: str) -> str: - return ''.join(c for c in s if not c.isspace()) - - -def _pop_human_duration_unit( - rgex: re.Pattern, - string: str, -) -> typing.Tuple[float, bool, str]: - match = rgex.search(string) - if not match: - return 0, False, string - - number = 0 - if match.group(1): - literal = _remove_white_space(match.group(1)) - if literal.endswith('.'): - return 0, False, string - - if literal.startswith('-.'): - return 0, False, string - - number = float(literal) - string = string.replace( - match.group(0), - match.group(2) or match.group(3) or "", - 1, - ) - - return number, True, string - - -def _parse_human_duration(string: str) -> float: - found = False - - hour, f, s = _pop_human_duration_unit(HUMAN_HOURS_RE, string) - found |= f - - minute, f, s = _pop_human_duration_unit(HUMAN_MINUTES_RE, s) - found |= f - - second, f, s = _pop_human_duration_unit(HUMAN_SECONDS_RE, s) - found |= f - - ms, f, s = _pop_human_duration_unit(HUMAN_MS_RE, s) - found |= f - - us, f, s = _pop_human_duration_unit(HUMAN_US_RE, s) - found |= f - - if s.strip() or not found: - raise ValueError(f'invalid duration {string!r}') - - return 3600 * hour + 60 * minute + second + 0.001 * ms + 0.000001 * us - - -def _parse_duration_str(string: str) -> float: - if string.startswith('PT'): - return _parse_iso_duration(string) - return _parse_human_duration(string) - - -def _validate_wait_until_available(wait_until_available): - if isinstance(wait_until_available, str): - return _parse_duration_str(wait_until_available) - - if isinstance(wait_until_available, (int, float)): - return wait_until_available - - raise ValueError(f"invalid duration {wait_until_available!r}") - - -def _validate_server_settings(server_settings): - if ( - not isinstance(server_settings, dict) or - not all(isinstance(k, str) for k in server_settings) or - not all(isinstance(v, str) for v in server_settings.values()) - ): - raise ValueError( - 'server_settings is expected to be None or ' - 'a Dict[str, str]') - - -def _parse_connect_dsn_and_args( - *, - dsn, - host, - port, - credentials, - credentials_file, - user, - password, - secret_key, - database, - branch, - tls_ca, - tls_ca_file, - tls_security, - tls_server_name, - server_settings, - wait_until_available, -): - resolved_config = ResolvedConnectConfig() - - if dsn and DSN_RE.match(dsn): - instance_name = None - else: - instance_name, dsn = dsn, None - - def _get(key: str) -> typing.Optional[typing.Tuple[str, str]]: - val, env = _getenv_and_key(key) - return ( - (val, f'"{env}" environment variable') - if val is not None else None - ) - - # The cloud profile is potentially relevant to resolving credentials at - # any stage, including the config stage when other environment variables - # are not yet read. - cloud_profile_tuple = _get('CLOUD_PROFILE') - cloud_profile = cloud_profile_tuple[0] if cloud_profile_tuple else None - - has_compound_options = _resolve_config_options( - resolved_config, - 'Cannot have more than one of the following connection options: ' - + '"dsn", "credentials", "credentials_file" or "host"/"port"', - dsn=(dsn, '"dsn" option') if dsn is not None else None, - instance_name=( - (instance_name, '"dsn" option (parsed as instance name)') - if instance_name is not None else None - ), - credentials=( - (credentials, '"credentials" option') - if credentials is not None else None - ), - credentials_file=( - (credentials_file, '"credentials_file" option') - if credentials_file is not None else None - ), - host=(host, '"host" option') if host is not None else None, - port=(port, '"port" option') if port is not None else None, - database=( - (database, '"database" option') - if database is not None else None - ), - branch=( - (branch, '"branch" option') - if branch is not None else None - ), - user=(user, '"user" option') if user is not None else None, - password=( - (password, '"password" option') - if password is not None else None - ), - secret_key=( - (secret_key, '"secret_key" option') - if secret_key is not None else None - ), - tls_ca=( - (tls_ca, '"tls_ca" option') - if tls_ca is not None else None - ), - tls_ca_file=( - (tls_ca_file, '"tls_ca_file" option') - if tls_ca_file is not None else None - ), - tls_security=( - (tls_security, '"tls_security" option') - if tls_security is not None else None - ), - tls_server_name=( - (tls_server_name, '"tls_server_name" option') - if tls_server_name is not None else None - ), - server_settings=( - (server_settings, '"server_settings" option') - if server_settings is not None else None - ), - wait_until_available=( - (wait_until_available, '"wait_until_available" option') - if wait_until_available is not None else None - ), - cloud_profile=cloud_profile_tuple, - ) - - if has_compound_options is False: - env_port_tuple = _get("PORT") - if ( - resolved_config._port is None - and env_port_tuple - and env_port_tuple[0].startswith('tcp://') - ): - # EDGEDB_PORT is set by 'docker --link' so ignore and warn - warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' + - 'so will be ignored', stacklevel=1) - env_port_tuple = None - - has_compound_options = _resolve_config_options( - resolved_config, - # XXX - 'Cannot have more than one of the following connection ' - + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", ' - + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"', - dsn=_get('DSN'), - instance_name=_get('INSTANCE'), - credentials_file=_get('CREDENTIALS_FILE'), - host=_get('HOST'), - port=env_port_tuple, - database=_get('DATABASE'), - branch=_get('BRANCH'), - user=_get('USER'), - password=_get('PASSWORD'), - secret_key=_get('SECRET_KEY'), - tls_ca=_get('TLS_CA'), - tls_ca_file=_get('TLS_CA_FILE'), - tls_security=_get('CLIENT_TLS_SECURITY'), - tls_server_name=_get('TLS_SERVER_NAME'), - wait_until_available=_get('WAIT_UNTIL_AVAILABLE'), - ) - - if not has_compound_options: - dir = find_edgedb_project_dir() - stash_dir = _stash_path(dir) - if os.path.exists(stash_dir): - with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: - instance_name = f.read().strip() - cloud_profile_file = os.path.join(stash_dir, 'cloud-profile') - if os.path.exists(cloud_profile_file): - with open(cloud_profile_file, 'rt') as f: - cloud_profile = f.read().strip() - else: - cloud_profile = None - - _resolve_config_options( - resolved_config, - '', - instance_name=( - instance_name, - f'project linked instance ("{instance_name}")' - ), - cloud_profile=( - cloud_profile, - f'project defined cloud profile ("{cloud_profile}")' - ), - ) - - opt_database_file = os.path.join(stash_dir, 'database') - if os.path.exists(opt_database_file): - with open(opt_database_file, 'rt') as f: - database = f.read().strip() - resolved_config.set_database(database, "project") - else: - raise errors.ClientConnectionError( - f'Found `edgedb.toml` but the project is not initialized. ' - f'Run `edgedb project init`.' - ) - - return resolved_config - - -def _parse_dsn_into_config( - resolved_config: ResolvedConnectConfig, - dsn: typing.Tuple[str, str] -): - dsn_str, source = dsn - - try: - parsed = urllib.parse.urlparse(dsn_str) - host = ( - urllib.parse.unquote(parsed.hostname) if parsed.hostname else None - ) - port = parsed.port - database = parsed.path - user = parsed.username - password = parsed.password - except Exception as e: - raise ValueError(f'invalid DSN or instance name: {str(e)}') - - if parsed.scheme != 'edgedb': - raise ValueError( - f'invalid DSN: scheme is expected to be ' - f'"edgedb", got {parsed.scheme!r}') - - query = ( - urllib.parse.parse_qs(parsed.query, keep_blank_values=True) - if parsed.query != '' - else {} - ) - for key, val in query.items(): - if isinstance(val, list): - if len(val) > 1: - raise ValueError( - f'invalid DSN: duplicate query parameter {key}') - query[key] = val[-1] - - def handle_dsn_part( - paramName, value, currentValue, setter, - formatter=lambda val: val - ): - param_values = [ - (value if value != '' else None), - query.get(paramName), - query.get(paramName + '_env'), - query.get(paramName + '_file') - ] - if len([p for p in param_values if p is not None]) > 1: - raise ValueError( - f'invalid DSN: more than one of ' + - f'{(paramName + ", ") if value else ""}' + - f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' + - f'was specified' - ) - - if currentValue is None: - param = ( - value if (value is not None and value != '') - else query.get(paramName) - ) - paramSource = source - - if param is None: - env = query.get(paramName + '_env') - if env is not None: - param = os.getenv(env) - if param is None: - raise ValueError( - f'{paramName}_env environment variable "{env}" ' + - f'doesn\'t exist') - paramSource = paramSource + f' ({paramName}_env: {env})' - if param is None: - filename = query.get(paramName + '_file') - if filename is not None: - with open(filename) as f: - param = f.read() - paramSource = ( - paramSource + f' ({paramName}_file: {filename})' - ) - - param = formatter(param) if param is not None else None - - setter(param, paramSource) - - query.pop(paramName, None) - query.pop(paramName + '_env', None) - query.pop(paramName + '_file', None) - - handle_dsn_part( - 'host', host, resolved_config._host, resolved_config.set_host - ) - - handle_dsn_part( - 'port', port, resolved_config._port, resolved_config.set_port - ) - - def strip_leading_slash(str): - return str[1:] if str.startswith('/') else str - - if ( - 'branch' in query or - 'branch_env' in query or - 'branch_file' in query - ): - if ( - 'database' in query or - 'database_env' in query or - 'database_file' in query - ): - raise ValueError( - f"invalid DSN: `database` and `branch` cannot be present " - f"at the same time" - ) - if resolved_config._database is None: - # Only update the config if 'database' has not been already - # resolved. - handle_dsn_part( - 'branch', strip_leading_slash(database), - resolved_config._branch, resolved_config.set_branch, - strip_leading_slash - ) - else: - # Clean up the query, if config already has 'database' - query.pop('branch', None) - query.pop('branch_env', None) - query.pop('branch_file', None) - - else: - if resolved_config._branch is None: - # Only update the config if 'branch' has not been already - # resolved. - handle_dsn_part( - 'database', strip_leading_slash(database), - resolved_config._database, resolved_config.set_database, - strip_leading_slash - ) - else: - # Clean up the query, if config already has 'branch' - query.pop('database', None) - query.pop('database_env', None) - query.pop('database_file', None) - - handle_dsn_part( - 'user', user, resolved_config._user, resolved_config.set_user - ) - - handle_dsn_part( - 'password', password, - resolved_config._password, resolved_config.set_password - ) - - handle_dsn_part( - 'secret_key', None, - resolved_config._secret_key, resolved_config.set_secret_key - ) - - handle_dsn_part( - 'tls_ca_file', None, - resolved_config._tls_ca_data, resolved_config.set_tls_ca_file - ) - - handle_dsn_part( - 'tls_server_name', None, - resolved_config._tls_server_name, - resolved_config.set_tls_server_name - ) - - handle_dsn_part( - 'tls_security', None, - resolved_config._tls_security, - resolved_config.set_tls_security - ) - - handle_dsn_part( - 'wait_until_available', None, - resolved_config._wait_until_available, - resolved_config.set_wait_until_available - ) - - resolved_config.add_server_settings(query) - - -def _jwt_base64_decode(payload): - remainder = len(payload) % 4 - if remainder == 2: - payload += '==' - elif remainder == 3: - payload += '=' - elif remainder != 0: - raise errors.ClientConnectionError("Invalid secret key") - payload = base64.urlsafe_b64decode(payload.encode("utf-8")) - return json.loads(payload.decode("utf-8")) - - -def _parse_cloud_instance_name_into_config( - resolved_config: ResolvedConnectConfig, - source: str, - org_slug: str, - instance_name: str, -): - org_slug = org_slug.lower() - instance_name = instance_name.lower() - - label = f"{instance_name}--{org_slug}" - if len(label) > DOMAIN_LABEL_MAX_LENGTH: - raise ValueError( - f"invalid instance name: cloud instance name length cannot exceed " - f"{DOMAIN_LABEL_MAX_LENGTH - 1} characters: " - f"{org_slug}/{instance_name}" - ) - secret_key = resolved_config.secret_key - if secret_key is None: - try: - config_dir = platform.config_dir() - if resolved_config._cloud_profile is None: - profile = profile_src = "default" - else: - profile = resolved_config._cloud_profile - profile_src = resolved_config._cloud_profile_source - path = config_dir / "cloud-credentials" / f"{profile}.json" - with open(path, "rt") as f: - secret_key = json.load(f)["secret_key"] - except Exception: - raise errors.ClientConnectionError( - "Cannot connect to cloud instances without secret key." - ) - resolved_config.set_secret_key( - secret_key, - f"cloud-credentials/{profile}.json specified by {profile_src}", - ) - try: - dns_zone = _jwt_base64_decode(secret_key.split(".", 2)[1])["iss"] - except errors.EdgeDBError: - raise - except Exception: - raise errors.ClientConnectionError("Invalid secret key") - payload = f"{org_slug}/{instance_name}".encode("utf-8") - dns_bucket = binascii.crc_hqx(payload, 0) % 100 - host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}" - resolved_config.set_host(host, source) - - -def _resolve_config_options( - resolved_config: ResolvedConnectConfig, - compound_error: str, - *, - dsn=None, - instance_name=None, - credentials=None, - credentials_file=None, - host=None, - port=None, - database=None, - branch=None, - user=None, - password=None, - secret_key=None, - tls_ca=None, - tls_ca_file=None, - tls_security=None, - tls_server_name=None, - server_settings=None, - wait_until_available=None, - cloud_profile=None, -): - if database is not None: - if branch is not None: - raise errors.ClientConnectionError( - f"{database[1]} and {branch[1]} are mutually exclusive" - ) - if resolved_config._branch is None: - # Only update the config if 'branch' has not been already - # resolved. - resolved_config.set_database(*database) - if branch is not None: - if resolved_config._database is None: - # Only update the config if 'database' has not been already - # resolved. - resolved_config.set_branch(*branch) - if user is not None: - resolved_config.set_user(*user) - if password is not None: - resolved_config.set_password(*password) - if secret_key is not None: - resolved_config.set_secret_key(*secret_key) - if tls_ca_file is not None: - if tls_ca is not None: - raise errors.ClientConnectionError( - f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive" - ) - resolved_config.set_tls_ca_file(*tls_ca_file) - if tls_ca is not None: - resolved_config.set_tls_ca_data(*tls_ca) - if tls_security is not None: - resolved_config.set_tls_security(*tls_security) - if tls_server_name is not None: - resolved_config.set_tls_server_name(*tls_server_name) - if server_settings is not None: - resolved_config.add_server_settings(server_settings[0]) - if wait_until_available is not None: - resolved_config.set_wait_until_available(*wait_until_available) - if cloud_profile is not None: - resolved_config._set_param('cloud_profile', *cloud_profile) - - compound_params = [ - dsn, - instance_name, - credentials, - credentials_file, - host or port, - ] - compound_params_count = len([p for p in compound_params if p is not None]) - - if compound_params_count > 1: - raise errors.ClientConnectionError(compound_error) - - elif compound_params_count == 1: - if dsn is not None or host is not None or port is not None: - if port is not None: - resolved_config.set_port(*port) - if dsn is None: - dsn = ( - 'edgedb://' + - (_prepare_host_for_dsn(host[0]) if host else ''), - host[1] if host is not None else port[1] - ) - _parse_dsn_into_config(resolved_config, dsn) - else: - if credentials_file is not None: - creds = cred_utils.read_credentials(credentials_file[0]) - source = "credentials" - elif credentials is not None: - try: - cred_data = json.loads(credentials[0]) - except ValueError as e: - raise RuntimeError(f"cannot read credentials") from e - else: - creds = cred_utils.validate_credentials(cred_data) - source = "credentials" - elif INSTANCE_NAME_RE.match(instance_name[0]): - source = instance_name[1] - creds = cred_utils.read_credentials( - cred_utils.get_credentials_path(instance_name[0]), - ) - else: - name_match = CLOUD_INSTANCE_NAME_RE.match(instance_name[0]) - if name_match is None: - raise ValueError( - f'invalid DSN or instance name: "{instance_name[0]}"' - ) - source = instance_name[1] - org, inst = name_match.groups() - _parse_cloud_instance_name_into_config( - resolved_config, source, org, inst - ) - return True - - resolved_config.set_host(creds.get('host'), source) - resolved_config.set_port(creds.get('port'), source) - if 'database' in creds and resolved_config._branch is None: - # Only update the config if 'branch' has not been already - # resolved. - resolved_config.set_database(creds.get('database'), source) - - elif 'branch' in creds and resolved_config._database is None: - # Only update the config if 'database' has not been already - # resolved. - resolved_config.set_branch(creds.get('branch'), source) - resolved_config.set_user(creds.get('user'), source) - resolved_config.set_password(creds.get('password'), source) - resolved_config.set_tls_ca_data(creds.get('tls_ca'), source) - resolved_config.set_tls_security( - creds.get('tls_security'), - source - ) - - return True - - else: - return False - - -def find_edgedb_project_dir(): - dir = os.getcwd() - dev = os.stat(dir).st_dev - - while True: - gel_toml = os.path.join(dir, 'gel.toml') - edgedb_toml = os.path.join(dir, 'edgedb.toml') - if not os.path.isfile(gel_toml) and not os.path.isfile(edgedb_toml): - parent = os.path.dirname(dir) - if parent == dir: - raise errors.ClientConnectionError( - f'no `gel.toml` found and ' - f'no connection options specified' - ) - parent_dev = os.stat(parent).st_dev - if parent_dev != dev: - raise errors.ClientConnectionError( - f'no `gel.toml` found and ' - f'no connection options specified' - f'(stopped searching for `edgedb.toml` at file system' - f'boundary {dir!r})' - ) - dir = parent - dev = parent_dev - continue - return dir - - -def parse_connect_arguments( - *, - dsn, - host, - port, - credentials, - credentials_file, - database, - branch, - user, - password, - secret_key, - tls_ca, - tls_ca_file, - tls_security, - tls_server_name, - timeout, - command_timeout, - wait_until_available, - server_settings, -) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]: - - if command_timeout is not None: - try: - if isinstance(command_timeout, bool): - raise ValueError - command_timeout = float(command_timeout) - if command_timeout <= 0: - raise ValueError - except ValueError: - raise ValueError( - 'invalid command_timeout value: ' - 'expected greater than 0 float (got {!r})'.format( - command_timeout)) from None - - connect_config = _parse_connect_dsn_and_args( - dsn=dsn, - host=host, - port=port, - credentials=credentials, - credentials_file=credentials_file, - database=database, - branch=branch, - user=user, - password=password, - secret_key=secret_key, - tls_ca=tls_ca, - tls_ca_file=tls_ca_file, - tls_security=tls_security, - tls_server_name=tls_server_name, - server_settings=server_settings, - wait_until_available=wait_until_available, - ) - - client_config = ClientConfiguration( - connect_timeout=timeout, - command_timeout=command_timeout, - wait_until_available=connect_config.wait_until_available, - ) - - return connect_config, client_config - - -def check_alpn_protocol(ssl_obj): - if ssl_obj.selected_alpn_protocol() != 'edgedb-binary': - raise errors.ClientConnectionFailedError( - "The server doesn't support the edgedb-binary protocol." - ) - - -def render_client_no_connection_error(prefix, addr, attempts, duration): - if isinstance(addr, str): - msg = ( - f'{prefix}' - f'\n\tAfter {attempts} attempts in {duration:.1f} sec' - f'\n\tIs the server running locally and accepting ' - f'\n\tconnections on Unix domain socket {addr!r}?' - ) - else: - msg = ( - f'{prefix}' - f'\n\tAfter {attempts} attempts in {duration:.1f} sec' - f'\n\tIs the server running on host {addr[0]!r} ' - f'and accepting ' - f'\n\tTCP/IP connections on port {addr[1]}?' - ) - return msg - - -def _extract_errno(s): - """Extract multiple errnos from error string - - When we connect to a host that has multiple underlying IP addresses, say - ``localhost`` having ``::1`` and ``127.0.0.1``, we get - ``OSError("Multiple exceptions:...")`` error without ``.errno`` attribute - set. There are multiple ones in the text, so we extract all of them. - """ - result = [] - for match in ERRNO_RE.finditer(s): - result.append(int(match.group(1))) - if result: - return result - - -def wrap_error(e): - message = str(e) - if e.errno is None: - errnos = _extract_errno(message) - else: - errnos = [e.errno] - - if errnos: - is_temp = any((code in TEMPORARY_ERROR_CODES for code in errnos)) - else: - is_temp = isinstance(e, TEMPORARY_ERRORS) - - if is_temp: - return errors.ClientConnectionFailedTemporarilyError(message) - else: - return errors.ClientConnectionFailedError(message) +# Auto-generated shim +import gel.con_utils as _mod +import sys as _sys +_cur = _sys.modules['edgedb.con_utils'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/credentials.py b/edgedb/credentials.py index 71146fd9..ea5f0390 100644 --- a/edgedb/credentials.py +++ b/edgedb/credentials.py @@ -1,119 +1,11 @@ -import os -import pathlib -import typing -import json - -from . import platform - - -class RequiredCredentials(typing.TypedDict, total=True): - port: int - user: str - - -class Credentials(RequiredCredentials, total=False): - host: typing.Optional[str] - password: typing.Optional[str] - # It's OK for database and branch to appear in credentials, as long as - # they match. - database: typing.Optional[str] - branch: typing.Optional[str] - tls_ca: typing.Optional[str] - tls_security: typing.Optional[str] - - -def get_credentials_path(instance_name: str) -> pathlib.Path: - return platform.search_config_dir("credentials", instance_name + ".json") - - -def read_credentials(path: os.PathLike) -> Credentials: - try: - with open(path, encoding='utf-8') as f: - credentials = json.load(f) - return validate_credentials(credentials) - except Exception as e: - raise RuntimeError( - f"cannot read credentials at {path}" - ) from e - - -def validate_credentials(data: dict) -> Credentials: - port = data.get('port') - if port is None: - port = 5656 - if not isinstance(port, int) or port < 1 or port > 65535: - raise ValueError("invalid `port` value") - - user = data.get('user') - if user is None: - raise ValueError("`user` key is required") - if not isinstance(user, str): - raise ValueError("`user` must be a string") - - result = { # required keys - "user": user, - "port": port, - } - - host = data.get('host') - if host is not None: - if not isinstance(host, str): - raise ValueError("`host` must be a string") - result['host'] = host - - database = data.get('database') - if database is not None: - if not isinstance(database, str): - raise ValueError("`database` must be a string") - result['database'] = database - - branch = data.get('branch') - if branch is not None: - if not isinstance(branch, str): - raise ValueError("`branch` must be a string") - if database is not None and branch != database: - raise ValueError( - f"`database` and `branch` cannot be different") - result['branch'] = branch - - password = data.get('password') - if password is not None: - if not isinstance(password, str): - raise ValueError("`password` must be a string") - result['password'] = password - - ca = data.get('tls_ca') - if ca is not None: - if not isinstance(ca, str): - raise ValueError("`tls_ca` must be a string") - result['tls_ca'] = ca - - cert_data = data.get('tls_cert_data') - if cert_data is not None: - if not isinstance(cert_data, str): - raise ValueError("`tls_cert_data` must be a string") - if ca is not None and ca != cert_data: - raise ValueError( - f"tls_ca and tls_cert_data are both set and disagree") - result['tls_ca'] = cert_data - - verify = data.get('tls_verify_hostname') - if verify is not None: - if not isinstance(verify, bool): - raise ValueError("`tls_verify_hostname` must be a bool") - result['tls_security'] = 'strict' if verify else 'no_host_verification' - - tls_security = data.get('tls_security') - if tls_security is not None: - if not isinstance(tls_security, str): - raise ValueError("`tls_security` must be a string") - result['tls_security'] = tls_security - - missmatch = ValueError(f"tls_verify_hostname={verify} and " - f"tls_security={tls_security} are incompatible") - if tls_security == "strict" and verify is False: - raise missmatch - if tls_security in {"no_host_verification", "insecure"} and verify is True: - raise missmatch - - return result +# Auto-generated shim +import gel.credentials as _mod +import sys as _sys +_cur = _sys.modules['edgedb.credentials'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/datatypes/__init__.py b/edgedb/datatypes/__init__.py index 73609285..6da01f78 100644 --- a/edgedb/datatypes/__init__.py +++ b/edgedb/datatypes/__init__.py @@ -1,17 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# +# Auto-generated shim +import gel.datatypes as _mod +import sys as _sys +_cur = _sys.modules['edgedb.datatypes'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/datatypes/datatypes.py b/edgedb/datatypes/datatypes.py new file mode 100644 index 00000000..05dc5979 --- /dev/null +++ b/edgedb/datatypes/datatypes.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.datatypes.datatypes as _mod +import sys as _sys +_cur = _sys.modules['edgedb.datatypes.datatypes'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/datatypes/range.py b/edgedb/datatypes/range.py index e3fd3d1e..29e3c8cf 100644 --- a/edgedb/datatypes/range.py +++ b/edgedb/datatypes/range.py @@ -1,165 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator, - Sequence) - -T = TypeVar("T") - - -class Range(Generic[T]): - - __slots__ = ("_lower", "_upper", "_inc_lower", "_inc_upper", "_empty") - - def __init__( - self, - lower: Optional[T] = None, - upper: Optional[T] = None, - *, - inc_lower: bool = True, - inc_upper: bool = False, - empty: bool = False, - ) -> None: - self._empty = empty - - if empty: - if ( - lower != upper - or lower is not None and inc_upper and inc_lower - ): - raise ValueError( - "conflicting arguments in range constructor: " - "\"empty\" is `true` while the specified bounds " - "suggest otherwise" - ) - - self._lower = self._upper = None - self._inc_lower = self._inc_upper = False - else: - self._lower = lower - self._upper = upper - self._inc_lower = lower is not None and inc_lower - self._inc_upper = upper is not None and inc_upper - - @property - def lower(self) -> Optional[T]: - return self._lower - - @property - def inc_lower(self) -> bool: - return self._inc_lower - - @property - def upper(self) -> Optional[T]: - return self._upper - - @property - def inc_upper(self) -> bool: - return self._inc_upper - - def is_empty(self) -> bool: - return self._empty - - def __bool__(self): - return not self.is_empty() - - def __eq__(self, other) -> bool: - if isinstance(other, Range): - o = other - else: - return NotImplemented - - return ( - self._lower, - self._upper, - self._inc_lower, - self._inc_upper, - self._empty, - ) == ( - o._lower, - o._upper, - o._inc_lower, - o._inc_upper, - o._empty, - ) - - def __hash__(self) -> int: - return hash(( - self._lower, - self._upper, - self._inc_lower, - self._inc_upper, - self._empty, - )) - - def __str__(self) -> str: - if self._empty: - desc = "empty" - else: - lb = "(" if not self._inc_lower else "[" - if self._lower is not None: - lb += repr(self._lower) - - if self._upper is not None: - ub = repr(self._upper) - else: - ub = "" - - ub += ")" if self._inc_upper else "]" - - desc = f"{lb}, {ub}" - - return f"" - - __repr__ = __str__ - - -# TODO: maybe we should implement range and multirange operations as well as -# normalization of the sub-ranges? -class MultiRange(Iterable[T]): - - _ranges: Sequence[T] - - def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: - if iterable is not None: - self._ranges = tuple(iterable) - else: - self._ranges = tuple() - - def __len__(self) -> int: - return len(self._ranges) - - def __iter__(self) -> Iterator[T]: - return iter(self._ranges) - - def __reversed__(self) -> Iterator[T]: - return reversed(self._ranges) - - def __str__(self) -> str: - return f'' - - __repr__ = __str__ - - def __eq__(self, other: Any) -> bool: - if isinstance(other, MultiRange): - return set(self._ranges) == set(other._ranges) - else: - return NotImplemented - - def __hash__(self) -> int: - return hash(self._ranges) +# Auto-generated shim +import gel.datatypes.range as _mod +import sys as _sys +_cur = _sys.modules['edgedb.datatypes.range'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/describe.py b/edgedb/describe.py index 92e38854..4b55e179 100644 --- a/edgedb/describe.py +++ b/edgedb/describe.py @@ -1,98 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import dataclasses -import typing -import uuid - -from . import enums - - -@dataclasses.dataclass(frozen=True) -class AnyType: - desc_id: uuid.UUID - name: typing.Optional[str] - - -@dataclasses.dataclass(frozen=True) -class Element: - type: AnyType - cardinality: enums.Cardinality - is_implicit: bool - kind: enums.ElementKind - - -@dataclasses.dataclass(frozen=True) -class SequenceType(AnyType): - element_type: AnyType - - -@dataclasses.dataclass(frozen=True) -class SetType(SequenceType): - pass - - -@dataclasses.dataclass(frozen=True) -class ObjectType(AnyType): - elements: typing.Dict[str, Element] - - -@dataclasses.dataclass(frozen=True) -class BaseScalarType(AnyType): - pass - - -@dataclasses.dataclass(frozen=True) -class ScalarType(AnyType): - base_type: BaseScalarType - - -@dataclasses.dataclass(frozen=True) -class TupleType(AnyType): - element_types: typing.Tuple[AnyType] - - -@dataclasses.dataclass(frozen=True) -class NamedTupleType(AnyType): - element_types: typing.Dict[str, AnyType] - - -@dataclasses.dataclass(frozen=True) -class ArrayType(SequenceType): - pass - - -@dataclasses.dataclass(frozen=True) -class EnumType(AnyType): - members: typing.Tuple[str] - - -@dataclasses.dataclass(frozen=True) -class SparseObjectType(ObjectType): - pass - - -@dataclasses.dataclass(frozen=True) -class RangeType(AnyType): - value_type: AnyType - - -@dataclasses.dataclass(frozen=True) -class MultiRangeType(AnyType): - value_type: AnyType +# Auto-generated shim +import gel.describe as _mod +import sys as _sys +_cur = _sys.modules['edgedb.describe'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/enums.py b/edgedb/enums.py index 6312f596..feeb69ea 100644 --- a/edgedb/enums.py +++ b/edgedb/enums.py @@ -1,74 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2021-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import enum - - -class Capability(enum.IntFlag): - - NONE = 0 # noqa - MODIFICATIONS = 1 << 0 # noqa - SESSION_CONFIG = 1 << 1 # noqa - TRANSACTION = 1 << 2 # noqa - DDL = 1 << 3 # noqa - PERSISTENT_CONFIG = 1 << 4 # noqa - - ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa - EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa - LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa - - -class CompilationFlag(enum.IntFlag): - - INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa - INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa - INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa - - -class Cardinality(enum.Enum): - # Cardinality isn't applicable for the query: - # * the query is a command like CONFIGURE that - # does not return any data; - # * the query is composed of multiple queries. - NO_RESULT = 0x6e - - # Cardinality is 1 or 0 - AT_MOST_ONE = 0x6f - - # Cardinality is 1 - ONE = 0x41 - - # Cardinality is >= 0 - MANY = 0x6d - - # Cardinality is >= 1 - AT_LEAST_ONE = 0x4d - - def is_single(self) -> bool: - return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE} - - def is_multi(self) -> bool: - return self in {Cardinality.AT_LEAST_ONE, Cardinality.MANY} - - -class ElementKind(enum.Enum): - - LINK = 1 # noqa - PROPERTY = 2 # noqa - LINK_PROPERTY = 3 # noqa +# Auto-generated shim +import gel.enums as _mod +import sys as _sys +_cur = _sys.modules['edgedb.enums'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/errors/__init__.py b/edgedb/errors/__init__.py index 424edc91..f510b333 100644 --- a/edgedb/errors/__init__.py +++ b/edgedb/errors/__init__.py @@ -1,518 +1,11 @@ -# AUTOGENERATED FROM "edb/api/errors.txt" WITH -# $ edb gen-errors \ -# --import 'from edgedb.errors._base import *\nfrom edgedb.errors.tags import *' \ -# --extra-all "_base.__all__" \ -# --stdout \ -# --client - - -# flake8: noqa - - -from edgedb.errors._base import * -from edgedb.errors.tags import * - - -__all__ = _base.__all__ + ( # type: ignore - 'InternalServerError', - 'UnsupportedFeatureError', - 'ProtocolError', - 'BinaryProtocolError', - 'UnsupportedProtocolVersionError', - 'TypeSpecNotFoundError', - 'UnexpectedMessageError', - 'InputDataError', - 'ParameterTypeMismatchError', - 'StateMismatchError', - 'ResultCardinalityMismatchError', - 'CapabilityError', - 'UnsupportedCapabilityError', - 'DisabledCapabilityError', - 'QueryError', - 'InvalidSyntaxError', - 'EdgeQLSyntaxError', - 'SchemaSyntaxError', - 'GraphQLSyntaxError', - 'InvalidTypeError', - 'InvalidTargetError', - 'InvalidLinkTargetError', - 'InvalidPropertyTargetError', - 'InvalidReferenceError', - 'UnknownModuleError', - 'UnknownLinkError', - 'UnknownPropertyError', - 'UnknownUserError', - 'UnknownDatabaseError', - 'UnknownParameterError', - 'SchemaError', - 'SchemaDefinitionError', - 'InvalidDefinitionError', - 'InvalidModuleDefinitionError', - 'InvalidLinkDefinitionError', - 'InvalidPropertyDefinitionError', - 'InvalidUserDefinitionError', - 'InvalidDatabaseDefinitionError', - 'InvalidOperatorDefinitionError', - 'InvalidAliasDefinitionError', - 'InvalidFunctionDefinitionError', - 'InvalidConstraintDefinitionError', - 'InvalidCastDefinitionError', - 'DuplicateDefinitionError', - 'DuplicateModuleDefinitionError', - 'DuplicateLinkDefinitionError', - 'DuplicatePropertyDefinitionError', - 'DuplicateUserDefinitionError', - 'DuplicateDatabaseDefinitionError', - 'DuplicateOperatorDefinitionError', - 'DuplicateViewDefinitionError', - 'DuplicateFunctionDefinitionError', - 'DuplicateConstraintDefinitionError', - 'DuplicateCastDefinitionError', - 'DuplicateMigrationError', - 'SessionTimeoutError', - 'IdleSessionTimeoutError', - 'QueryTimeoutError', - 'TransactionTimeoutError', - 'IdleTransactionTimeoutError', - 'ExecutionError', - 'InvalidValueError', - 'DivisionByZeroError', - 'NumericOutOfRangeError', - 'AccessPolicyError', - 'QueryAssertionError', - 'IntegrityError', - 'ConstraintViolationError', - 'CardinalityViolationError', - 'MissingRequiredError', - 'TransactionError', - 'TransactionConflictError', - 'TransactionSerializationError', - 'TransactionDeadlockError', - 'WatchError', - 'ConfigurationError', - 'AccessError', - 'AuthenticationError', - 'AvailabilityError', - 'BackendUnavailableError', - 'ServerOfflineError', - 'BackendError', - 'UnsupportedBackendFeatureError', - 'LogMessage', - 'WarningMessage', - 'ClientError', - 'ClientConnectionError', - 'ClientConnectionFailedError', - 'ClientConnectionFailedTemporarilyError', - 'ClientConnectionTimeoutError', - 'ClientConnectionClosedError', - 'InterfaceError', - 'QueryArgumentError', - 'MissingArgumentError', - 'UnknownArgumentError', - 'InvalidArgumentError', - 'NoDataError', - 'InternalClientError', -) - - -class InternalServerError(EdgeDBError): - _code = 0x_01_00_00_00 - - -class UnsupportedFeatureError(EdgeDBError): - _code = 0x_02_00_00_00 - - -class ProtocolError(EdgeDBError): - _code = 0x_03_00_00_00 - - -class BinaryProtocolError(ProtocolError): - _code = 0x_03_01_00_00 - - -class UnsupportedProtocolVersionError(BinaryProtocolError): - _code = 0x_03_01_00_01 - - -class TypeSpecNotFoundError(BinaryProtocolError): - _code = 0x_03_01_00_02 - - -class UnexpectedMessageError(BinaryProtocolError): - _code = 0x_03_01_00_03 - - -class InputDataError(ProtocolError): - _code = 0x_03_02_00_00 - - -class ParameterTypeMismatchError(InputDataError): - _code = 0x_03_02_01_00 - - -class StateMismatchError(InputDataError): - _code = 0x_03_02_02_00 - tags = frozenset({SHOULD_RETRY}) - - -class ResultCardinalityMismatchError(ProtocolError): - _code = 0x_03_03_00_00 - - -class CapabilityError(ProtocolError): - _code = 0x_03_04_00_00 - - -class UnsupportedCapabilityError(CapabilityError): - _code = 0x_03_04_01_00 - - -class DisabledCapabilityError(CapabilityError): - _code = 0x_03_04_02_00 - - -class QueryError(EdgeDBError): - _code = 0x_04_00_00_00 - - -class InvalidSyntaxError(QueryError): - _code = 0x_04_01_00_00 - - -class EdgeQLSyntaxError(InvalidSyntaxError): - _code = 0x_04_01_01_00 - - -class SchemaSyntaxError(InvalidSyntaxError): - _code = 0x_04_01_02_00 - - -class GraphQLSyntaxError(InvalidSyntaxError): - _code = 0x_04_01_03_00 - - -class InvalidTypeError(QueryError): - _code = 0x_04_02_00_00 - - -class InvalidTargetError(InvalidTypeError): - _code = 0x_04_02_01_00 - - -class InvalidLinkTargetError(InvalidTargetError): - _code = 0x_04_02_01_01 - - -class InvalidPropertyTargetError(InvalidTargetError): - _code = 0x_04_02_01_02 - - -class InvalidReferenceError(QueryError): - _code = 0x_04_03_00_00 - - -class UnknownModuleError(InvalidReferenceError): - _code = 0x_04_03_00_01 - - -class UnknownLinkError(InvalidReferenceError): - _code = 0x_04_03_00_02 - - -class UnknownPropertyError(InvalidReferenceError): - _code = 0x_04_03_00_03 - - -class UnknownUserError(InvalidReferenceError): - _code = 0x_04_03_00_04 - - -class UnknownDatabaseError(InvalidReferenceError): - _code = 0x_04_03_00_05 - - -class UnknownParameterError(InvalidReferenceError): - _code = 0x_04_03_00_06 - - -class SchemaError(QueryError): - _code = 0x_04_04_00_00 - - -class SchemaDefinitionError(QueryError): - _code = 0x_04_05_00_00 - - -class InvalidDefinitionError(SchemaDefinitionError): - _code = 0x_04_05_01_00 - - -class InvalidModuleDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_01 - - -class InvalidLinkDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_02 - - -class InvalidPropertyDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_03 - - -class InvalidUserDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_04 - - -class InvalidDatabaseDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_05 - - -class InvalidOperatorDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_06 - - -class InvalidAliasDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_07 - - -class InvalidFunctionDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_08 - - -class InvalidConstraintDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_09 - - -class InvalidCastDefinitionError(InvalidDefinitionError): - _code = 0x_04_05_01_0A - - -class DuplicateDefinitionError(SchemaDefinitionError): - _code = 0x_04_05_02_00 - - -class DuplicateModuleDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_01 - - -class DuplicateLinkDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_02 - - -class DuplicatePropertyDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_03 - - -class DuplicateUserDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_04 - - -class DuplicateDatabaseDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_05 - - -class DuplicateOperatorDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_06 - - -class DuplicateViewDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_07 - - -class DuplicateFunctionDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_08 - - -class DuplicateConstraintDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_09 - - -class DuplicateCastDefinitionError(DuplicateDefinitionError): - _code = 0x_04_05_02_0A - - -class DuplicateMigrationError(DuplicateDefinitionError): - _code = 0x_04_05_02_0B - - -class SessionTimeoutError(QueryError): - _code = 0x_04_06_00_00 - - -class IdleSessionTimeoutError(SessionTimeoutError): - _code = 0x_04_06_01_00 - tags = frozenset({SHOULD_RETRY}) - - -class QueryTimeoutError(SessionTimeoutError): - _code = 0x_04_06_02_00 - - -class TransactionTimeoutError(SessionTimeoutError): - _code = 0x_04_06_0A_00 - - -class IdleTransactionTimeoutError(TransactionTimeoutError): - _code = 0x_04_06_0A_01 - - -class ExecutionError(EdgeDBError): - _code = 0x_05_00_00_00 - - -class InvalidValueError(ExecutionError): - _code = 0x_05_01_00_00 - - -class DivisionByZeroError(InvalidValueError): - _code = 0x_05_01_00_01 - - -class NumericOutOfRangeError(InvalidValueError): - _code = 0x_05_01_00_02 - - -class AccessPolicyError(InvalidValueError): - _code = 0x_05_01_00_03 - - -class QueryAssertionError(InvalidValueError): - _code = 0x_05_01_00_04 - - -class IntegrityError(ExecutionError): - _code = 0x_05_02_00_00 - - -class ConstraintViolationError(IntegrityError): - _code = 0x_05_02_00_01 - - -class CardinalityViolationError(IntegrityError): - _code = 0x_05_02_00_02 - - -class MissingRequiredError(IntegrityError): - _code = 0x_05_02_00_03 - - -class TransactionError(ExecutionError): - _code = 0x_05_03_00_00 - - -class TransactionConflictError(TransactionError): - _code = 0x_05_03_01_00 - tags = frozenset({SHOULD_RETRY}) - - -class TransactionSerializationError(TransactionConflictError): - _code = 0x_05_03_01_01 - tags = frozenset({SHOULD_RETRY}) - - -class TransactionDeadlockError(TransactionConflictError): - _code = 0x_05_03_01_02 - tags = frozenset({SHOULD_RETRY}) - - -class WatchError(ExecutionError): - _code = 0x_05_04_00_00 - - -class ConfigurationError(EdgeDBError): - _code = 0x_06_00_00_00 - - -class AccessError(EdgeDBError): - _code = 0x_07_00_00_00 - - -class AuthenticationError(AccessError): - _code = 0x_07_01_00_00 - - -class AvailabilityError(EdgeDBError): - _code = 0x_08_00_00_00 - - -class BackendUnavailableError(AvailabilityError): - _code = 0x_08_00_00_01 - tags = frozenset({SHOULD_RETRY}) - - -class ServerOfflineError(AvailabilityError): - _code = 0x_08_00_00_02 - tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) - - -class BackendError(EdgeDBError): - _code = 0x_09_00_00_00 - - -class UnsupportedBackendFeatureError(BackendError): - _code = 0x_09_00_01_00 - - -class LogMessage(EdgeDBMessage): - _code = 0x_F0_00_00_00 - - -class WarningMessage(LogMessage): - _code = 0x_F0_01_00_00 - - -class ClientError(EdgeDBError): - _code = 0x_FF_00_00_00 - - -class ClientConnectionError(ClientError): - _code = 0x_FF_01_00_00 - - -class ClientConnectionFailedError(ClientConnectionError): - _code = 0x_FF_01_01_00 - - -class ClientConnectionFailedTemporarilyError(ClientConnectionFailedError): - _code = 0x_FF_01_01_01 - tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) - - -class ClientConnectionTimeoutError(ClientConnectionError): - _code = 0x_FF_01_02_00 - tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) - - -class ClientConnectionClosedError(ClientConnectionError): - _code = 0x_FF_01_03_00 - tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) - - -class InterfaceError(ClientError): - _code = 0x_FF_02_00_00 - - -class QueryArgumentError(InterfaceError): - _code = 0x_FF_02_01_00 - - -class MissingArgumentError(QueryArgumentError): - _code = 0x_FF_02_01_01 - - -class UnknownArgumentError(QueryArgumentError): - _code = 0x_FF_02_01_02 - - -class InvalidArgumentError(QueryArgumentError): - _code = 0x_FF_02_01_03 - - -class NoDataError(ClientError): - _code = 0x_FF_03_00_00 - - -class InternalClientError(ClientError): - _code = 0x_FF_04_00_00 - +# Auto-generated shim +import gel.errors as _mod +import sys as _sys +_cur = _sys.modules['edgedb.errors'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py index 03ce547c..d567937f 100644 --- a/edgedb/errors/_base.py +++ b/edgedb/errors/_base.py @@ -1,363 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import io -import os -import traceback -import unicodedata -import warnings - -__all__ = ( - 'EdgeDBError', 'EdgeDBMessage', -) - - -class Meta(type): - - def __new__(mcls, name, bases, dct): - cls = super().__new__(mcls, name, bases, dct) - - code = dct.get('_code') - if code is not None: - mcls._index[code] = cls - - # If it's a base class add it to the base class index - b1, b2, b3, b4 = _decode(code) - if b1 == 0 or b2 == 0 or b3 == 0 or b4 == 0: - mcls._base_class_index[(b1, b2, b3, b4)] = cls - - return cls - - -class EdgeDBMessageMeta(Meta): - - _base_class_index = {} - _index = {} - - -class EdgeDBMessage(Warning, metaclass=EdgeDBMessageMeta): - - _code = None - - def __init__(self, severity, message): - super().__init__(message) - self._severity = severity - - def get_severity(self): - return self._severity - - def get_severity_name(self): - return _severity_name(self._severity) - - def get_code(self): - return self._code - - @staticmethod - def _from_code(code, severity, message, *args, **kwargs): - cls = _lookup_message_cls(code) - exc = cls(severity, message, *args, **kwargs) - exc._code = code - return exc - - -class EdgeDBErrorMeta(Meta): - - _base_class_index = {} - _index = {} - - -class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta): - - _code = None - _query = None - tags = frozenset() - - def __init__(self, *args, **kwargs): - self._attrs = {} - super().__init__(*args, **kwargs) - - def has_tag(self, tag): - return tag in self.tags - - @property - def _position(self): - # not a stable API method - return int(self._read_str_field(FIELD_POSITION_START, -1)) - - @property - def _position_start(self): - # not a stable API method - return int(self._read_str_field(FIELD_CHARACTER_START, -1)) - - @property - def _position_end(self): - # not a stable API method - return int(self._read_str_field(FIELD_CHARACTER_END, -1)) - - @property - def _line(self): - # not a stable API method - return int(self._read_str_field(FIELD_LINE_START, -1)) - - @property - def _col(self): - # not a stable API method - return int(self._read_str_field(FIELD_COLUMN_START, -1)) - - @property - def _hint(self): - # not a stable API method - return self._read_str_field(FIELD_HINT) - - @property - def _details(self): - # not a stable API method - return self._read_str_field(FIELD_DETAILS) - - def _read_str_field(self, key, default=None): - val = self._attrs.get(key) - if isinstance(val, bytes): - return val.decode('utf-8') - elif val is not None: - return val - return default - - def get_code(self): - return self._code - - def get_server_context(self): - return self._read_str_field(FIELD_SERVER_TRACEBACK) - - @staticmethod - def _from_code(code, *args, **kwargs): - cls = _lookup_error_cls(code) - exc = cls(*args, **kwargs) - exc._code = code - return exc - - @staticmethod - def _from_json(data): - exc = EdgeDBError._from_code(data['code'], data['message']) - exc._attrs = { - field: data[name] - for name, field in _JSON_FIELDS.items() - if name in data - } - return exc - - def __str__(self): - msg = super().__str__() - if SHOW_HINT and self._query and self._position_start >= 0: - try: - return _format_error( - msg, - self._query, - self._position_start, - max(1, self._position_end - self._position_start), - self._line if self._line > 0 else "?", - self._col if self._col > 0 else "?", - self._hint or "error", - self._details, - ) - except Exception: - return "".join( - ( - msg, - LINESEP, - LINESEP, - "During formatting of the above exception, " - "another exception occurred:", - LINESEP, - LINESEP, - traceback.format_exc(), - ) - ) - else: - return msg - - -def _lookup_cls(code: int, *, meta: type, default: type): - try: - return meta._index[code] - except KeyError: - pass - - b1, b2, b3, _ = _decode(code) - - try: - return meta._base_class_index[(b1, b2, b3, 0)] - except KeyError: - pass - try: - return meta._base_class_index[(b1, b2, 0, 0)] - except KeyError: - pass - try: - return meta._base_class_index[(b1, 0, 0, 0)] - except KeyError: - pass - - return default - - -def _lookup_error_cls(code: int): - return _lookup_cls(code, meta=EdgeDBErrorMeta, default=EdgeDBError) - - -def _lookup_message_cls(code: int): - return _lookup_cls(code, meta=EdgeDBMessageMeta, default=EdgeDBMessage) - - -def _decode(code: int): - return tuple(code.to_bytes(4, 'big')) - - -def _severity_name(severity): - if severity <= EDGE_SEVERITY_DEBUG: - return 'DEBUG' - if severity <= EDGE_SEVERITY_INFO: - return 'INFO' - if severity <= EDGE_SEVERITY_NOTICE: - return 'NOTICE' - if severity <= EDGE_SEVERITY_WARNING: - return 'WARNING' - if severity <= EDGE_SEVERITY_ERROR: - return 'ERROR' - if severity <= EDGE_SEVERITY_FATAL: - return 'FATAL' - return 'PANIC' - - -def _format_error(msg, query, start, offset, line, col, hint, details): - c = get_color() - rv = io.StringIO() - rv.write(f"{c.BOLD}{msg}{c.ENDC}{LINESEP}") - lines = query.splitlines(keepends=True) - num_len = len(str(len(lines))) - rv.write(f"{c.BLUE}{'':>{num_len}} ┌─{c.ENDC} query:{line}:{col}{LINESEP}") - rv.write(f"{c.BLUE}{'':>{num_len}} │ {c.ENDC}{LINESEP}") - for num, line in enumerate(lines): - length = len(line) - line = line.rstrip() # we'll use our own line separator - if start >= length: - # skip lines before the error - start -= length - continue - - if start >= 0: - # Error starts in current line, write the line before the error - first_half = repr(line[:start])[1:-1] - line = line[start:] - length -= start - rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.ENDC}{first_half}") - start = _unicode_width(first_half) - else: - # Multi-line error continues - rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.FAIL}│ {c.ENDC}") - - if offset > length: - # Error is ending beyond current line - line = repr(line)[1:-1] - rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}") - if start >= 0: - # Multi-line error starts - rv.write(f"{c.BLUE}{'':>{num_len}} │ " - f"{c.FAIL}╭─{'─' * start}^{c.ENDC}{LINESEP}") - offset -= length - start = -1 # mark multi-line - else: - # Error is ending within current line - first_half = repr(line[:offset])[1:-1] - line = repr(line[offset:])[1:-1] - rv.write(f"{c.FAIL}{first_half}{c.ENDC}{line}{LINESEP}") - size = _unicode_width(first_half) - if start >= 0: - # Mark single-line error - rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}" - f"{c.FAIL}{'^' * size} {hint}{c.ENDC}") - else: - # End of multi-line error - rv.write(f"{c.BLUE}{'':>{num_len}} │ " - f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}") - break - - if details: - rv.write(f"{LINESEP}Details: {details}") - - return rv.getvalue() - - -def _unicode_width(text): - return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else - 2 if unicodedata.east_asian_width(c) == "W" else 1 - for c in text) - - -FIELD_HINT = 0x_00_01 -FIELD_DETAILS = 0x_00_02 -FIELD_SERVER_TRACEBACK = 0x_01_01 - -# XXX: Subject to be changed/deprecated. -FIELD_POSITION_START = 0x_FF_F1 -FIELD_POSITION_END = 0x_FF_F2 -FIELD_LINE_START = 0x_FF_F3 -FIELD_COLUMN_START = 0x_FF_F4 -FIELD_UTF16_COLUMN_START = 0x_FF_F5 -FIELD_LINE_END = 0x_FF_F6 -FIELD_COLUMN_END = 0x_FF_F7 -FIELD_UTF16_COLUMN_END = 0x_FF_F8 -FIELD_CHARACTER_START = 0x_FF_F9 -FIELD_CHARACTER_END = 0x_FF_FA - - -EDGE_SEVERITY_DEBUG = 20 -EDGE_SEVERITY_INFO = 40 -EDGE_SEVERITY_NOTICE = 60 -EDGE_SEVERITY_WARNING = 80 -EDGE_SEVERITY_ERROR = 120 -EDGE_SEVERITY_FATAL = 200 -EDGE_SEVERITY_PANIC = 255 - - -# Fields to include in the json dump of the type -_JSON_FIELDS = { - 'hint': FIELD_HINT, - 'details': FIELD_DETAILS, - 'start': FIELD_CHARACTER_START, - 'end': FIELD_CHARACTER_END, - 'line': FIELD_LINE_START, - 'col': FIELD_COLUMN_START, -} - - -LINESEP = os.linesep - -try: - SHOW_HINT = {"default": True, "enabled": True, "disabled": False}[ - os.getenv("EDGEDB_ERROR_HINT", "default") - ] -except KeyError: - warnings.warn( - "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled", - stacklevel=1, - ) - SHOW_HINT = False - - -from edgedb.color import get_color +# Auto-generated shim +import gel.errors._base as _mod +import sys as _sys +_cur = _sys.modules['edgedb.errors._base'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/errors/tags.py b/edgedb/errors/tags.py index 275b31ac..ff307bb9 100644 --- a/edgedb/errors/tags.py +++ b/edgedb/errors/tags.py @@ -1,25 +1,11 @@ -__all__ = [ - 'Tag', - 'SHOULD_RECONNECT', - 'SHOULD_RETRY', -] - - -class Tag(object): - """Error tag - - Tags are used to differentiate certain properties of errors that apply to - error classes across hierarchy. - - Use ``error.has_tag(tag_name)`` to check for a tag. - """ - - def __init__(self, name): - self.name = name - - def __repr__(self): - return f'' - - -SHOULD_RECONNECT = Tag('SHOULD_RECONNECT') -SHOULD_RETRY = Tag('SHOULD_RETRY') +# Auto-generated shim +import gel.errors.tags as _mod +import sys as _sys +_cur = _sys.modules['edgedb.errors.tags'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/introspect.py b/edgedb/introspect.py index 0db045e6..7fa06bfc 100644 --- a/edgedb/introspect.py +++ b/edgedb/introspect.py @@ -1,66 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -# IMPORTANT: this private API is subject to change. - - -import functools -import typing - -from edgedb.datatypes import datatypes as dt -from edgedb.enums import ElementKind - - -class PointerDescription(typing.NamedTuple): - - name: str - kind: ElementKind - implicit: bool - - -class ObjectDescription(typing.NamedTuple): - - pointers: typing.Tuple[PointerDescription, ...] - - -@functools.lru_cache() -def _introspect_object_desc(desc) -> ObjectDescription: - pointers = [] - # Call __dir__ directly as dir() scrambles the order. - for name in desc.__dir__(): - if desc.is_link(name): - kind = ElementKind.LINK - elif desc.is_linkprop(name): - continue - else: - kind = ElementKind.PROPERTY - - pointers.append( - PointerDescription( - name=name, - kind=kind, - implicit=desc.is_implicit(name))) - - return ObjectDescription( - pointers=tuple(pointers)) - - -def introspect_object(obj) -> ObjectDescription: - return _introspect_object_desc( - dt.get_object_descriptor(obj)) +# Auto-generated shim +import gel.introspect as _mod +import sys as _sys +_cur = _sys.modules['edgedb.introspect'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/options.py b/edgedb/options.py index cec9541a..ba17dbc2 100644 --- a/edgedb/options.py +++ b/edgedb/options.py @@ -1,492 +1,11 @@ -import abc -import enum -import logging -import random -import typing -import sys -from collections import namedtuple - -from . import errors - - -logger = logging.getLogger('edgedb') - - -_RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"]) - - -def default_backoff(attempt): - return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001 - - -WarningHandler = typing.Callable[ - [typing.Tuple[errors.EdgeDBError, ...], typing.Any], - typing.Any, -] - - -def raise_warnings(warnings, res): - if ( - len(warnings) > 1 - and sys.version_info >= (3, 11) - ): - raise ExceptionGroup( # noqa - "Query produced warnings", warnings - ) - else: - raise warnings[0] - - -def log_warnings(warnings, res): - for w in warnings: - logger.warning("EdgeDB warning: %s", str(w)) - return res - - -class RetryCondition: - """Specific condition to retry on for fine-grained control""" - TransactionConflict = enum.auto() - NetworkError = enum.auto() - - -class IsolationLevel: - """Isolation level for transaction""" - Serializable = "SERIALIZABLE" - - -class RetryOptions: - """An immutable class that contains rules for `transaction()`""" - __slots__ = ['_default', '_overrides'] - - def __init__(self, attempts: int, backoff=default_backoff): - self._default = _RetryRule(attempts, backoff) - self._overrides = None - - def with_rule(self, condition, attempts=None, backoff=None): - default = self._default - overrides = self._overrides - if overrides is None: - overrides = {} - else: - overrides = overrides.copy() - overrides[condition] = _RetryRule( - default.attempts if attempts is None else attempts, - default.backoff if backoff is None else backoff, - ) - result = RetryOptions.__new__(RetryOptions) - result._default = default - result._overrides = overrides - return result - - @classmethod - def defaults(cls): - return cls( - attempts=3, - backoff=default_backoff, - ) - - def get_rule_for_exception(self, exception): - default = self._default - overrides = self._overrides - res = default - if overrides: - if isinstance(exception, errors.TransactionConflictError): - res = overrides.get(RetryCondition.TransactionConflict, res) - elif isinstance(exception, errors.ClientError): - res = overrides.get(RetryCondition.NetworkError, res) - return res - - -class TransactionOptions: - """Options for `transaction()`""" - __slots__ = ['_isolation', '_readonly', '_deferrable'] - - def __init__( - self, - isolation: IsolationLevel=IsolationLevel.Serializable, - readonly: bool = False, - deferrable: bool = False, - ): - self._isolation = isolation - self._readonly = readonly - self._deferrable = deferrable - - @classmethod - def defaults(cls): - return cls() - - def start_transaction_query(self): - isolation = str(self._isolation) - if self._readonly: - mode = 'READ ONLY' - else: - mode = 'READ WRITE' - - if self._deferrable: - defer = 'DEFERRABLE' - else: - defer = 'NOT DEFERRABLE' - - return f'START TRANSACTION ISOLATION {isolation}, {mode}, {defer};' - - def __repr__(self): - return ( - f'<{self.__class__.__name__} ' - f'isolation:{self._isolation}, ' - f'readonly:{self._readonly}, ' - f'deferrable:{self._deferrable}>' - ) - - -class State: - __slots__ = ['_module', '_aliases', '_config', '_globals'] - - def __init__( - self, - default_module: typing.Optional[str] = None, - module_aliases: typing.Mapping[str, str] = None, - config: typing.Mapping[str, typing.Any] = None, - globals_: typing.Mapping[str, typing.Any] = None, - ): - self._module = default_module - self._aliases = {} if module_aliases is None else dict(module_aliases) - self._config = {} if config is None else dict(config) - self._globals = ( - {} if globals_ is None else self.with_globals(globals_)._globals - ) - - @classmethod - def _new(cls, default_module, module_aliases, config, globals_): - rv = cls.__new__(cls) - rv._module = default_module - rv._aliases = module_aliases - rv._config = config - rv._globals = globals_ - return rv - - @classmethod - def defaults(cls): - return cls() - - def with_default_module(self, module: typing.Optional[str] = None): - return self._new( - default_module=module, - module_aliases=self._aliases, - config=self._config, - globals_=self._globals, - ) - - def with_module_aliases(self, *args, **aliases): - if len(args) > 1: - raise errors.InvalidArgumentError( - "with_module_aliases() takes from 0 to 1 positional arguments " - "but {} were given".format(len(args)) - ) - aliases_dict = args[0] if args else {} - aliases_dict.update(aliases) - new_aliases = self._aliases.copy() - new_aliases.update(aliases_dict) - return self._new( - default_module=self._module, - module_aliases=new_aliases, - config=self._config, - globals_=self._globals, - ) - - def with_config(self, *args, **config): - if len(args) > 1: - raise errors.InvalidArgumentError( - "with_config() takes from 0 to 1 positional arguments " - "but {} were given".format(len(args)) - ) - config_dict = args[0] if args else {} - config_dict.update(config) - new_config = self._config.copy() - new_config.update(config_dict) - return self._new( - default_module=self._module, - module_aliases=self._aliases, - config=new_config, - globals_=self._globals, - ) - - def resolve(self, name: str) -> str: - parts = name.split("::", 1) - if len(parts) == 1: - return f"{self._module or 'default'}::{name}" - elif len(parts) == 2: - mod, name = parts - mod = self._aliases.get(mod, mod) - return f"{mod}::{name}" - else: - raise AssertionError('broken split') - - def with_globals(self, *args, **globals_): - if len(args) > 1: - raise errors.InvalidArgumentError( - "with_globals() takes from 0 to 1 positional arguments " - "but {} were given".format(len(args)) - ) - new_globals = self._globals.copy() - if args: - for k, v in args[0].items(): - new_globals[self.resolve(k)] = v - for k, v in globals_.items(): - new_globals[self.resolve(k)] = v - return self._new( - default_module=self._module, - module_aliases=self._aliases, - config=self._config, - globals_=new_globals, - ) - - def without_module_aliases(self, *aliases): - if not aliases: - new_aliases = {} - else: - new_aliases = self._aliases.copy() - for alias in aliases: - new_aliases.pop(alias, None) - return self._new( - default_module=self._module, - module_aliases=new_aliases, - config=self._config, - globals_=self._globals, - ) - - def without_config(self, *config_names): - if not config_names: - new_config = {} - else: - new_config = self._config.copy() - for name in config_names: - new_config.pop(name, None) - return self._new( - default_module=self._module, - module_aliases=self._aliases, - config=new_config, - globals_=self._globals, - ) - - def without_globals(self, *global_names): - if not global_names: - new_globals = {} - else: - new_globals = self._globals.copy() - for name in global_names: - new_globals.pop(self.resolve(name), None) - return self._new( - default_module=self._module, - module_aliases=self._aliases, - config=self._config, - globals_=new_globals, - ) - - def as_dict(self): - rv = {} - if self._module is not None: - rv["module"] = self._module - if self._aliases: - rv["aliases"] = list(self._aliases.items()) - if self._config: - rv["config"] = self._config - if self._globals: - rv["globals"] = self._globals - return rv - - -class _OptionsMixin: - def __init__(self, *args, **kwargs): - self._options = _Options.defaults() - super().__init__(*args, **kwargs) - - @abc.abstractmethod - def _shallow_clone(self): - pass - - def with_transaction_options(self, options: TransactionOptions = None): - """Returns object with adjusted options for future transactions. - - :param options TransactionOptions: - Object that encapsulates transaction options. - - This method returns a "shallow copy" of the current object - with modified transaction options. - - Both ``self`` and returned object can be used after, but when using - them transaction options applied will be different. - - Transaction options are used by the ``transaction`` method. - """ - result = self._shallow_clone() - result._options = self._options.with_transaction_options(options) - return result - - def with_retry_options(self, options: RetryOptions=None): - """Returns object with adjusted options for future retrying - transactions. - - :param options RetryOptions: - Object that encapsulates retry options. - - This method returns a "shallow copy" of the current object - with modified retry options. - - Both ``self`` and returned object can be used after, but when using - them retry options applied will be different. - """ - - result = self._shallow_clone() - result._options = self._options.with_retry_options(options) - return result - - def with_warning_handler(self, warning_handler: WarningHandler=None): - """Returns object with adjusted options for handling warnings. - - :param warning_handler WarningHandler: - Function for handling warnings. It is passed a tuple of warnings - and the query result and returns a potentially updated query - result. - - This method returns a "shallow copy" of the current object - with modified retry options. - - Both ``self`` and returned object can be used after, but when using - them retry options applied will be different. - """ - - result = self._shallow_clone() - result._options = self._options.with_warning_handler(warning_handler) - return result - - def with_state(self, state: State): - result = self._shallow_clone() - result._options = self._options.with_state(state) - return result - - def with_default_module(self, module: typing.Optional[str] = None): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.with_default_module(module) - ) - return result - - def with_module_aliases(self, *args, **aliases): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.with_module_aliases(*args, **aliases) - ) - return result - - def with_config(self, *args, **config): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.with_config(*args, **config) - ) - return result - - def with_globals(self, *args, **globals_): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.with_globals(*args, **globals_) - ) - return result - - def without_module_aliases(self, *aliases): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.without_module_aliases(*aliases) - ) - return result - - def without_config(self, *config_names): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.without_config(*config_names) - ) - return result - - def without_globals(self, *global_names): - result = self._shallow_clone() - result._options = self._options.with_state( - self._options.state.without_globals(*global_names) - ) - return result - - -class _Options: - """Internal class for storing connection options""" - - __slots__ = [ - '_retry_options', '_transaction_options', '_state', - '_warning_handler' - ] - - def __init__( - self, - retry_options: RetryOptions, - transaction_options: TransactionOptions, - state: State, - warning_handler: WarningHandler, - ): - self._retry_options = retry_options - self._transaction_options = transaction_options - self._state = state - self._warning_handler = warning_handler - - @property - def retry_options(self): - return self._retry_options - - @property - def transaction_options(self): - return self._transaction_options - - @property - def state(self): - return self._state - - @property - def warning_handler(self): - return self._warning_handler - - def with_retry_options(self, options: RetryOptions): - return _Options( - options, - self._transaction_options, - self._state, - self._warning_handler, - ) - - def with_transaction_options(self, options: TransactionOptions): - return _Options( - self._retry_options, - options, - self._state, - self._warning_handler, - ) - - def with_state(self, state: State): - return _Options( - self._retry_options, - self._transaction_options, - state, - self._warning_handler, - ) - - def with_warning_handler(self, warning_handler: WarningHandler): - return _Options( - self._retry_options, - self._transaction_options, - self._state, - warning_handler, - ) - - @classmethod - def defaults(cls): - return cls( - RetryOptions.defaults(), - TransactionOptions.defaults(), - State.defaults(), - log_warnings, - ) +# Auto-generated shim +import gel.options as _mod +import sys as _sys +_cur = _sys.modules['edgedb.options'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/pgproto/__init__.py b/edgedb/pgproto/__init__.py new file mode 100644 index 00000000..14e16623 --- /dev/null +++ b/edgedb/pgproto/__init__.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.pgproto as _mod +import sys as _sys +_cur = _sys.modules['edgedb.pgproto'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/pgproto/pgproto.py b/edgedb/pgproto/pgproto.py new file mode 100644 index 00000000..879e7585 --- /dev/null +++ b/edgedb/pgproto/pgproto.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.pgproto.pgproto as _mod +import sys as _sys +_cur = _sys.modules['edgedb.pgproto.pgproto'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/pgproto/types.py b/edgedb/pgproto/types.py new file mode 100644 index 00000000..68484dd3 --- /dev/null +++ b/edgedb/pgproto/types.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.pgproto.types as _mod +import sys as _sys +_cur = _sys.modules['edgedb.pgproto.types'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/platform.py b/edgedb/platform.py index 55410532..18eec595 100644 --- a/edgedb/platform.py +++ b/edgedb/platform.py @@ -1,52 +1,11 @@ -import functools -import os -import pathlib -import sys - -if sys.platform == "darwin": - def config_dir() -> pathlib.Path: - return ( - pathlib.Path.home() / "Library" / "Application Support" / "edgedb" - ) - - IS_WINDOWS = False - -elif sys.platform == "win32": - import ctypes - from ctypes import windll - - def config_dir() -> pathlib.Path: - path_buf = ctypes.create_unicode_buffer(255) - csidl = 28 # CSIDL_LOCAL_APPDATA - windll.shell32.SHGetFolderPathW(0, csidl, 0, 0, path_buf) - return pathlib.Path(path_buf.value) / "EdgeDB" / "config" - - IS_WINDOWS = True - -else: - def config_dir() -> pathlib.Path: - xdg_conf_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", ".")) - if not xdg_conf_dir.is_absolute(): - xdg_conf_dir = pathlib.Path.home() / ".config" - return xdg_conf_dir / "edgedb" - - IS_WINDOWS = False - - -def old_config_dir() -> pathlib.Path: - return pathlib.Path.home() / ".edgedb" - - -def search_config_dir(*suffix): - rv = functools.reduce(lambda p1, p2: p1 / p2, [config_dir(), *suffix]) - if rv.exists(): - return rv - - fallback = functools.reduce( - lambda p1, p2: p1 / p2, [old_config_dir(), *suffix] - ) - if fallback.exists(): - return fallback - - # None of the searched files exists, return the new path. - return rv +# Auto-generated shim +import gel.platform as _mod +import sys as _sys +_cur = _sys.modules['edgedb.platform'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/protocol/__init__.py b/edgedb/protocol/__init__.py index 46ceb625..f53f0a30 100644 --- a/edgedb/protocol/__init__.py +++ b/edgedb/protocol/__init__.py @@ -1,17 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# +# Auto-generated shim +import gel.protocol as _mod +import sys as _sys +_cur = _sys.modules['edgedb.protocol'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/protocol/asyncio_proto.py b/edgedb/protocol/asyncio_proto.py new file mode 100644 index 00000000..189b76d5 --- /dev/null +++ b/edgedb/protocol/asyncio_proto.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.protocol.asyncio_proto as _mod +import sys as _sys +_cur = _sys.modules['edgedb.protocol.asyncio_proto'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/protocol/blocking_proto.py b/edgedb/protocol/blocking_proto.py new file mode 100644 index 00000000..5974af66 --- /dev/null +++ b/edgedb/protocol/blocking_proto.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.protocol.blocking_proto as _mod +import sys as _sys +_cur = _sys.modules['edgedb.protocol.blocking_proto'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/protocol/protocol.py b/edgedb/protocol/protocol.py new file mode 100644 index 00000000..811284d5 --- /dev/null +++ b/edgedb/protocol/protocol.py @@ -0,0 +1,11 @@ +# Auto-generated shim +import gel.protocol.protocol as _mod +import sys as _sys +_cur = _sys.modules['edgedb.protocol.protocol'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/scram/__init__.py b/edgedb/scram/__init__.py index 62f3e712..e0ca2611 100644 --- a/edgedb/scram/__init__.py +++ b/edgedb/scram/__init__.py @@ -1,434 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Helpers for SCRAM authentication.""" - -import base64 -import hashlib -import hmac -import os -import typing - -from .saslprep import saslprep - - -RAW_NONCE_LENGTH = 18 - -# Per recommendations in RFC 7677. -DEFAULT_SALT_LENGTH = 16 -DEFAULT_ITERATIONS = 4096 - - -def generate_salt(length: int = DEFAULT_SALT_LENGTH) -> bytes: - return os.urandom(length) - - -def generate_nonce(length: int = RAW_NONCE_LENGTH) -> str: - return B64(os.urandom(length)) - - -def build_verifier(password: str, *, salt: typing.Optional[bytes] = None, - iterations: int = DEFAULT_ITERATIONS) -> str: - """Build the SCRAM verifier for the given password. - - Returns a string in the following format: - - "$:$:" - - The salt and keys are base64-encoded values. - """ - password = saslprep(password).encode('utf-8') - - if salt is None: - salt = generate_salt() - - salted_password = get_salted_password(password, salt, iterations) - client_key = get_client_key(salted_password) - stored_key = H(client_key) - server_key = get_server_key(salted_password) - - return (f'SCRAM-SHA-256${iterations}:{B64(salt)}$' - f'{B64(stored_key)}:{B64(server_key)}') - - -class SCRAMVerifier(typing.NamedTuple): - - mechanism: str - iterations: int - salt: bytes - stored_key: bytes - server_key: bytes - - -def parse_verifier(verifier: str) -> SCRAMVerifier: - - parts = verifier.split('$') - if len(parts) != 3: - raise ValueError('invalid SCRAM verifier') - - mechanism = parts[0] - if mechanism != 'SCRAM-SHA-256': - raise ValueError('invalid SCRAM verifier') - - iterations, _, salt = parts[1].partition(':') - stored_key, _, server_key = parts[2].partition(':') - if not salt or not server_key: - raise ValueError('invalid SCRAM verifier') - - try: - iterations = int(iterations) - except ValueError: - raise ValueError('invalid SCRAM verifier') from None - - return SCRAMVerifier( - mechanism=mechanism, - iterations=iterations, - salt=base64.b64decode(salt), - stored_key=base64.b64decode(stored_key), - server_key=base64.b64decode(server_key), - ) - - -def parse_client_first_message(resp: bytes): - - # Relevant bits of RFC 5802: - # - # saslname = 1*(value-safe-char / "=2C" / "=3D") - # ;; Conforms to . - # - # authzid = "a=" saslname - # ;; Protocol specific. - # - # cb-name = 1*(ALPHA / DIGIT / "." / "-") - # ;; See RFC 5056, Section 7. - # ;; E.g., "tls-server-end-point" or - # ;; "tls-unique". - # - # gs2-cbind-flag = ("p=" cb-name) / "n" / "y" - # ;; "n" -> client doesn't support channel binding. - # ;; "y" -> client does support channel binding - # ;; but thinks the server does not. - # ;; "p" -> client requires channel binding. - # ;; The selected channel binding follows "p=". - # - # gs2-header = gs2-cbind-flag "," [ authzid ] "," - # ;; GS2 header for SCRAM - # ;; (the actual GS2 header includes an optional - # ;; flag to indicate that the GSS mechanism is not - # ;; "standard", but since SCRAM is "standard", we - # ;; don't include that flag). - # - # username = "n=" saslname - # ;; Usernames are prepared using SASLprep. - # - # reserved-mext = "m=" 1*(value-char) - # ;; Reserved for signaling mandatory extensions. - # ;; The exact syntax will be defined in - # ;; the future. - # - # nonce = "r=" c-nonce [s-nonce] - # ;; Second part provided by server. - # - # c-nonce = printable - # - # client-first-message-bare = - # [reserved-mext ","] - # username "," nonce ["," extensions] - # - # client-first-message = - # gs2-header client-first-message-bare - - attrs = resp.split(b',') - - cb_attr = attrs[0] - if cb_attr == b'y': - cb = True - elif cb_attr == b'n': - cb = False - elif cb_attr[0:1] == b'p': - _, _, cb = cb_attr.partition(b'=') - if not cb: - raise ValueError('malformed SCRAM message') - else: - raise ValueError('malformed SCRAM message') - - authzid_attr = attrs[1] - if authzid_attr: - if authzid_attr[0:1] != b'a': - raise ValueError('malformed SCRAM message') - _, _, authzid = authzid_attr.partition(b'=') - else: - authzid = None - - user_attr = attrs[2] - if user_attr[0:1] == b'm': - raise ValueError('unsupported SCRAM extensions in message') - elif user_attr[0:1] != b'n': - raise ValueError('malformed SCRAM message') - - _, _, user = user_attr.partition(b'=') - - nonce_attr = attrs[3] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') - - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') - if not nonce.isprintable(): - raise ValueError('invalid characters in client nonce') - - # ["," extensions] are ignored - - return len(cb_attr) + 2, cb, authzid, user, nonce - - -def parse_client_final_message( - msg: bytes, client_nonce: str, server_nonce: str): - - # Relevant bits of RFC 5802: - # - # gs2-header = gs2-cbind-flag "," [ authzid ] "," - # ;; GS2 header for SCRAM - # ;; (the actual GS2 header includes an optional - # ;; flag to indicate that the GSS mechanism is not - # ;; "standard", but since SCRAM is "standard", we - # ;; don't include that flag). - # - # cbind-input = gs2-header [ cbind-data ] - # ;; cbind-data MUST be present for - # ;; gs2-cbind-flag of "p" and MUST be absent - # ;; for "y" or "n". - # - # channel-binding = "c=" base64 - # ;; base64 encoding of cbind-input. - # - # proof = "p=" base64 - # - # client-final-message-without-proof = - # channel-binding "," nonce ["," - # extensions] - # - # client-final-message = - # client-final-message-without-proof "," proof - - attrs = msg.split(b',') - - cb_attr = attrs[0] - if cb_attr[0:1] != b'c': - raise ValueError('malformed SCRAM message') - - _, _, cb_data = cb_attr.partition(b'=') - - nonce_attr = attrs[1] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') - - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') - - expected_nonce = f'{client_nonce}{server_nonce}' - - if nonce != expected_nonce: - raise ValueError( - 'invalid SCRAM client-final message: nonce does not match') - - proof = None - - for attr in attrs[2:]: - if attr[0:1] == b'p': - _, _, proof = attr.partition(b'=') - proof_attr_len = len(attr) - proof = base64.b64decode(proof) - elif proof is not None: - raise ValueError('malformed SCRAM message') - - if proof is None: - raise ValueError('malformed SCRAM message') - - return cb_data, proof, proof_attr_len + 1 - - -def build_client_first_message(client_nonce: str, username: str) -> str: - - bare = f'n={saslprep(username)},r={client_nonce}' - return f'n,,{bare}', bare - - -def build_server_first_message(server_nonce: str, client_nonce: str, - salt: bytes, iterations: int) -> str: - - return ( - f'r={client_nonce}{server_nonce},' - f's={B64(salt)},i={iterations}' - ) - - -def build_auth_message( - client_first_bare: bytes, - server_first: bytes, client_final: bytes) -> bytes: - - return b'%b,%b,%b' % (client_first_bare, server_first, client_final) - - -def build_client_final_message( - password: str, - salt: bytes, - iterations: int, - client_first_bare: bytes, - server_first: bytes, - server_nonce: str) -> str: - - client_final = f'c=biws,r={server_nonce}' - - AuthMessage = build_auth_message( - client_first_bare, server_first, client_final.encode('utf-8')) - - SaltedPassword = get_salted_password( - saslprep(password).encode('utf-8'), - salt, - iterations) - - ClientKey = get_client_key(SaltedPassword) - StoredKey = H(ClientKey) - ClientSignature = HMAC(StoredKey, AuthMessage) - ClientProof = XOR(ClientKey, ClientSignature) - - ServerKey = get_server_key(SaltedPassword) - ServerProof = HMAC(ServerKey, AuthMessage) - - return f'{client_final},p={B64(ClientProof)}', ServerProof - - -def build_server_final_message( - client_first_bare: bytes, server_first: bytes, - client_final: bytes, server_key: bytes) -> str: - - AuthMessage = build_auth_message( - client_first_bare, server_first, client_final) - ServerSignature = HMAC(server_key, AuthMessage) - return f'v={B64(ServerSignature)}' - - -def parse_server_first_message(msg: bytes): - - attrs = msg.split(b',') - - nonce_attr = attrs[0] - if nonce_attr[0:1] != b'r': - raise ValueError('malformed SCRAM message') - - _, _, nonce_bin = nonce_attr.partition(b'=') - nonce = nonce_bin.decode('ascii') - if not nonce.isprintable(): - raise ValueError('malformed SCRAM message') - - salt_attr = attrs[1] - if salt_attr[0:1] != b's': - raise ValueError('malformed SCRAM message') - - _, _, salt_b64 = salt_attr.partition(b'=') - salt = base64.b64decode(salt_b64) - - iter_attr = attrs[2] - if iter_attr[0:1] != b'i': - raise ValueError('malformed SCRAM message') - - _, _, iterations = iter_attr.partition(b'=') - - try: - itercount = int(iterations) - except ValueError: - raise ValueError('malformed SCRAM message') from None - - return nonce, salt, itercount - - -def parse_server_final_message(msg: bytes): - - attrs = msg.split(b',') - - nonce_attr = attrs[0] - if nonce_attr[0:1] != b'v': - raise ValueError('malformed SCRAM message') - - _, _, signature_b64 = nonce_attr.partition(b'=') - signature = base64.b64decode(signature_b64) - - return signature - - -def verify_password(password: bytes, verifier: str) -> bool: - """Check the given password against a verifier. - - Returns True if the password is OK, False otherwise. - """ - - password = saslprep(password).encode('utf-8') - v = parse_verifier(verifier) - salted_password = get_salted_password(password, v.salt, v.iterations) - computed_key = get_server_key(salted_password) - return v.server_key == computed_key - - -def verify_client_proof(client_first: bytes, server_first: bytes, - client_final: bytes, StoredKey: bytes, - ClientProof: bytes) -> bool: - AuthMessage = build_auth_message(client_first, server_first, client_final) - ClientSignature = HMAC(StoredKey, AuthMessage) - ClientKey = XOR(ClientProof, ClientSignature) - return H(ClientKey) == StoredKey - - -def B64(val: bytes) -> str: - """Return base64-encoded string representation of input binary data.""" - return base64.b64encode(val).decode() - - -def HMAC(key: bytes, msg: bytes) -> bytes: - return hmac.new(key, msg, digestmod=hashlib.sha256).digest() - - -def XOR(a: bytes, b: bytes) -> bytes: - if len(a) != len(b): - raise ValueError('scram.XOR received operands of unequal length') - xint = int.from_bytes(a, 'big') ^ int.from_bytes(b, 'big') - return xint.to_bytes(len(a), 'big') - - -def H(s: bytes) -> bytes: - return hashlib.sha256(s).digest() - - -def get_salted_password(password: bytes, salt: bytes, - iterations: int) -> bytes: - # U1 := HMAC(str, salt + INT(1)) - H_i = U_i = HMAC(password, salt + b'\x00\x00\x00\x01') - - for _ in range(iterations - 1): - U_i = HMAC(password, U_i) - H_i = XOR(H_i, U_i) - - return H_i - - -def get_client_key(salted_password: bytes) -> bytes: - return HMAC(salted_password, b'Client Key') - - -def get_server_key(salted_password: bytes) -> bytes: - return HMAC(salted_password, b'Server Key') +# Auto-generated shim +import gel.scram as _mod +import sys as _sys +_cur = _sys.modules['edgedb.scram'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/scram/saslprep.py b/edgedb/scram/saslprep.py index 79eb84d8..18128fe9 100644 --- a/edgedb/scram/saslprep.py +++ b/edgedb/scram/saslprep.py @@ -1,82 +1,11 @@ -# Copyright 2016-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import stringprep -import unicodedata - - -# RFC4013 section 2.3 prohibited output. -_PROHIBITED = ( - # A strict reading of RFC 4013 requires table c12 here, but - # characters from it are mapped to SPACE in the Map step. Can - # normalization reintroduce them somehow? - stringprep.in_table_c12, - stringprep.in_table_c21_c22, - stringprep.in_table_c3, - stringprep.in_table_c4, - stringprep.in_table_c5, - stringprep.in_table_c6, - stringprep.in_table_c7, - stringprep.in_table_c8, - stringprep.in_table_c9) - - -def saslprep(data: str, prohibit_unassigned_code_points=True): - """An implementation of RFC4013 SASLprep.""" - - if data == '': - return data - - if prohibit_unassigned_code_points: - prohibited = _PROHIBITED + (stringprep.in_table_a1,) - else: - prohibited = _PROHIBITED - - # RFC3454 section 2, step 1 - Map - # RFC4013 section 2.1 mappings - # Map Non-ASCII space characters to SPACE (U+0020). Map - # commonly mapped to nothing characters to, well, nothing. - in_table_c12 = stringprep.in_table_c12 - in_table_b1 = stringprep.in_table_b1 - data = u"".join( - [u"\u0020" if in_table_c12(elt) else elt - for elt in data if not in_table_b1(elt)]) - - # RFC3454 section 2, step 2 - Normalize - # RFC4013 section 2.2 normalization - data = unicodedata.ucd_3_2_0.normalize('NFKC', data) - - in_table_d1 = stringprep.in_table_d1 - if in_table_d1(data[0]): - if not in_table_d1(data[-1]): - # RFC3454, Section 6, #3. If a string contains any - # RandALCat character, the first and last characters - # MUST be RandALCat characters. - raise ValueError("SASLprep: failed bidirectional check") - # RFC3454, Section 6, #2. If a string contains any RandALCat - # character, it MUST NOT contain any LCat character. - prohibited = prohibited + (stringprep.in_table_d2,) - else: - # RFC3454, Section 6, #3. Following the logic of #3, if - # the first character is not a RandALCat, no other character - # can be either. - prohibited = prohibited + (in_table_d1,) - - # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi - for char in data: - if any(in_table(char) for in_table in prohibited): - raise ValueError( - "SASLprep: failed prohibited character check") - - return data +# Auto-generated shim +import gel.scram.saslprep as _mod +import sys as _sys +_cur = _sys.modules['edgedb.scram.saslprep'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 2e2d871a..91828a2b 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -1,224 +1,11 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import enum - -from . import abstract -from . import errors -from . import options - - -class TransactionState(enum.Enum): - NEW = 0 - STARTED = 1 - COMMITTED = 2 - ROLLEDBACK = 3 - FAILED = 4 - - -class BaseTransaction: - - __slots__ = ( - '_client', - '_connection', - '_options', - '_state', - '__retry', - '__iteration', - '__started', - ) - - def __init__(self, retry, client, iteration): - self._client = client - self._connection = None - self._options = retry._options.transaction_options - self._state = TransactionState.NEW - self.__retry = retry - self.__iteration = iteration - self.__started = False - - def is_active(self) -> bool: - return self._state is TransactionState.STARTED - - def __check_state_base(self, opname): - if self._state is TransactionState.COMMITTED: - raise errors.InterfaceError( - 'cannot {}; the transaction is already committed'.format( - opname)) - if self._state is TransactionState.ROLLEDBACK: - raise errors.InterfaceError( - 'cannot {}; the transaction is already rolled back'.format( - opname)) - if self._state is TransactionState.FAILED: - raise errors.InterfaceError( - 'cannot {}; the transaction is in error state'.format( - opname)) - - def __check_state(self, opname): - if self._state is not TransactionState.STARTED: - if self._state is TransactionState.NEW: - raise errors.InterfaceError( - 'cannot {}; the transaction is not yet started'.format( - opname)) - self.__check_state_base(opname) - - def _make_start_query(self): - self.__check_state_base('start') - if self._state is TransactionState.STARTED: - raise errors.InterfaceError( - 'cannot start; the transaction is already started') - - return self._options.start_transaction_query() - - def _make_commit_query(self): - self.__check_state('commit') - return 'COMMIT;' - - def _make_rollback_query(self): - self.__check_state('rollback') - return 'ROLLBACK;' - - def __repr__(self): - attrs = [] - attrs.append('state:{}'.format(self._state.name.lower())) - attrs.append(repr(self._options)) - - if self.__class__.__module__.startswith('edgedb.'): - mod = 'edgedb' - else: - mod = self.__class__.__module__ - - return '<{}.{} {} {:#x}>'.format( - mod, self.__class__.__name__, ' '.join(attrs), id(self)) - - async def _ensure_transaction(self): - if not self.__started: - self.__started = True - query = self._make_start_query() - self._connection = await self._client._impl.acquire() - if self._connection.is_closed(): - await self._connection.connect( - single_attempt=self.__iteration != 0 - ) - try: - await self._privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - raise - else: - self._state = TransactionState.STARTED - - async def _exit(self, extype, ex): - if not self.__started: - return False - - try: - if extype is None: - query = self._make_commit_query() - state = TransactionState.COMMITTED - else: - query = self._make_rollback_query() - state = TransactionState.ROLLEDBACK - try: - await self._privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - if extype is None: - # COMMIT itself may fail; recover in connection - await self._privileged_execute("ROLLBACK;") - raise - else: - self._state = state - except errors.EdgeDBError as err: - if ex is None: - # On commit we don't know if commit is succeeded before the - # database have received it or after it have been done but - # network is dropped before we were able to receive a response. - # On a TransactionError, though, we know the we need - # to retry. - # TODO(tailhook) should other errors have retries? - if ( - isinstance(err, errors.TransactionError) - and err.has_tag(errors.SHOULD_RETRY) - and self.__retry._retry(err) - ): - pass - else: - raise err - # If we were going to rollback, look at original error - # to find out whether we want to retry, regardless of - # the rollback error. - # In this case we ignore rollback issue as original error is more - # important, e.g. in case `CancelledError` it's important - # to propagate it to cancel the whole task. - # NOTE: rollback error is always swallowed, should we use - # on_log_message for it? - finally: - await self._client._impl.release(self._connection) - - if ( - extype is not None and - issubclass(extype, errors.EdgeDBError) and - ex.has_tag(errors.SHOULD_RETRY) - ): - return self.__retry._retry(ex) - - def _get_query_cache(self) -> abstract.QueryCache: - return self._client._get_query_cache() - - def _get_state(self) -> options.State: - return self._client._get_state() - - def _get_warning_handler(self) -> options.WarningHandler: - return self._client._get_warning_handler() - - async def _query(self, query_context: abstract.QueryContext): - await self._ensure_transaction() - return await self._connection.raw_query(query_context) - - async def _execute(self, execute_context: abstract.ExecuteContext) -> None: - await self._ensure_transaction() - await self._connection._execute(execute_context) - - async def _privileged_execute(self, query: str) -> None: - await self._connection.privileged_execute(abstract.ExecuteContext( - query=abstract.QueryWithArgs(query, (), {}), - cache=self._get_query_cache(), - state=self._get_state(), - warning_handler=self._get_warning_handler(), - )) - - -class BaseRetry: - - def __init__(self, owner): - self._owner = owner - self._iteration = 0 - self._done = False - self._next_backoff = 0 - self._options = owner._options - - def _retry(self, exc): - self._last_exception = exc - rule = self._options.retry_options.get_rule_for_exception(exc) - if self._iteration >= rule.attempts: - return False - self._done = False - self._next_backoff = rule.backoff(self._iteration) - return True +# Auto-generated shim +import gel.transaction as _mod +import sys as _sys +_cur = _sys.modules['edgedb.transaction'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k diff --git a/gel/__init__.py b/gel/__init__.py new file mode 100644 index 00000000..c19f7778 --- /dev/null +++ b/gel/__init__.py @@ -0,0 +1,284 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# flake8: noqa + +from ._version import __version__ + +from gel.datatypes.datatypes import ( + Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory +) +from gel.datatypes.datatypes import Set, Object, Array +from gel.datatypes.range import Range, MultiRange + +from .abstract import ( + Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor, +) + +from .asyncio_client import ( + create_async_client, + AsyncIOClient +) + +from .blocking_client import create_client, Client +from .enums import Cardinality, ElementKind +from .options import RetryCondition, IsolationLevel, default_backoff +from .options import RetryOptions, TransactionOptions +from .options import State + +from .errors._base import EdgeDBError, EdgeDBMessage + +__all__ = [ + "Array", + "AsyncIOClient", + "AsyncIOExecutor", + "AsyncIOReadOnlyExecutor", + "Cardinality", + "Client", + "ConfigMemory", + "DateDuration", + "EdgeDBError", + "EdgeDBMessage", + "ElementKind", + "EnumValue", + "Executor", + "IsolationLevel", + "NamedTuple", + "Object", + "Range", + "ReadOnlyExecutor", + "RelativeDuration", + "RetryCondition", + "RetryOptions", + "Set", + "State", + "TransactionOptions", + "Tuple", + "create_async_client", + "create_client", + "default_backoff", +] + + +# The below is generated by `make gen-errors`. +# DO NOT MODIFY BY HAND. +# +# +from .errors import ( + InternalServerError, + UnsupportedFeatureError, + ProtocolError, + BinaryProtocolError, + UnsupportedProtocolVersionError, + TypeSpecNotFoundError, + UnexpectedMessageError, + InputDataError, + ParameterTypeMismatchError, + StateMismatchError, + ResultCardinalityMismatchError, + CapabilityError, + UnsupportedCapabilityError, + DisabledCapabilityError, + QueryError, + InvalidSyntaxError, + EdgeQLSyntaxError, + SchemaSyntaxError, + GraphQLSyntaxError, + InvalidTypeError, + InvalidTargetError, + InvalidLinkTargetError, + InvalidPropertyTargetError, + InvalidReferenceError, + UnknownModuleError, + UnknownLinkError, + UnknownPropertyError, + UnknownUserError, + UnknownDatabaseError, + UnknownParameterError, + SchemaError, + SchemaDefinitionError, + InvalidDefinitionError, + InvalidModuleDefinitionError, + InvalidLinkDefinitionError, + InvalidPropertyDefinitionError, + InvalidUserDefinitionError, + InvalidDatabaseDefinitionError, + InvalidOperatorDefinitionError, + InvalidAliasDefinitionError, + InvalidFunctionDefinitionError, + InvalidConstraintDefinitionError, + InvalidCastDefinitionError, + DuplicateDefinitionError, + DuplicateModuleDefinitionError, + DuplicateLinkDefinitionError, + DuplicatePropertyDefinitionError, + DuplicateUserDefinitionError, + DuplicateDatabaseDefinitionError, + DuplicateOperatorDefinitionError, + DuplicateViewDefinitionError, + DuplicateFunctionDefinitionError, + DuplicateConstraintDefinitionError, + DuplicateCastDefinitionError, + DuplicateMigrationError, + SessionTimeoutError, + IdleSessionTimeoutError, + QueryTimeoutError, + TransactionTimeoutError, + IdleTransactionTimeoutError, + ExecutionError, + InvalidValueError, + DivisionByZeroError, + NumericOutOfRangeError, + AccessPolicyError, + QueryAssertionError, + IntegrityError, + ConstraintViolationError, + CardinalityViolationError, + MissingRequiredError, + TransactionError, + TransactionConflictError, + TransactionSerializationError, + TransactionDeadlockError, + WatchError, + ConfigurationError, + AccessError, + AuthenticationError, + AvailabilityError, + BackendUnavailableError, + ServerOfflineError, + BackendError, + UnsupportedBackendFeatureError, + LogMessage, + WarningMessage, + ClientError, + ClientConnectionError, + ClientConnectionFailedError, + ClientConnectionFailedTemporarilyError, + ClientConnectionTimeoutError, + ClientConnectionClosedError, + InterfaceError, + QueryArgumentError, + MissingArgumentError, + UnknownArgumentError, + InvalidArgumentError, + NoDataError, + InternalClientError, +) + +__all__.extend([ + "InternalServerError", + "UnsupportedFeatureError", + "ProtocolError", + "BinaryProtocolError", + "UnsupportedProtocolVersionError", + "TypeSpecNotFoundError", + "UnexpectedMessageError", + "InputDataError", + "ParameterTypeMismatchError", + "StateMismatchError", + "ResultCardinalityMismatchError", + "CapabilityError", + "UnsupportedCapabilityError", + "DisabledCapabilityError", + "QueryError", + "InvalidSyntaxError", + "EdgeQLSyntaxError", + "SchemaSyntaxError", + "GraphQLSyntaxError", + "InvalidTypeError", + "InvalidTargetError", + "InvalidLinkTargetError", + "InvalidPropertyTargetError", + "InvalidReferenceError", + "UnknownModuleError", + "UnknownLinkError", + "UnknownPropertyError", + "UnknownUserError", + "UnknownDatabaseError", + "UnknownParameterError", + "SchemaError", + "SchemaDefinitionError", + "InvalidDefinitionError", + "InvalidModuleDefinitionError", + "InvalidLinkDefinitionError", + "InvalidPropertyDefinitionError", + "InvalidUserDefinitionError", + "InvalidDatabaseDefinitionError", + "InvalidOperatorDefinitionError", + "InvalidAliasDefinitionError", + "InvalidFunctionDefinitionError", + "InvalidConstraintDefinitionError", + "InvalidCastDefinitionError", + "DuplicateDefinitionError", + "DuplicateModuleDefinitionError", + "DuplicateLinkDefinitionError", + "DuplicatePropertyDefinitionError", + "DuplicateUserDefinitionError", + "DuplicateDatabaseDefinitionError", + "DuplicateOperatorDefinitionError", + "DuplicateViewDefinitionError", + "DuplicateFunctionDefinitionError", + "DuplicateConstraintDefinitionError", + "DuplicateCastDefinitionError", + "DuplicateMigrationError", + "SessionTimeoutError", + "IdleSessionTimeoutError", + "QueryTimeoutError", + "TransactionTimeoutError", + "IdleTransactionTimeoutError", + "ExecutionError", + "InvalidValueError", + "DivisionByZeroError", + "NumericOutOfRangeError", + "AccessPolicyError", + "QueryAssertionError", + "IntegrityError", + "ConstraintViolationError", + "CardinalityViolationError", + "MissingRequiredError", + "TransactionError", + "TransactionConflictError", + "TransactionSerializationError", + "TransactionDeadlockError", + "WatchError", + "ConfigurationError", + "AccessError", + "AuthenticationError", + "AvailabilityError", + "BackendUnavailableError", + "ServerOfflineError", + "BackendError", + "UnsupportedBackendFeatureError", + "LogMessage", + "WarningMessage", + "ClientError", + "ClientConnectionError", + "ClientConnectionFailedError", + "ClientConnectionFailedTemporarilyError", + "ClientConnectionTimeoutError", + "ClientConnectionClosedError", + "InterfaceError", + "QueryArgumentError", + "MissingArgumentError", + "UnknownArgumentError", + "InvalidArgumentError", + "NoDataError", + "InternalClientError", +]) +# diff --git a/gel/_taskgroup.py b/gel/_taskgroup.py new file mode 100644 index 00000000..0a1859d7 --- /dev/null +++ b/gel/_taskgroup.py @@ -0,0 +1,295 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import functools +import itertools +import textwrap +import traceback + + +class TaskGroup: + + def __init__(self, *, name=None): + if name is None: + self._name = f'tg-{_name_counter()}' + else: + self._name = str(name) + + self._entered = False + self._exiting = False + self._aborting = False + self._loop = None + self._parent_task = None + self._parent_cancel_requested = False + self._tasks = set() + self._unfinished_tasks = 0 + self._errors = [] + self._base_error = None + self._on_completed_fut = None + + def get_name(self): + return self._name + + def __repr__(self): + msg = f'= 0 + + if self._exiting and not self._unfinished_tasks: + if not self._on_completed_fut.done(): + self._on_completed_fut.set_result(True) + + if task.cancelled(): + return + + exc = task.exception() + if exc is None: + return + + self._errors.append(exc) + if self._is_base_error(exc) and self._base_error is None: + self._base_error = exc + + if self._parent_task.done(): + # Not sure if this case is possible, but we want to handle + # it anyways. + self._loop.call_exception_handler({ + 'message': f'Task {task!r} has errored out but its parent ' + f'task {self._parent_task} is already completed', + 'exception': exc, + 'task': task, + }) + return + + self._abort() + if not self._parent_task.__cancel_requested__: + # If parent task *is not* being cancelled, it means that we want + # to manually cancel it to abort whatever is being run right now + # in the TaskGroup. But we want to mark parent task as + # "not cancelled" later in __aexit__. Example situation that + # we need to handle: + # + # async def foo(): + # try: + # async with TaskGroup() as g: + # g.create_task(crash_soon()) + # await something # <- this needs to be canceled + # # by the TaskGroup, e.g. + # # foo() needs to be cancelled + # except Exception: + # # Ignore any exceptions raised in the TaskGroup + # pass + # await something_else # this line has to be called + # # after TaskGroup is finished. + self._parent_cancel_requested = True + self._parent_task.cancel() + + +class MultiError(Exception): + + def __init__(self, msg, *args, errors=()): + if errors: + types = set(type(e).__name__ for e in errors) + msg = f'{msg}; {len(errors)} sub errors: ({", ".join(types)})' + for er in errors: + msg += f'\n + {type(er).__name__}: {er}' + if er.__traceback__: + er_tb = ''.join(traceback.format_tb(er.__traceback__)) + er_tb = textwrap.indent(er_tb, ' | ') + msg += f'\n{er_tb}\n' + super().__init__(msg, *args) + self.__errors__ = tuple(errors) + + def get_error_types(self): + return {type(e) for e in self.__errors__} + + def __reduce__(self): + return (type(self), (self.args,), {'__errors__': self.__errors__}) + + +class TaskGroupError(MultiError): + pass + + +_name_counter = itertools.count(1).__next__ diff --git a/edgedb/_testbase.py b/gel/_testbase.py similarity index 96% rename from edgedb/_testbase.py rename to gel/_testbase.py index 167c9fb5..ef97d81b 100644 --- a/edgedb/_testbase.py +++ b/gel/_testbase.py @@ -32,9 +32,9 @@ import time import unittest -import edgedb -from edgedb import asyncio_client -from edgedb import blocking_client +import gel +from gel import asyncio_client +from gel import blocking_client log = logging.getLogger(__name__) @@ -89,9 +89,9 @@ def _start_cluster(*, cleanup_atexit=True): # not interfere with the server's. env.pop('PYTHONPATH', None) - edgedb_server = env.get('EDGEDB_SERVER_BINARY', 'edgedb-server') + gel_server = env.get('EDGEDB_SERVER_BINARY', 'edgedb-server') args = [ - edgedb_server, + gel_server, "--temp-dir", "--testmode", f"--emit-server-status={status_file_unix}", @@ -100,7 +100,7 @@ def _start_cluster(*, cleanup_atexit=True): "--bootstrap-command=ALTER ROLE edgedb { SET password := 'test' }", ] - help_args = [edgedb_server, "--help"] + help_args = [gel_server, "--help"] if sys.platform == 'win32': help_args = ['wsl', '-u', 'edgedb'] + help_args @@ -173,7 +173,7 @@ def _start_cluster(*, cleanup_atexit=True): else: con_args['tls_ca_file'] = data['tls_cert_file'] - client = edgedb.create_client(password='test', **con_args) + client = gel.create_client(password='test', **con_args) client.ensure_connected() client.execute(""" # Set session_idle_transaction_timeout to 5 minutes. @@ -237,7 +237,7 @@ def wrapper(self, *args, __meth__=meth, **kwargs): # retry the test. self.loop.run_until_complete( __meth__(self, *args, **kwargs)) - except edgedb.TransactionSerializationError: + except gel.TransactionSerializationError: if try_no == 3: raise else: @@ -335,7 +335,7 @@ def setUpClass(cls): cls.cluster = _start_cluster(cleanup_atexit=True) -class TestAsyncIOClient(edgedb.AsyncIOClient): +class TestAsyncIOClient(gel.AsyncIOClient): def _clear_codecs_cache(self): self._impl.codecs_registry.clear_cache() @@ -352,7 +352,7 @@ def is_proto_lt_1_0(self): return self.connection._protocol.is_legacy -class TestClient(edgedb.Client): +class TestClient(gel.Client): @property def connection(self): return self._impl._holders[0]._con @@ -560,12 +560,12 @@ def tearDownClass(cls): try: cls.adapt_call( cls.admin_client.execute(script)) - except edgedb.errors.ExecutionError: + except gel.errors.ExecutionError: if i < retry - 1: time.sleep(0.1) else: raise - except edgedb.errors.UnknownDatabaseError: + except gel.errors.UnknownDatabaseError: break except Exception: diff --git a/gel/_version.py b/gel/_version.py new file mode 100644 index 00000000..41998d40 --- /dev/null +++ b/gel/_version.py @@ -0,0 +1,31 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file MUST NOT contain anything but the __version__ assignment. +# +# When making a release, change the value of __version__ +# to an appropriate value, and open a pull request against +# the correct branch (master if making a new feature release). +# The commit message MUST contain a properly formatted release +# log, and the commit must be signed. +# +# The release automation will: build and test the packages for the +# supported platforms, publish the packages on PyPI, merge the PR +# to the target branch, create a Git tag pointing to the commit. + +__version__ = '3.0.0b2' diff --git a/gel/abstract.py b/gel/abstract.py new file mode 100644 index 00000000..0c2c06cc --- /dev/null +++ b/gel/abstract.py @@ -0,0 +1,437 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2020-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations +import abc +import dataclasses +import typing + +from . import describe +from . import enums +from . import options +from .protocol import protocol + +__all__ = ( + "QueryWithArgs", + "QueryCache", + "QueryOptions", + "QueryContext", + "Executor", + "AsyncIOExecutor", + "ReadOnlyExecutor", + "AsyncIOReadOnlyExecutor", + "DescribeContext", + "DescribeResult", +) + + +class QueryWithArgs(typing.NamedTuple): + query: str + args: typing.Tuple + kwargs: typing.Dict[str, typing.Any] + input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL + + +class QueryCache(typing.NamedTuple): + codecs_registry: protocol.CodecsRegistry + query_cache: protocol.LRUMapping + + +class QueryOptions(typing.NamedTuple): + output_format: protocol.OutputFormat + expect_one: bool + required_one: bool + + +class QueryContext(typing.NamedTuple): + query: QueryWithArgs + cache: QueryCache + query_options: QueryOptions + retry_options: typing.Optional[options.RetryOptions] + state: typing.Optional[options.State] + warning_handler: options.WarningHandler + + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + input_language=self.query.input_language, + output_format=self.query_options.output_format, + expect_one=self.query_options.expect_one, + required_one=self.query_options.required_one, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + + +class ExecuteContext(typing.NamedTuple): + query: QueryWithArgs + cache: QueryCache + state: typing.Optional[options.State] + warning_handler: options.WarningHandler + + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + input_language=self.query.input_language, + output_format=protocol.OutputFormat.NONE, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + + +@dataclasses.dataclass +class DescribeContext: + query: str + state: typing.Optional[options.State] + inject_type_names: bool + input_language: protocol.InputLanguage + output_format: protocol.OutputFormat + expect_one: bool + + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query, + args=None, + kwargs=None, + reg=protocol.CodecsRegistry(), + qc=protocol.LRUMapping(maxsize=1), + input_language=self.input_language, + output_format=self.output_format, + expect_one=self.expect_one, + inline_typenames=self.inject_type_names, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + + +@dataclasses.dataclass +class DescribeResult: + input_type: typing.Optional[describe.AnyType] + output_type: typing.Optional[describe.AnyType] + output_cardinality: enums.Cardinality + capabilities: enums.Capability + + +_query_opts = QueryOptions( + output_format=protocol.OutputFormat.BINARY, + expect_one=False, + required_one=False, +) +_query_single_opts = QueryOptions( + output_format=protocol.OutputFormat.BINARY, + expect_one=True, + required_one=False, +) +_query_required_single_opts = QueryOptions( + output_format=protocol.OutputFormat.BINARY, + expect_one=True, + required_one=True, +) +_query_json_opts = QueryOptions( + output_format=protocol.OutputFormat.JSON, + expect_one=False, + required_one=False, +) +_query_single_json_opts = QueryOptions( + output_format=protocol.OutputFormat.JSON, + expect_one=True, + required_one=False, +) +_query_required_single_json_opts = QueryOptions( + output_format=protocol.OutputFormat.JSON, + expect_one=True, + required_one=True, +) + + +class BaseReadOnlyExecutor(abc.ABC): + __slots__ = () + + @abc.abstractmethod + def _get_query_cache(self) -> QueryCache: + ... + + def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: + return None + + @abc.abstractmethod + def _get_state(self) -> options.State: + ... + + @abc.abstractmethod + def _get_warning_handler(self) -> options.WarningHandler: + ... + + +class ReadOnlyExecutor(BaseReadOnlyExecutor): + """Subclasses can execute *at least* read-only queries""" + + __slots__ = () + + @abc.abstractmethod + def _query(self, query_context: QueryContext): + ... + + def query(self, query: str, *args, **kwargs) -> list: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_single( + self, query: str, *args, **kwargs + ) -> typing.Union[typing.Any, None]: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_json(self, query: str, *args, **kwargs) -> str: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_single_json(self, query: str, *args, **kwargs) -> str: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_required_single_json(self, query: str, *args, **kwargs) -> str: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def query_sql(self, query: str, *args, **kwargs) -> typing.Any: + return self._query(QueryContext( + query=QueryWithArgs( + query, + args, + kwargs, + input_language=protocol.InputLanguage.SQL, + ), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + @abc.abstractmethod + def _execute(self, execute_context: ExecuteContext): + ... + + def execute(self, commands: str, *args, **kwargs) -> None: + self._execute(ExecuteContext( + query=QueryWithArgs(commands, args, kwargs), + cache=self._get_query_cache(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + def execute_sql(self, commands: str, *args, **kwargs) -> None: + self._execute(ExecuteContext( + query=QueryWithArgs( + commands, + args, + kwargs, + input_language=protocol.InputLanguage.SQL, + ), + cache=self._get_query_cache(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + +class Executor(ReadOnlyExecutor): + """Subclasses can execute both read-only and modification queries""" + + __slots__ = () + + +class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor): + """Subclasses can execute *at least* read-only queries""" + + __slots__ = () + + @abc.abstractmethod + async def _query(self, query_context: QueryContext): + ... + + async def query(self, query: str, *args, **kwargs) -> list: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_single(self, query: str, *args, **kwargs) -> typing.Any: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_required_single( + self, + query: str, + *args, + **kwargs + ) -> typing.Any: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_json(self, query: str, *args, **kwargs) -> str: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_single_json(self, query: str, *args, **kwargs) -> str: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_required_single_json( + self, + query: str, + *args, + **kwargs + ) -> str: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def query_sql(self, query: str, *args, **kwargs) -> typing.Any: + return await self._query(QueryContext( + query=QueryWithArgs( + query, + args, + kwargs, + input_language=protocol.InputLanguage.SQL, + ), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + @abc.abstractmethod + async def _execute(self, execute_context: ExecuteContext) -> None: + ... + + async def execute(self, commands: str, *args, **kwargs) -> None: + await self._execute(ExecuteContext( + query=QueryWithArgs(commands, args, kwargs), + cache=self._get_query_cache(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + async def execute_sql(self, commands: str, *args, **kwargs) -> None: + await self._execute(ExecuteContext( + query=QueryWithArgs( + commands, + args, + kwargs, + input_language=protocol.InputLanguage.SQL, + ), + cache=self._get_query_cache(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + +class AsyncIOExecutor(AsyncIOReadOnlyExecutor): + """Subclasses can execute both read-only and modification queries""" + + __slots__ = () diff --git a/gel/ai/__init__.py b/gel/ai/__init__.py new file mode 100644 index 00000000..96111c2b --- /dev/null +++ b/gel/ai/__init__.py @@ -0,0 +1,32 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext +from .core import create_ai, EdgeDBAI +from .core import create_async_ai, AsyncEdgeDBAI + +__all__ = [ + "AIOptions", + "ChatParticipantRole", + "Prompt", + "QueryContext", + "create_ai", + "EdgeDBAI", + "create_async_ai", + "AsyncEdgeDBAI", +] diff --git a/gel/ai/core.py b/gel/ai/core.py new file mode 100644 index 00000000..2822bae4 --- /dev/null +++ b/gel/ai/core.py @@ -0,0 +1,191 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations +import typing + +import gel +import httpx +import httpx_sse + +from . import types + + +def create_ai(client: gel.Client, **kwargs) -> EdgeDBAI: + client.ensure_connected() + return EdgeDBAI(client, types.AIOptions(**kwargs)) + + +async def create_async_ai( + client: gel.AsyncIOClient, **kwargs +) -> AsyncEdgeDBAI: + await client.ensure_connected() + return AsyncEdgeDBAI(client, types.AIOptions(**kwargs)) + + +class BaseEdgeDBAI: + options: types.AIOptions + context: types.QueryContext + client_cls = NotImplemented + + def __init__( + self, + client: typing.Union[gel.Client, gel.AsyncIOClient], + options: types.AIOptions, + **kwargs, + ): + pool = client._impl + host, port = pool._working_addr + params = pool._working_params + proto = "http" if params.tls_security == "insecure" else "https" + branch = params.branch + self.options = options + self.context = types.QueryContext(**kwargs) + args = dict( + base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai", + verify=params.ssl_ctx, + ) + if params.password is not None: + args["auth"] = (params.user, params.password) + elif params.secret_key is not None: + args["headers"] = {"Authorization": f"Bearer {params.secret_key}"} + self._init_client(**args) + + def _init_client(self, **kwargs): + raise NotImplementedError + + def with_config(self, **kwargs) -> typing.Self: + cls = type(self) + rv = cls.__new__(cls) + rv.options = self.options.derive(kwargs) + rv.context = self.context + rv.client = self.client + return rv + + def with_context(self, **kwargs) -> typing.Self: + cls = type(self) + rv = cls.__new__(cls) + rv.options = self.options + rv.context = self.context.derive(kwargs) + rv.client = self.client + return rv + + def _make_rag_request( + self, + *, + message: str, + context: typing.Optional[types.QueryContext] = None, + stream: bool, + ) -> types.RAGRequest: + if context is None: + context = self.context + return types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=stream, + ) + + +class EdgeDBAI(BaseEdgeDBAI): + client: httpx.Client + + def _init_client(self, **kwargs): + self.client = httpx.Client(**kwargs) + + def query_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> str: + resp = self.client.post( + **self._make_rag_request( + context=context, + message=message, + stream=False, + ).to_httpx_request() + ) + resp.raise_for_status() + return resp.json()["response"] + + def stream_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> typing.Iterator[str]: + with httpx_sse.connect_sse( + self.client, + "post", + **self._make_rag_request( + context=context, + message=message, + stream=True, + ).to_httpx_request(), + ) as event_source: + event_source.response.raise_for_status() + for sse in event_source.iter_sse(): + yield sse.data + + def generate_embeddings(self, *inputs: str, model: str) -> list[float]: + resp = self.client.post( + "/embeddings", json={"input": inputs, "model": model} + ) + resp.raise_for_status() + return resp.json()["data"][0]["embedding"] + + +class AsyncEdgeDBAI(BaseEdgeDBAI): + client: httpx.AsyncClient + + def _init_client(self, **kwargs): + self.client = httpx.AsyncClient(**kwargs) + + async def query_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> str: + resp = await self.client.post( + **self._make_rag_request( + context=context, + message=message, + stream=False, + ).to_httpx_request() + ) + resp.raise_for_status() + return resp.json()["response"] + + async def stream_rag( + self, message: str, context: typing.Optional[types.QueryContext] = None + ) -> typing.Iterator[str]: + async with httpx_sse.aconnect_sse( + self.client, + "post", + **self._make_rag_request( + context=context, + message=message, + stream=True, + ).to_httpx_request(), + ) as event_source: + event_source.response.raise_for_status() + async for sse in event_source.aiter_sse(): + yield sse.data + + async def generate_embeddings( + self, *inputs: str, model: str + ) -> list[float]: + resp = await self.client.post( + "/embeddings", json={"input": inputs, "model": model} + ) + resp.raise_for_status() + return resp.json()["data"][0]["embedding"] diff --git a/gel/ai/types.py b/gel/ai/types.py new file mode 100644 index 00000000..41bf24c0 --- /dev/null +++ b/gel/ai/types.py @@ -0,0 +1,81 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing + +import dataclasses as dc +import enum + + +class ChatParticipantRole(enum.Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +class Custom(typing.TypedDict): + role: ChatParticipantRole + content: str + + +class Prompt: + name: typing.Optional[str] + id: typing.Optional[str] + custom: typing.Optional[typing.List[Custom]] + + +@dc.dataclass +class AIOptions: + model: str + prompt: typing.Optional[Prompt] = None + + def derive(self, kwargs): + return AIOptions(**{**dc.asdict(self), **kwargs}) + + +@dc.dataclass +class QueryContext: + query: str = "" + variables: typing.Optional[typing.Dict[str, typing.Any]] = None + globals: typing.Optional[typing.Dict[str, typing.Any]] = None + max_object_count: typing.Optional[int] = None + + def derive(self, kwargs): + return QueryContext(**{**dc.asdict(self), **kwargs}) + + +@dc.dataclass +class RAGRequest: + model: str + prompt: typing.Optional[Prompt] + context: QueryContext + query: str + stream: typing.Optional[bool] + + def to_httpx_request(self) -> typing.Dict[str, typing.Any]: + return dict( + url="/rag", + headers={ + "Content-Type": "application/json", + "Accept": ( + "text/event-stream" if self.stream else "application/json" + ), + }, + json=dc.asdict(self), + ) diff --git a/gel/asyncio_client.py b/gel/asyncio_client.py new file mode 100644 index 00000000..81cbc808 --- /dev/null +++ b/gel/asyncio_client.py @@ -0,0 +1,448 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import contextlib +import logging +import socket +import ssl +import typing + +from . import abstract +from . import base_client +from . import con_utils +from . import errors +from . import transaction +from .protocol import asyncio_proto +from .protocol.protocol import InputLanguage, OutputFormat + + +__all__ = ( + 'create_async_client', 'AsyncIOClient' +) + + +logger = logging.getLogger(__name__) + + +class AsyncIOConnection(base_client.BaseConnection): + __slots__ = ("_loop",) + + def __init__(self, loop, *args, **kwargs): + super().__init__(*args, **kwargs) + self._loop = loop + + def is_closed(self): + protocol = self._protocol + return protocol is None or not protocol.connected + + async def connect_addr(self, addr, timeout): + try: + await asyncio.wait_for(self._connect_addr(addr), timeout) + except asyncio.TimeoutError as e: + raise TimeoutError from e + + async def sleep(self, seconds): + await asyncio.sleep(seconds) + + async def aclose(self): + """Send graceful termination message wait for connection to drop.""" + if not self.is_closed(): + try: + self._protocol.terminate() + await self._protocol.wait_for_disconnect() + except (Exception, asyncio.CancelledError): + self.terminate() + raise + finally: + self._cleanup() + + def _protocol_factory(self): + return asyncio_proto.AsyncIOProtocol(self._params, self._loop) + + async def _connect_addr(self, addr): + tr = None + + try: + if isinstance(addr, str): + # UNIX socket + tr, pr = await self._loop.create_unix_connection( + self._protocol_factory, addr + ) + else: + try: + tr, pr = await self._loop.create_connection( + self._protocol_factory, + *addr, + ssl=self._params.ssl_ctx, + server_hostname=( + self._params.tls_server_name or addr[0] + ), + ) + except ssl.CertificateError as e: + raise con_utils.wrap_error(e) from e + except ssl.SSLError as e: + raise con_utils.wrap_error(e) from e + else: + con_utils.check_alpn_protocol( + tr.get_extra_info('ssl_object') + ) + except socket.gaierror as e: + # All name resolution errors are considered temporary + raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e + except OSError as e: + raise con_utils.wrap_error(e) from e + except Exception: + if tr is not None: + tr.close() + raise + + pr.set_connection(self) + + try: + await pr.connect() + except OSError as e: + if tr is not None: + tr.close() + raise con_utils.wrap_error(e) from e + except BaseException: + if tr is not None: + tr.close() + raise + + self._protocol = pr + self._addr = addr + + def _dispatch_log_message(self, msg): + for cb in self._log_listeners: + self._loop.call_soon(cb, self, msg) + + +class _PoolConnectionHolder(base_client.PoolConnectionHolder): + __slots__ = () + _event_class = asyncio.Event + + async def close(self, *, wait=True): + if self._con is None: + return + if wait: + await self._con.aclose() + else: + self._pool._loop.create_task(self._con.aclose()) + + async def wait_until_released(self, timeout=None): + await self._release_event.wait() + + +class _AsyncIOPoolImpl(base_client.BasePoolImpl): + __slots__ = ('_loop',) + _holder_class = _PoolConnectionHolder + + def __init__( + self, + connect_args, + *, + max_concurrency: typing.Optional[int], + connection_class, + ): + if not issubclass(connection_class, AsyncIOConnection): + raise TypeError( + f'connection_class is expected to be a subclass of ' + f'gel.asyncio_client.AsyncIOConnection, ' + f'got {connection_class}') + self._loop = None + super().__init__( + connect_args, + lambda *args: connection_class(self._loop, *args), + max_concurrency=max_concurrency, + ) + + def _ensure_initialized(self): + if self._loop is None: + self._loop = asyncio.get_event_loop() + self._queue = asyncio.LifoQueue(maxsize=self._max_concurrency) + self._first_connect_lock = asyncio.Lock() + self._resize_holder_pool() + + def _set_queue_maxsize(self, maxsize): + self._queue._maxsize = maxsize + + async def _maybe_get_first_connection(self): + async with self._first_connect_lock: + if self._working_addr is None: + return await self._get_first_connection() + + async def acquire(self, timeout=None): + self._ensure_initialized() + + async def _acquire_impl(): + ch = await self._queue.get() # type: _PoolConnectionHolder + try: + proxy = await ch.acquire() # type: AsyncIOConnection + except (Exception, asyncio.CancelledError): + self._queue.put_nowait(ch) + raise + else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout + return proxy + + if self._closing: + raise errors.InterfaceError('pool is closing') + + if timeout is None: + return await _acquire_impl() + else: + return await asyncio.wait_for( + _acquire_impl(), timeout=timeout) + + async def _release(self, holder): + + if not isinstance(holder._con, AsyncIOConnection): + raise errors.InterfaceError( + f'release() received invalid connection: ' + f'{holder._con!r} does not belong to any connection pool' + ) + + timeout = None + + # Use asyncio.shield() to guarantee that task cancellation + # does not prevent the connection from being returned to the + # pool properly. + return await asyncio.shield(holder.release(timeout)) + + async def aclose(self): + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + _AsyncIOPoolImpl.terminate() . + + It is advisable to use :func:`python:asyncio.wait_for` to set + a timeout. + """ + if self._closed: + return + + if not self._loop: + self._closed = True + return + + self._closing = True + + try: + warning_callback = self._loop.call_later( + 60, self._warn_on_long_close) + + release_coros = [ + ch.wait_until_released() for ch in self._holders] + await asyncio.gather(*release_coros) + + close_coros = [ + ch.close() for ch in self._holders] + await asyncio.gather(*close_coros) + + except (Exception, asyncio.CancelledError): + self.terminate() + raise + + finally: + warning_callback.cancel() + self._closed = True + self._closing = False + + def _warn_on_long_close(self): + logger.warning( + 'AsyncIOClient.aclose() is taking over 60 seconds to complete. ' + 'Check if you have any unreleased connections left. ' + 'Use asyncio.wait_for() to set a timeout for ' + 'AsyncIOClient.aclose().') + + +class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor): + + __slots__ = ("_managed", "_locked") + + def __init__(self, retry, client, iteration): + super().__init__(retry, client, iteration) + self._managed = False + self._locked = False + + async def __aenter__(self): + if self._managed: + raise errors.InterfaceError( + 'cannot enter context: already in an `async with` block') + self._managed = True + return self + + async def __aexit__(self, extype, ex, tb): + with self._exclusive(): + self._managed = False + return await self._exit(extype, ex) + + async def _ensure_transaction(self): + if not self._managed: + raise errors.InterfaceError( + "Only managed retriable transactions are supported. " + "Use `async with transaction:`" + ) + await super()._ensure_transaction() + + async def _query(self, query_context: abstract.QueryContext): + with self._exclusive(): + return await super()._query(query_context) + + async def _execute(self, execute_context: abstract.ExecuteContext) -> None: + with self._exclusive(): + await super()._execute(execute_context) + + @contextlib.contextmanager + def _exclusive(self): + if self._locked: + raise errors.InterfaceError( + "concurrent queries within the same transaction " + "are not allowed" + ) + self._locked = True + try: + yield + finally: + self._locked = False + + +class AsyncIORetry(transaction.BaseRetry): + + def __aiter__(self): + return self + + async def __anext__(self): + # Note: when changing this code consider also + # updating Retry.__next__. + if self._done: + raise StopAsyncIteration + if self._next_backoff: + await asyncio.sleep(self._next_backoff) + self._done = True + iteration = AsyncIOIteration(self, self._owner, self._iteration) + self._iteration += 1 + return iteration + + +class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor): + """A lazy connection pool. + + A Client can be used to manage a set of connections to the database. + Connections are first acquired from the pool, then used, and then released + back to the pool. Once a connection is released, it's reset to close all + open cursors and other resources *except* prepared statements. + + Clients are created by calling + :func:`~gel.asyncio_client.create_async_client`. + """ + + __slots__ = () + _impl_class = _AsyncIOPoolImpl + + async def ensure_connected(self): + await self._impl.ensure_connected() + return self + + async def aclose(self): + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``aclose()`` the pool will terminate by calling + AsyncIOClient.terminate() . + + It is advisable to use :func:`python:asyncio.wait_for` to set + a timeout. + """ + await self._impl.aclose() + + def transaction(self) -> AsyncIORetry: + return AsyncIORetry(self) + + async def __aenter__(self): + return await self.ensure_connected() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + async def _describe_query( + self, + query: str, + *, + inject_type_names: bool = False, + input_language: InputLanguage = InputLanguage.EDGEQL, + output_format: OutputFormat = OutputFormat.BINARY, + expect_one: bool = False, + ) -> abstract.DescribeResult: + return await self._describe(abstract.DescribeContext( + query=query, + state=self._get_state(), + inject_type_names=inject_type_names, + input_language=input_language, + output_format=output_format, + expect_one=expect_one, + )) + + +def create_async_client( + dsn=None, + *, + max_concurrency=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + secret_key: str = None, + database: str = None, + branch: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + wait_until_available: int = 30, + timeout: int = 10, +): + return AsyncIOClient( + connection_class=AsyncIOConnection, + max_concurrency=max_concurrency, + + # connect arguments + dsn=dsn, + host=host, + port=port, + credentials=credentials, + credentials_file=credentials_file, + user=user, + password=password, + secret_key=secret_key, + database=database, + branch=branch, + tls_ca=tls_ca, + tls_ca_file=tls_ca_file, + tls_security=tls_security, + wait_until_available=wait_until_available, + timeout=timeout, + ) diff --git a/gel/base_client.py b/gel/base_client.py new file mode 100644 index 00000000..5bc1a86a --- /dev/null +++ b/gel/base_client.py @@ -0,0 +1,734 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import abc +import random +import time +import typing + +from . import abstract +from . import con_utils +from . import enums +from . import errors +from . import options as _options +from .protocol import protocol + + +BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') +QUERY_CACHE_SIZE = 1000 + + +class BaseConnection(metaclass=abc.ABCMeta): + _protocol: typing.Any + _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]] + _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]] + _config: con_utils.ClientConfiguration + _params: con_utils.ResolvedConnectConfig + _log_listeners: typing.Set[ + typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None] + ] + __slots__ = ( + "__weakref__", + "_protocol", + "_addr", + "_addrs", + "_config", + "_params", + "_log_listeners", + "_holder", + ) + + def __init__( + self, + addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]], + config: con_utils.ClientConfiguration, + params: con_utils.ResolvedConnectConfig, + ): + self._addr = None + self._protocol = None + self._addrs = addrs + self._config = config + self._params = params + self._log_listeners = set() + self._holder = None + + @abc.abstractmethod + def _dispatch_log_message(self, msg): + ... + + def _on_log_message(self, msg): + if self._log_listeners: + self._dispatch_log_message(msg) + + def connected_addr(self): + return self._addr + + def _get_last_status(self) -> typing.Optional[str]: + if self._protocol is None: + return None + status = self._protocol.last_status + if status is not None: + status = status.decode() + return status + + def _cleanup(self): + self._log_listeners.clear() + if self._holder: + self._holder._release_on_close() + self._holder = None + + def add_log_listener( + self: BaseConnection_T, + callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], + None] + ) -> None: + """Add a listener for EdgeDB log messages. + + :param callable callback: + A callable receiving the following arguments: + **connection**: a Connection the callback is registered with; + **message**: the `gel.EdgeDBMessage` message. + """ + self._log_listeners.add(callback) + + def remove_log_listener( + self: BaseConnection_T, + callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], + None] + ) -> None: + """Remove a listening callback for log messages.""" + self._log_listeners.discard(callback) + + @property + def dbname(self) -> str: + return self._params.database + + @property + def branch(self) -> str: + return self._params.branch + + @abc.abstractmethod + def is_closed(self) -> bool: + ... + + @abc.abstractmethod + async def connect_addr(self, addr, timeout): + ... + + @abc.abstractmethod + async def sleep(self, seconds): + ... + + async def connect(self, *, single_attempt=False): + start = time.monotonic() + if single_attempt: + max_time = 0 + else: + max_time = start + self._config.wait_until_available + iteration = 1 + + while True: + for addr in self._addrs: + try: + await self.connect_addr(addr, self._config.connect_timeout) + except TimeoutError as e: + if iteration == 1 or time.monotonic() < max_time: + continue + else: + raise errors.ClientConnectionTimeoutError( + f"connecting to {addr} failed in" + f" {self._config.connect_timeout} sec" + ) from e + except errors.ClientConnectionError as e: + if ( + e.has_tag(errors.SHOULD_RECONNECT) and + (iteration == 1 or time.monotonic() < max_time) + ): + continue + nice_err = e.__class__( + con_utils.render_client_no_connection_error( + e, + addr, + attempts=iteration, + duration=time.monotonic() - start, + )) + raise nice_err from e.__cause__ + else: + return + + iteration += 1 + await self.sleep(0.01 + random.random() * 0.2) + + async def privileged_execute( + self, execute_context: abstract.ExecuteContext + ): + if self._protocol.is_legacy: + await self._protocol.legacy_simple_query( + execute_context.query.query, enums.Capability.ALL + ) + else: + await self._protocol.execute( + execute_context.lower(allow_capabilities=enums.Capability.ALL) + ) + + def is_in_transaction(self) -> bool: + """Return True if Connection is currently inside a transaction. + + :return bool: True if inside transaction, False otherwise. + """ + return self._protocol.is_in_transaction() + + def get_settings(self) -> typing.Dict[str, typing.Any]: + return self._protocol.get_settings() + + async def raw_query(self, query_context: abstract.QueryContext): + if self.is_closed(): + await self.connect() + + reconnect = False + i = 0 + if self._protocol.is_legacy: + allow_capabilities = enums.Capability.LEGACY_EXECUTE + else: + allow_capabilities = enums.Capability.EXECUTE + ctx = query_context.lower(allow_capabilities=allow_capabilities) + while True: + i += 1 + try: + if reconnect: + await self.connect(single_attempt=True) + if self._protocol.is_legacy: + return await self._protocol.legacy_execute_anonymous(ctx) + else: + res = await self._protocol.query(ctx) + if ctx.warnings: + res = query_context.warning_handler(ctx.warnings, res) + return res + + except errors.EdgeDBError as e: + if query_context.retry_options is None: + raise + if not e.has_tag(errors.SHOULD_RETRY): + raise e + # A query is read-only if it has no capabilities i.e. + # capabilities == 0. Read-only queries are safe to retry. + # Explicit transaction conflicts as well. + if ( + ctx.capabilities != 0 + and not isinstance(e, errors.TransactionConflictError) + ): + raise e + rule = query_context.retry_options.get_rule_for_exception(e) + if i >= rule.attempts: + raise e + await self.sleep(rule.backoff(i)) + reconnect = self.is_closed() + + async def _execute(self, execute_context: abstract.ExecuteContext) -> None: + if self._protocol.is_legacy: + if execute_context.query.args or execute_context.query.kwargs: + raise errors.InterfaceError( + "Legacy protocol doesn't support arguments in execute()" + ) + await self._protocol.legacy_simple_query( + execute_context.query.query, enums.Capability.LEGACY_EXECUTE + ) + else: + ctx = execute_context.lower( + allow_capabilities=enums.Capability.EXECUTE + ) + res = await self._protocol.execute(ctx) + if ctx.warnings: + res = execute_context.warning_handler(ctx.warnings, res) + + async def describe( + self, describe_context: abstract.DescribeContext + ) -> abstract.DescribeResult: + ctx = describe_context.lower( + allow_capabilities=enums.Capability.EXECUTE + ) + await self._protocol._parse(ctx) + return abstract.DescribeResult( + input_type=ctx.in_dc.make_type(describe_context), + output_type=ctx.out_dc.make_type(describe_context), + output_cardinality=enums.Cardinality(ctx.cardinality[0]), + capabilities=ctx.capabilities, + ) + + def terminate(self): + if not self.is_closed(): + try: + self._protocol.abort() + finally: + self._cleanup() + + def __repr__(self): + if self.is_closed(): + return '<{classname} [closed] {id:#x}>'.format( + classname=self.__class__.__name__, id=id(self)) + else: + return '<{classname} [connected to {addr}] {id:#x}>'.format( + classname=self.__class__.__name__, + addr=self.connected_addr(), + id=id(self)) + + +class PoolConnectionHolder(abc.ABC): + __slots__ = ( + "_con", + "_pool", + "_release_event", + "_timeout", + "_generation", + ) + _event_class = NotImplemented + + def __init__(self, pool): + + self._pool = pool + self._con = None + + self._timeout = None + self._generation = None + + self._release_event = self._event_class() + self._release_event.set() + + @abc.abstractmethod + async def close(self, *, wait=True): + ... + + @abc.abstractmethod + async def wait_until_released(self, timeout=None): + ... + + async def connect(self): + if self._con is not None: + raise errors.InternalClientError( + 'PoolConnectionHolder.connect() called while another ' + 'connection already exists') + + self._con = await self._pool._get_new_connection() + assert self._con._holder is None + self._con._holder = self + self._generation = self._pool._generation + + async def acquire(self) -> BaseConnection: + if self._con is None or self._con.is_closed(): + self._con = None + await self.connect() + + elif self._generation != self._pool._generation: + # Connections have been expired, re-connect the holder. + self._con._holder = None # don't release the connection + await self.close(wait=False) + self._con = None + await self.connect() + + self._release_event.clear() + + return self._con + + async def release(self, timeout): + if self._release_event.is_set(): + raise errors.InternalClientError( + 'PoolConnectionHolder.release() called on ' + 'a free connection holder') + + if self._con.is_closed(): + # This is usually the case when the connection is broken rather + # than closed by the user, so we need to call _release_on_close() + # here to release the holder back to the queue, because + # self._con._cleanup() was never called. On the other hand, it is + # safe to call self._release() twice - the second call is no-op. + self._release_on_close() + return + + self._timeout = None + + if self._generation != self._pool._generation: + # The connection has expired because it belongs to + # an older generation (BasePoolImpl.expire_connections() has + # been called.) + await self.close() + return + + # Free this connection holder and invalidate the + # connection proxy. + self._release() + + def terminate(self): + if self._con is not None: + # AsyncIOConnection.terminate() will call _release_on_close() to + # finish holder cleanup. + self._con.terminate() + + def _release_on_close(self): + self._release() + self._con = None + + def _release(self): + """Release this connection holder.""" + if self._release_event.is_set(): + # The holder is not checked out. + return + + self._release_event.set() + + # Put ourselves back to the pool queue. + self._pool._queue.put_nowait(self) + + +class BasePoolImpl(abc.ABC): + __slots__ = ( + "_connect_args", + "_codecs_registry", + "_query_cache", + "_connection_factory", + "_queue", + "_user_max_concurrency", + "_max_concurrency", + "_first_connect_lock", + "_working_addr", + "_working_config", + "_working_params", + "_holders", + "_initialized", + "_initializing", + "_closing", + "_closed", + "_generation", + ) + + _holder_class = NotImplemented + + def __init__( + self, + connect_args, + connection_factory, + *, + max_concurrency: typing.Optional[int], + ): + self._connection_factory = connection_factory + self._connect_args = connect_args + self._codecs_registry = protocol.CodecsRegistry() + self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) + + if max_concurrency is not None and max_concurrency <= 0: + raise ValueError( + 'max_concurrency is expected to be greater than zero' + ) + + self._user_max_concurrency = max_concurrency + self._max_concurrency = max_concurrency if max_concurrency else 1 + + self._holders = [] + self._queue = None + + self._first_connect_lock = None + self._working_addr = None + self._working_config = None + self._working_params = None + + self._closing = False + self._closed = False + self._generation = 0 + + @abc.abstractmethod + def _ensure_initialized(self): + ... + + @abc.abstractmethod + def _set_queue_maxsize(self, maxsize): + ... + + @abc.abstractmethod + async def _maybe_get_first_connection(self): + ... + + @abc.abstractmethod + async def acquire(self, timeout=None): + ... + + @abc.abstractmethod + async def _release(self, connection): + ... + + @property + def codecs_registry(self): + return self._codecs_registry + + @property + def query_cache(self): + return self._query_cache + + def _resize_holder_pool(self): + resize_diff = self._max_concurrency - len(self._holders) + + if (resize_diff > 0): + if self._queue.maxsize != self._max_concurrency: + self._set_queue_maxsize(self._max_concurrency) + + for _ in range(resize_diff): + ch = self._holder_class(self) + + self._holders.append(ch) + self._queue.put_nowait(ch) + elif resize_diff < 0: + # TODO: shrink the pool + pass + + def get_max_concurrency(self): + return self._max_concurrency + + def get_free_size(self): + if self._queue is None: + # Queue has not been initialized yet + return self._max_concurrency + + return self._queue.qsize() + + def set_connect_args(self, dsn=None, **connect_kwargs): + r"""Set the new connection arguments for this pool. + + The new connection arguments will be used for all subsequent + new connection attempts. Existing connections will remain until + they expire. Use BasePoolImpl.expire_connections() to expedite + the connection expiry. + + :param str dsn: + Connection arguments specified using as a single string in + the following format: + ``gel://user:pass@host:port/database?option=value``. + + :param \*\*connect_kwargs: + Keyword arguments for the + :func:`~gel.asyncio_client.create_async_client` function. + """ + + connect_kwargs["dsn"] = dsn + self._connect_args = connect_kwargs + self._codecs_registry = protocol.CodecsRegistry() + self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE) + self._working_addr = None + self._working_config = None + self._working_params = None + + async def _get_first_connection(self): + # First connection attempt on this pool. + connect_config, client_config = con_utils.parse_connect_arguments( + **self._connect_args, + # ToDos + command_timeout=None, + server_settings=None, + ) + con = self._connection_factory( + [connect_config.address], client_config, connect_config + ) + await con.connect() + self._working_addr = con.connected_addr() + self._working_config = client_config + self._working_params = connect_config + + if self._user_max_concurrency is None: + suggested_concurrency = con.get_settings().get( + 'suggested_pool_concurrency') + if suggested_concurrency: + self._max_concurrency = suggested_concurrency + self._resize_holder_pool() + return con + + async def _get_new_connection(self): + con = None + if self._working_addr is None: + con = await self._maybe_get_first_connection() + if con is None: + assert self._working_addr is not None + # We've connected before and have a resolved address, + # and parsed options and config. + con = self._connection_factory( + [self._working_addr], + self._working_config, + self._working_params, + ) + await con.connect() + + return con + + async def release(self, connection): + + if not isinstance(connection, BaseConnection): + raise errors.InterfaceError( + f'BasePoolImpl.release() received invalid connection: ' + f'{connection!r} does not belong to any connection pool' + ) + + ch = connection._holder + if ch is None: + # Already released, do nothing. + return + + if ch._pool is not self: + raise errors.InterfaceError( + f'BasePoolImpl.release() received invalid connection: ' + f'{connection!r} is not a member of this pool' + ) + + return await self._release(ch) + + def terminate(self): + """Terminate all connections in the pool.""" + if self._closed: + return + for ch in self._holders: + ch.terminate() + self._closed = True + + def expire_connections(self): + """Expire all currently open connections. + + Cause all currently open connections to get replaced on the + next query. + """ + self._generation += 1 + + async def ensure_connected(self): + self._ensure_initialized() + + for ch in self._holders: + if ch._con is not None and not ch._con.is_closed(): + return + + ch = self._holders[0] + ch._con = None + await ch.connect() + + +class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin): + __slots__ = ("_impl", "_options") + _impl_class = NotImplemented + + def __init__( + self, + *, + connection_class, + max_concurrency: typing.Optional[int], + dsn=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + secret_key: str = None, + database: str = None, + branch: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + tls_server_name: str = None, + wait_until_available: int = 30, + timeout: int = 10, + **kwargs, + ): + super().__init__() + connect_args = { + "dsn": dsn, + "host": host, + "port": port, + "credentials": credentials, + "credentials_file": credentials_file, + "user": user, + "password": password, + "secret_key": secret_key, + "database": database, + "branch": branch, + "timeout": timeout, + "tls_ca": tls_ca, + "tls_ca_file": tls_ca_file, + "tls_security": tls_security, + "tls_server_name": tls_server_name, + "wait_until_available": wait_until_available, + } + + self._impl = self._impl_class( + connect_args, + connection_class=connection_class, + max_concurrency=max_concurrency, + **kwargs, + ) + + def _shallow_clone(self): + new_client = self.__class__.__new__(self.__class__) + new_client._impl = self._impl + return new_client + + def _get_query_cache(self) -> abstract.QueryCache: + return abstract.QueryCache( + codecs_registry=self._impl.codecs_registry, + query_cache=self._impl.query_cache, + ) + + def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: + return self._options.retry_options + + def _get_state(self) -> _options.State: + return self._options.state + + def _get_warning_handler(self) -> _options.WarningHandler: + return self._options.warning_handler + + @property + def max_concurrency(self) -> int: + """Max number of connections in the pool.""" + + return self._impl.get_max_concurrency() + + @property + def free_size(self) -> int: + """Number of available connections in the pool.""" + + return self._impl.get_free_size() + + async def _query(self, query_context: abstract.QueryContext): + con = await self._impl.acquire() + try: + return await con.raw_query(query_context) + finally: + await self._impl.release(con) + + async def _execute(self, execute_context: abstract.ExecuteContext) -> None: + con = await self._impl.acquire() + try: + await con._execute(execute_context) + finally: + await self._impl.release(con) + + async def _describe( + self, describe_context: abstract.DescribeContext + ) -> abstract.DescribeResult: + con = await self._impl.acquire() + try: + return await con.describe(describe_context) + finally: + await self._impl.release(con) + + def terminate(self): + """Terminate all connections in the pool.""" + self._impl.terminate() diff --git a/gel/blocking_client.py b/gel/blocking_client.py new file mode 100644 index 00000000..797cdcb0 --- /dev/null +++ b/gel/blocking_client.py @@ -0,0 +1,490 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import contextlib +import datetime +import queue +import socket +import ssl +import threading +import time +import typing + +from . import abstract +from . import base_client +from . import con_utils +from . import errors +from . import transaction +from .protocol import blocking_proto +from .protocol.protocol import InputLanguage, OutputFormat + + +DEFAULT_PING_BEFORE_IDLE_TIMEOUT = datetime.timedelta(seconds=5) +MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1) + + +class BlockingIOConnection(base_client.BaseConnection): + __slots__ = ("_ping_wait_time",) + + async def connect_addr(self, addr, timeout): + deadline = time.monotonic() + timeout + + if isinstance(addr, str): + # UNIX socket + res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)] + else: + host, port = addr + try: + # getaddrinfo() doesn't take timeout!! + res_list = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + except socket.gaierror as e: + # All name resolution errors are considered temporary + err = errors.ClientConnectionFailedTemporarilyError(str(e)) + raise err from e + + for i, res in enumerate(res_list): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + except OSError as e: + sock.close() + if i < len(res_list) - 1: + continue + else: + raise con_utils.wrap_error(e) from e + try: + await self._connect_addr(sock, addr, sa, deadline) + except TimeoutError: + raise + except Exception: + if i < len(res_list) - 1: + continue + else: + raise + else: + break + + async def _connect_addr(self, sock, addr, sa, deadline): + try: + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: + sock.settimeout(time_left) + sock.connect(sa) + except OSError as e: + raise con_utils.wrap_error(e) from e + + if not isinstance(addr, str): + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: + # Upgrade to TLS + sock.settimeout(time_left) + try: + sock = self._params.ssl_ctx.wrap_socket( + sock, + server_hostname=( + self._params.tls_server_name or addr[0] + ), + ) + except ssl.CertificateError as e: + raise con_utils.wrap_error(e) from e + except ssl.SSLError as e: + raise con_utils.wrap_error(e) from e + else: + con_utils.check_alpn_protocol(sock) + except OSError as e: + raise con_utils.wrap_error(e) from e + + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + + if not isinstance(addr, str): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + proto = blocking_proto.BlockingIOProtocol(self._params, sock) + proto.set_connection(self) + + try: + await proto.wait_for(proto.connect(), time_left) + except TimeoutError: + raise + except OSError as e: + raise con_utils.wrap_error(e) from e + + self._protocol = proto + self._addr = addr + self._ping_wait_time = max( + ( + self.get_settings() + .get("system_config") + .session_idle_timeout + - DEFAULT_PING_BEFORE_IDLE_TIMEOUT + ), + MINIMUM_PING_WAIT_TIME, + ).total_seconds() + + except Exception: + sock.close() + raise + + async def sleep(self, seconds): + time.sleep(seconds) + + def is_closed(self): + proto = self._protocol + return not (proto and proto.sock is not None and + proto.sock.fileno() >= 0 and proto.connected) + + async def close(self, timeout=None): + """Send graceful termination message wait for connection to drop.""" + if not self.is_closed(): + try: + self._protocol.terminate() + if timeout is None: + await self._protocol.wait_for_disconnect() + else: + await self._protocol.wait_for( + self._protocol.wait_for_disconnect(), timeout + ) + except TimeoutError: + self.terminate() + raise errors.QueryTimeoutError() + except Exception: + self.terminate() + raise + finally: + self._cleanup() + + def _dispatch_log_message(self, msg): + for cb in self._log_listeners: + cb(self, msg) + + async def raw_query(self, query_context: abstract.QueryContext): + try: + if ( + time.monotonic() - self._protocol.last_active_timestamp + > self._ping_wait_time + ): + await self._protocol.ping() + except (errors.IdleSessionTimeoutError, errors.ClientConnectionError): + await self.connect() + + return await super().raw_query(query_context) + + +class _PoolConnectionHolder(base_client.PoolConnectionHolder): + __slots__ = () + _event_class = threading.Event + + async def close(self, *, wait=True, timeout=None): + if self._con is None: + return + await self._con.close(timeout=timeout) + + async def wait_until_released(self, timeout=None): + return self._release_event.wait(timeout) + + +class _PoolImpl(base_client.BasePoolImpl): + _holder_class = _PoolConnectionHolder + + def __init__( + self, + connect_args, + *, + max_concurrency: typing.Optional[int], + connection_class, + ): + if not issubclass(connection_class, BlockingIOConnection): + raise TypeError( + f'connection_class is expected to be a subclass of ' + f'gel.blocking_client.BlockingIOConnection, ' + f'got {connection_class}') + super().__init__( + connect_args, + connection_class, + max_concurrency=max_concurrency, + ) + + def _ensure_initialized(self): + if self._queue is None: + self._queue = queue.LifoQueue(maxsize=self._max_concurrency) + self._first_connect_lock = threading.Lock() + self._resize_holder_pool() + + def _set_queue_maxsize(self, maxsize): + with self._queue.mutex: + self._queue.maxsize = maxsize + + async def _maybe_get_first_connection(self): + with self._first_connect_lock: + if self._working_addr is None: + return await self._get_first_connection() + + async def acquire(self, timeout=None): + self._ensure_initialized() + + if self._closing: + raise errors.InterfaceError('pool is closing') + + ch = self._queue.get(timeout=timeout) + try: + con = await ch.acquire() + except Exception: + self._queue.put_nowait(ch) + raise + else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout + return con + + async def _release(self, holder): + if not isinstance(holder._con, BlockingIOConnection): + raise errors.InterfaceError( + f'release() received invalid connection: ' + f'{holder._con!r} does not belong to any connection pool' + ) + + timeout = None + return await holder.release(timeout) + + async def close(self, timeout=None): + if self._closed: + return + self._closing = True + try: + if timeout is None: + for ch in self._holders: + await ch.wait_until_released() + for ch in self._holders: + await ch.close() + else: + deadline = time.monotonic() + timeout + for ch in self._holders: + secs = deadline - time.monotonic() + if secs <= 0: + raise TimeoutError + if not await ch.wait_until_released(secs): + raise TimeoutError + for ch in self._holders: + secs = deadline - time.monotonic() + if secs <= 0: + raise TimeoutError + await ch.close(timeout=secs) + except TimeoutError as e: + self.terminate() + raise errors.InterfaceError( + "client is not fully closed in {} seconds; " + "terminating now.".format(timeout) + ) from e + except Exception: + self.terminate() + raise + finally: + self._closed = True + self._closing = False + + +class Iteration(transaction.BaseTransaction, abstract.Executor): + + __slots__ = ("_managed", "_lock") + + def __init__(self, retry, client, iteration): + super().__init__(retry, client, iteration) + self._managed = False + self._lock = threading.Lock() + + def __enter__(self): + with self._exclusive(): + if self._managed: + raise errors.InterfaceError( + 'cannot enter context: already in a `with` block') + self._managed = True + return self + + def __exit__(self, extype, ex, tb): + with self._exclusive(): + self._managed = False + return self._client._iter_coroutine(self._exit(extype, ex)) + + async def _ensure_transaction(self): + if not self._managed: + raise errors.InterfaceError( + "Only managed retriable transactions are supported. " + "Use `with transaction:`" + ) + await super()._ensure_transaction() + + def _query(self, query_context: abstract.QueryContext): + with self._exclusive(): + return self._client._iter_coroutine(super()._query(query_context)) + + def _execute(self, execute_context: abstract.ExecuteContext) -> None: + with self._exclusive(): + self._client._iter_coroutine(super()._execute(execute_context)) + + @contextlib.contextmanager + def _exclusive(self): + if not self._lock.acquire(blocking=False): + raise errors.InterfaceError( + "concurrent queries within the same transaction " + "are not allowed" + ) + try: + yield + finally: + self._lock.release() + + +class Retry(transaction.BaseRetry): + + def __iter__(self): + return self + + def __next__(self): + # Note: when changing this code consider also + # updating AsyncIORetry.__anext__. + if self._done: + raise StopIteration + if self._next_backoff: + time.sleep(self._next_backoff) + self._done = True + iteration = Iteration(self, self._owner, self._iteration) + self._iteration += 1 + return iteration + + +class Client(base_client.BaseClient, abstract.Executor): + """A lazy connection pool. + + A Client can be used to manage a set of connections to the database. + Connections are first acquired from the pool, then used, and then released + back to the pool. Once a connection is released, it's reset to close all + open cursors and other resources *except* prepared statements. + + Clients are created by calling + :func:`~gel.blocking_client.create_client`. + """ + + __slots__ = () + _impl_class = _PoolImpl + + def _iter_coroutine(self, coro): + try: + coro.send(None) + except StopIteration as ex: + return ex.value + finally: + coro.close() + + def _query(self, query_context: abstract.QueryContext): + return self._iter_coroutine(super()._query(query_context)) + + def _execute(self, execute_context: abstract.ExecuteContext) -> None: + self._iter_coroutine(super()._execute(execute_context)) + + def ensure_connected(self): + self._iter_coroutine(self._impl.ensure_connected()) + return self + + def transaction(self) -> Retry: + return Retry(self) + + def close(self, timeout=None): + """Attempt to gracefully close all connections in the client. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + Client.terminate() . + """ + self._iter_coroutine(self._impl.close(timeout)) + + def __enter__(self): + return self.ensure_connected() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def _describe_query( + self, + query: str, + *, + inject_type_names: bool = False, + input_language: InputLanguage = InputLanguage.EDGEQL, + output_format: OutputFormat = OutputFormat.BINARY, + expect_one: bool = False, + ) -> abstract.DescribeResult: + return self._iter_coroutine(self._describe(abstract.DescribeContext( + query=query, + state=self._get_state(), + inject_type_names=inject_type_names, + input_language=input_language, + output_format=output_format, + expect_one=expect_one, + ))) + + +def create_client( + dsn=None, + *, + max_concurrency=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + secret_key: str = None, + database: str = None, + branch: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + wait_until_available: int = 30, + timeout: int = 10, +): + return Client( + connection_class=BlockingIOConnection, + max_concurrency=max_concurrency, + + # connect arguments + dsn=dsn, + host=host, + port=port, + credentials=credentials, + credentials_file=credentials_file, + user=user, + password=password, + secret_key=secret_key, + database=database, + branch=branch, + tls_ca=tls_ca, + tls_ca_file=tls_ca_file, + tls_security=tls_security, + wait_until_available=wait_until_available, + timeout=timeout, + ) diff --git a/edgedb/codegen/__init__.py b/gel/codegen/__init__.py similarity index 100% rename from edgedb/codegen/__init__.py rename to gel/codegen/__init__.py diff --git a/edgedb/codegen/__main__.py b/gel/codegen/__main__.py similarity index 100% rename from edgedb/codegen/__main__.py rename to gel/codegen/__main__.py diff --git a/edgedb/codegen/cli.py b/gel/codegen/cli.py similarity index 100% rename from edgedb/codegen/cli.py rename to gel/codegen/cli.py diff --git a/edgedb/codegen/generator.py b/gel/codegen/generator.py similarity index 94% rename from edgedb/codegen/generator.py rename to gel/codegen/generator.py index 08fbdc6e..b5c70ddb 100644 --- a/edgedb/codegen/generator.py +++ b/gel/codegen/generator.py @@ -25,11 +25,11 @@ import textwrap import typing -import edgedb -from edgedb import abstract -from edgedb import describe -from edgedb.con_utils import find_edgedb_project_dir -from edgedb.color import get_color +import gel +from gel import abstract +from gel import describe +from gel.con_utils import find_gel_project_dir +from gel.color import get_color C = get_color() @@ -64,9 +64,9 @@ "cal::local_date": "datetime.date", "cal::local_time": "datetime.time", "cal::local_datetime": "datetime.datetime", - "cal::relative_duration": "edgedb.RelativeDuration", - "cal::date_duration": "edgedb.DateDuration", - "cfg::memory": "edgedb.ConfigMemory", + "cal::relative_duration": "gel.RelativeDuration", + "cal::date_duration": "gel.DateDuration", + "cfg::memory": "gel.ConfigMemory", "ext::pgvector::vector": "array.array", } @@ -158,15 +158,15 @@ def __init__(self, args: argparse.Namespace): self._skip_pydantic_validation = args.skip_pydantic_validation self._async = False try: - self._project_dir = pathlib.Path(find_edgedb_project_dir()) - except edgedb.ClientConnectionError: + self._project_dir = pathlib.Path(find_gel_project_dir()) + except gel.ClientConnectionError: print( - "Cannot find edgedb.toml: " + "Cannot find gel.toml: " "codegen must be run under an EdgeDB project dir" ) sys.exit(2) print_msg(f"Found EdgeDB project: {C.BOLD}{self._project_dir}{C.ENDC}") - self._client = edgedb.create_client(**_get_conn_args(args)) + self._client = gel.create_client(**_get_conn_args(args)) self._single_mode_files = args.file self._search_dirs = [] for search_dir in args.dir or []: @@ -203,7 +203,7 @@ def _new_file(self): def run(self): try: self._client.ensure_connected() - except edgedb.EdgeDBError as e: + except gel.EdgeDBError as e: print(f"Failed to connect to EdgeDB instance: {e}") sys.exit(61) with self._client: @@ -301,7 +301,7 @@ def _write_comments( cmd = [] if sys.argv[0].endswith("__main__.py"): cmd.append(pathlib.Path(sys.executable).name) - cmd.extend(["-m", "edgedb.codegen"]) + cmd.extend(["-m", "gel.codegen"]) else: cmd.append(pathlib.Path(sys.argv[0]).name) cmd.extend(sys.argv[1:]) @@ -346,7 +346,7 @@ def _generate( else: self._imports.add("typing") out_type = f"typing.List[{out_type}]" - elif dr.output_cardinality == edgedb.Cardinality.AT_MOST_ONE: + elif dr.output_cardinality == gel.Cardinality.AT_MOST_ONE: if SYS_VERSION_INFO >= (3, 10): out_type = f"{out_type} | None" else: @@ -377,11 +377,11 @@ def _generate( print(f"async def {name}(", file=buf) else: print(f"def {name}(", file=buf) - self._imports.add("edgedb") + self._imports.add("gel") if self._async: - print(f"{INDENT}executor: edgedb.AsyncIOExecutor,", file=buf) + print(f"{INDENT}executor: gel.AsyncIOExecutor,", file=buf) else: - print(f"{INDENT}executor: edgedb.Executor,", file=buf) + print(f"{INDENT}executor: gel.Executor,", file=buf) if kw_only: print(f"{INDENT}*,", file=buf) for name, arg in args.items(): @@ -390,7 +390,7 @@ def _generate( if dr.output_cardinality.is_multi(): method = "query" rt = "return " - elif dr.output_cardinality == edgedb.Cardinality.NO_RESULT: + elif dr.output_cardinality == gel.Cardinality.NO_RESULT: method = "execute" rt = "" else: @@ -485,7 +485,7 @@ def _generate_code( el_code = self._generate_code_with_cardinality( element.type, name_hint, element.cardinality ) - if element.kind == edgedb.ElementKind.LINK_PROPERTY: + if element.kind == gel.ElementKind.LINK_PROPERTY: link_props.append((el_name, el_code)) else: print(f"{INDENT}{el_name}: {el_code}", file=buf) @@ -536,7 +536,7 @@ def _generate_code( elif isinstance(type_, describe.RangeType): value = self._generate_code(type_.value_type, name_hint, is_input) - rv = f"edgedb.Range[{value}]" + rv = f"gel.Range[{value}]" else: rv = "??" @@ -548,12 +548,12 @@ def _generate_code_with_cardinality( self, type_: typing.Optional[describe.AnyType], name_hint: str, - cardinality: edgedb.Cardinality, + cardinality: gel.Cardinality, keyword_argument: bool = False, is_input: bool = False, ): rv = self._generate_code(type_, name_hint, is_input) - if cardinality == edgedb.Cardinality.AT_MOST_ONE: + if cardinality == gel.Cardinality.AT_MOST_ONE: if SYS_VERSION_INFO >= (3, 10): rv = f"{rv} | None" else: diff --git a/gel/color.py b/gel/color.py new file mode 100644 index 00000000..2b972aaa --- /dev/null +++ b/gel/color.py @@ -0,0 +1,61 @@ +import os +import sys +import warnings + +COLOR = None + + +class Color: + HEADER = "" + BLUE = "" + CYAN = "" + GREEN = "" + WARNING = "" + FAIL = "" + ENDC = "" + BOLD = "" + UNDERLINE = "" + + +def get_color() -> Color: + global COLOR + + if COLOR is None: + COLOR = Color() + if type(USE_COLOR) is bool: + use_color = USE_COLOR + else: + try: + use_color = USE_COLOR() + except Exception: + use_color = False + if use_color: + COLOR.HEADER = '\033[95m' + COLOR.BLUE = '\033[94m' + COLOR.CYAN = '\033[96m' + COLOR.GREEN = '\033[92m' + COLOR.WARNING = '\033[93m' + COLOR.FAIL = '\033[91m' + COLOR.ENDC = '\033[0m' + COLOR.BOLD = '\033[1m' + COLOR.UNDERLINE = '\033[4m' + + return COLOR + + +try: + USE_COLOR = { + "default": lambda: sys.stderr.isatty(), + "auto": lambda: sys.stderr.isatty(), + "enabled": True, + "disabled": False, + }[ + os.getenv("EDGEDB_COLOR_OUTPUT", "default") + ] +except KeyError: + warnings.warn( + "EDGEDB_COLOR_OUTPUT can only be one of: " + "default, auto, enabled or disabled", + stacklevel=1, + ) + USE_COLOR = False diff --git a/gel/con_utils.py b/gel/con_utils.py new file mode 100644 index 00000000..9a593d9f --- /dev/null +++ b/gel/con_utils.py @@ -0,0 +1,1278 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import base64 +import binascii +import errno +import json +import os +import re +import ssl +import typing +import urllib.parse +import warnings +import hashlib + +from . import errors +from . import credentials as cred_utils +from . import platform + + +EDGEDB_PORT = 5656 +ERRNO_RE = re.compile(r"\[Errno (\d+)\]") +TEMPORARY_ERRORS = ( + ConnectionAbortedError, + ConnectionRefusedError, + ConnectionResetError, + FileNotFoundError, +) +TEMPORARY_ERROR_CODES = frozenset({ + errno.ECONNREFUSED, + errno.ECONNABORTED, + errno.ECONNRESET, + errno.ENOENT, +}) + +ISO_SECONDS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S') +ISO_MINUTES_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') +ISO_HOURS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H') +ISO_UNITLESS_HOURS_RE = re.compile(r'^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$') +ISO_DAYS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D') +ISO_WEEKS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W') +ISO_MONTHS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M') +ISO_YEARS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y') + +HUMAN_HOURS_RE = re.compile( + r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))', +) +HUMAN_MINUTES_RE = re.compile( + r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))', +) +HUMAN_SECONDS_RE = re.compile( + r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))', +) +HUMAN_MS_RE = re.compile( + r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))', +) +HUMAN_US_RE = re.compile( + r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))', +) +INSTANCE_NAME_RE = re.compile( + r'^(\w(?:-?\w)*)$', + re.ASCII, +) +CLOUD_INSTANCE_NAME_RE = re.compile( + r'^([A-Za-z0-9_-](?:-?[A-Za-z0-9_])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$', + re.ASCII, +) +DSN_RE = re.compile( + r'^[a-z]+://', + re.IGNORECASE, +) +DOMAIN_LABEL_MAX_LENGTH = 63 + + +class ClientConfiguration(typing.NamedTuple): + + connect_timeout: float + command_timeout: float + wait_until_available: float + + +def _validate_port_spec(hosts, port): + if isinstance(port, list): + # If there is a list of ports, its length must + # match that of the host list. + if len(port) != len(hosts): + raise errors.InterfaceError( + 'could not match {} port numbers to {} hosts'.format( + len(port), len(hosts))) + else: + port = [port for _ in range(len(hosts))] + + return port + + +def _parse_hostlist(hostlist, port): + if ',' in hostlist: + # A comma-separated list of host addresses. + hostspecs = hostlist.split(',') + else: + hostspecs = [hostlist] + + hosts = [] + hostlist_ports = [] + + if not port: + portspec = _getenv('PORT') + if portspec: + if ',' in portspec: + default_port = [int(p) for p in portspec.split(',')] + else: + default_port = int(portspec) + else: + default_port = EDGEDB_PORT + + default_port = _validate_port_spec(hostspecs, default_port) + + else: + port = _validate_port_spec(hostspecs, port) + + for i, hostspec in enumerate(hostspecs): + addr, _, hostspec_port = hostspec.partition(':') + hosts.append(addr) + + if not port: + if hostspec_port: + hostlist_ports.append(int(hostspec_port)) + else: + hostlist_ports.append(default_port[i]) + + if not port: + port = hostlist_ports + + return hosts, port + + +def _hash_path(path): + path = os.path.realpath(path) + if platform.IS_WINDOWS and not path.startswith('\\\\'): + path = '\\\\?\\' + path + return hashlib.sha1(str(path).encode('utf-8')).hexdigest() + + +def _stash_path(path): + base_name = os.path.basename(path) + dir_name = base_name + '-' + _hash_path(path) + return platform.search_config_dir('projects', dir_name) + + +def _validate_tls_security(val: str) -> str: + val = val.lower() + if val not in {"insecure", "no_host_verification", "strict", "default"}: + raise ValueError( + "tls_security can only be one of " + "`insecure`, `no_host_verification`, `strict` or `default`" + ) + + return val + + +def _getenv_and_key(key: str) -> typing.Tuple[typing.Optional[str], str]: + edgedb_key = f'EDGEDB_{key}' + edgedb_val = os.getenv(edgedb_key) + gel_key = f'GEL_{key}' + gel_val = os.getenv(gel_key) + if edgedb_val is not None and gel_val is not None: + warnings.warn( + f'Both {gel_key} and {edgedb_key} are set; ' + f'{edgedb_key} will be ignored', + stacklevel=1, + ) + + if gel_val is None and edgedb_val is not None: + return edgedb_val, edgedb_key + else: + return gel_val, gel_key + + +def _getenv(key: str) -> typing.Optional[str]: + return _getenv_and_key(key)[0] + + +class ResolvedConnectConfig: + _host = None + _host_source = None + + _port = None + _port_source = None + + # We keep track of database and branch separately, because we want to make + # sure that we don't use both at the same time on the same configuration + # level. + _database = None + _database_source = None + + _branch = None + _branch_source = None + + _user = None + _user_source = None + + _password = None + _password_source = None + + _secret_key = None + _secret_key_source = None + + _tls_ca_data = None + _tls_ca_data_source = None + + _tls_server_name = None + _tls_security = None + _tls_security_source = None + + _wait_until_available = None + + _cloud_profile = None + _cloud_profile_source = None + + server_settings = {} + + def _set_param(self, param, value, source, validator=None): + param_name = '_' + param + if getattr(self, param_name) is None: + setattr(self, param_name + '_source', source) + if value is not None: + setattr( + self, + param_name, + validator(value) if validator else value + ) + + def set_host(self, host, source): + self._set_param('host', host, source, _validate_host) + + def set_port(self, port, source): + self._set_param('port', port, source, _validate_port) + + def set_database(self, database, source): + self._set_param('database', database, source, _validate_database) + + def set_branch(self, branch, source): + self._set_param('branch', branch, source, _validate_branch) + + def set_user(self, user, source): + self._set_param('user', user, source, _validate_user) + + def set_password(self, password, source): + self._set_param('password', password, source) + + def set_secret_key(self, secret_key, source): + self._set_param('secret_key', secret_key, source) + + def set_tls_ca_data(self, ca_data, source): + self._set_param('tls_ca_data', ca_data, source) + + def set_tls_ca_file(self, ca_file, source): + def read_ca_file(file_path): + with open(file_path) as f: + return f.read() + + self._set_param('tls_ca_data', ca_file, source, read_ca_file) + + def set_tls_server_name(self, ca_data, source): + self._set_param('tls_server_name', ca_data, source) + + def set_tls_security(self, security, source): + self._set_param('tls_security', security, source, + _validate_tls_security) + + def set_wait_until_available(self, wait_until_available, source): + self._set_param( + 'wait_until_available', + wait_until_available, + source, + _validate_wait_until_available, + ) + + def add_server_settings(self, server_settings): + _validate_server_settings(server_settings) + self.server_settings = {**server_settings, **self.server_settings} + + @property + def address(self): + return ( + self._host if self._host else 'localhost', + self._port if self._port else 5656 + ) + + # The properties actually merge database and branch, but "default" is + # different. If you need to know the underlying config use the _database + # and _branch. + @property + def database(self): + return ( + self._database if self._database else + self._branch if self._branch else + 'edgedb' + ) + + @property + def branch(self): + return ( + self._database if self._database else + self._branch if self._branch else + '__default__' + ) + + @property + def user(self): + return self._user if self._user else 'edgedb' + + @property + def password(self): + return self._password + + @property + def secret_key(self): + return self._secret_key + + @property + def tls_server_name(self): + return self._tls_server_name + + @property + def tls_security(self): + tls_security = self._tls_security or 'default' + security, security_key = _getenv_and_key('CLIENT_SECURITY') + security = security or 'default' + if security not in {'default', 'insecure_dev_mode', 'strict'}: + raise ValueError( + f'environment variable {security_key} should be ' + f'one of strict, insecure_dev_mode or default, ' + f'got: {security!r}') + + if security == 'default': + pass + elif security == 'insecure_dev_mode': + if tls_security == 'default': + tls_security = 'insecure' + elif security == 'strict': + if tls_security == 'default': + tls_security = 'strict' + elif tls_security in {'no_host_verification', 'insecure'}: + raise ValueError( + f'{security_key}=strict but ' + f'tls_security={tls_security}, tls_security must be ' + f'set to strict when {security_key} is strict') + + if tls_security != 'default': + return tls_security + + if self._tls_ca_data is not None: + return "no_host_verification" + + return "strict" + + _ssl_ctx = None + + @property + def ssl_ctx(self): + if (self._ssl_ctx): + return self._ssl_ctx + + self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + if self._tls_ca_data: + self._ssl_ctx.load_verify_locations( + cadata=self._tls_ca_data + ) + else: + self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + if platform.IS_WINDOWS: + import certifi + self._ssl_ctx.load_verify_locations(cafile=certifi.where()) + + tls_security = self.tls_security + self._ssl_ctx.check_hostname = tls_security == "strict" + + if tls_security in {"strict", "no_host_verification"}: + self._ssl_ctx.verify_mode = ssl.CERT_REQUIRED + else: + self._ssl_ctx.verify_mode = ssl.CERT_NONE + + self._ssl_ctx.set_alpn_protocols(['edgedb-binary']) + + return self._ssl_ctx + + @property + def wait_until_available(self): + return ( + self._wait_until_available + if self._wait_until_available is not None + else 30 + ) + + +def _validate_host(host): + if '/' in host: + raise ValueError('unix socket paths not supported') + if host == '' or ',' in host: + raise ValueError(f'invalid host: "{host}"') + return host + + +def _prepare_host_for_dsn(host): + host = _validate_host(host) + if ':' in host: + # IPv6 + host = f'[{host}]' + return host + + +def _validate_port(port): + try: + if isinstance(port, str): + port = int(port) + if not isinstance(port, int): + raise ValueError() + except Exception: + raise ValueError(f'invalid port: {port}, not an integer') + if port < 1 or port > 65535: + raise ValueError(f'invalid port: {port}, must be between 1 and 65535') + return port + + +def _validate_database(database): + if database == '': + raise ValueError(f'invalid database name: {database}') + return database + + +def _validate_branch(branch): + if branch == '': + raise ValueError(f'invalid branch name: {branch}') + return branch + + +def _validate_user(user): + if user == '': + raise ValueError(f'invalid user name: {user}') + return user + + +def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]: + s = string + total = 0 + match = rgex.search(string) + if match: + total += float(match.group(1)) + s = s.replace(match.group(0), "", 1) + + return (total, s) + + +def _parse_iso_duration(string: str) -> typing.Union[float, int]: + if not string.startswith("PT"): + raise ValueError(f"invalid duration {string!r}") + + time = string[2:] + match = ISO_UNITLESS_HOURS_RE.search(time) + if match: + hours = float(match.group(0)) + return 3600 * hours + + hours, time = _pop_iso_unit(ISO_HOURS_RE, time) + minutes, time = _pop_iso_unit(ISO_MINUTES_RE, time) + seconds, time = _pop_iso_unit(ISO_SECONDS_RE, time) + + if time: + raise ValueError(f'invalid duration {string!r}') + + return 3600 * hours + 60 * minutes + seconds + + +def _remove_white_space(s: str) -> str: + return ''.join(c for c in s if not c.isspace()) + + +def _pop_human_duration_unit( + rgex: re.Pattern, + string: str, +) -> typing.Tuple[float, bool, str]: + match = rgex.search(string) + if not match: + return 0, False, string + + number = 0 + if match.group(1): + literal = _remove_white_space(match.group(1)) + if literal.endswith('.'): + return 0, False, string + + if literal.startswith('-.'): + return 0, False, string + + number = float(literal) + string = string.replace( + match.group(0), + match.group(2) or match.group(3) or "", + 1, + ) + + return number, True, string + + +def _parse_human_duration(string: str) -> float: + found = False + + hour, f, s = _pop_human_duration_unit(HUMAN_HOURS_RE, string) + found |= f + + minute, f, s = _pop_human_duration_unit(HUMAN_MINUTES_RE, s) + found |= f + + second, f, s = _pop_human_duration_unit(HUMAN_SECONDS_RE, s) + found |= f + + ms, f, s = _pop_human_duration_unit(HUMAN_MS_RE, s) + found |= f + + us, f, s = _pop_human_duration_unit(HUMAN_US_RE, s) + found |= f + + if s.strip() or not found: + raise ValueError(f'invalid duration {string!r}') + + return 3600 * hour + 60 * minute + second + 0.001 * ms + 0.000001 * us + + +def _parse_duration_str(string: str) -> float: + if string.startswith('PT'): + return _parse_iso_duration(string) + return _parse_human_duration(string) + + +def _validate_wait_until_available(wait_until_available): + if isinstance(wait_until_available, str): + return _parse_duration_str(wait_until_available) + + if isinstance(wait_until_available, (int, float)): + return wait_until_available + + raise ValueError(f"invalid duration {wait_until_available!r}") + + +def _validate_server_settings(server_settings): + if ( + not isinstance(server_settings, dict) or + not all(isinstance(k, str) for k in server_settings) or + not all(isinstance(v, str) for v in server_settings.values()) + ): + raise ValueError( + 'server_settings is expected to be None or ' + 'a Dict[str, str]') + + +def _parse_connect_dsn_and_args( + *, + dsn, + host, + port, + credentials, + credentials_file, + user, + password, + secret_key, + database, + branch, + tls_ca, + tls_ca_file, + tls_security, + tls_server_name, + server_settings, + wait_until_available, +): + resolved_config = ResolvedConnectConfig() + + if dsn and DSN_RE.match(dsn): + instance_name = None + else: + instance_name, dsn = dsn, None + + def _get(key: str) -> typing.Optional[typing.Tuple[str, str]]: + val, env = _getenv_and_key(key) + return ( + (val, f'"{env}" environment variable') + if val is not None else None + ) + + # The cloud profile is potentially relevant to resolving credentials at + # any stage, including the config stage when other environment variables + # are not yet read. + cloud_profile_tuple = _get('CLOUD_PROFILE') + cloud_profile = cloud_profile_tuple[0] if cloud_profile_tuple else None + + has_compound_options = _resolve_config_options( + resolved_config, + 'Cannot have more than one of the following connection options: ' + + '"dsn", "credentials", "credentials_file" or "host"/"port"', + dsn=(dsn, '"dsn" option') if dsn is not None else None, + instance_name=( + (instance_name, '"dsn" option (parsed as instance name)') + if instance_name is not None else None + ), + credentials=( + (credentials, '"credentials" option') + if credentials is not None else None + ), + credentials_file=( + (credentials_file, '"credentials_file" option') + if credentials_file is not None else None + ), + host=(host, '"host" option') if host is not None else None, + port=(port, '"port" option') if port is not None else None, + database=( + (database, '"database" option') + if database is not None else None + ), + branch=( + (branch, '"branch" option') + if branch is not None else None + ), + user=(user, '"user" option') if user is not None else None, + password=( + (password, '"password" option') + if password is not None else None + ), + secret_key=( + (secret_key, '"secret_key" option') + if secret_key is not None else None + ), + tls_ca=( + (tls_ca, '"tls_ca" option') + if tls_ca is not None else None + ), + tls_ca_file=( + (tls_ca_file, '"tls_ca_file" option') + if tls_ca_file is not None else None + ), + tls_security=( + (tls_security, '"tls_security" option') + if tls_security is not None else None + ), + tls_server_name=( + (tls_server_name, '"tls_server_name" option') + if tls_server_name is not None else None + ), + server_settings=( + (server_settings, '"server_settings" option') + if server_settings is not None else None + ), + wait_until_available=( + (wait_until_available, '"wait_until_available" option') + if wait_until_available is not None else None + ), + cloud_profile=cloud_profile_tuple, + ) + + if has_compound_options is False: + env_port_tuple = _get("PORT") + if ( + resolved_config._port is None + and env_port_tuple + and env_port_tuple[0].startswith('tcp://') + ): + # EDGEDB_PORT is set by 'docker --link' so ignore and warn + warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' + + 'so will be ignored', stacklevel=1) + env_port_tuple = None + + has_compound_options = _resolve_config_options( + resolved_config, + # XXX + 'Cannot have more than one of the following connection ' + + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", ' + + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"', + dsn=_get('DSN'), + instance_name=_get('INSTANCE'), + credentials_file=_get('CREDENTIALS_FILE'), + host=_get('HOST'), + port=env_port_tuple, + database=_get('DATABASE'), + branch=_get('BRANCH'), + user=_get('USER'), + password=_get('PASSWORD'), + secret_key=_get('SECRET_KEY'), + tls_ca=_get('TLS_CA'), + tls_ca_file=_get('TLS_CA_FILE'), + tls_security=_get('CLIENT_TLS_SECURITY'), + tls_server_name=_get('TLS_SERVER_NAME'), + wait_until_available=_get('WAIT_UNTIL_AVAILABLE'), + ) + + if not has_compound_options: + dir = find_gel_project_dir() + stash_dir = _stash_path(dir) + if os.path.exists(stash_dir): + with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: + instance_name = f.read().strip() + cloud_profile_file = os.path.join(stash_dir, 'cloud-profile') + if os.path.exists(cloud_profile_file): + with open(cloud_profile_file, 'rt') as f: + cloud_profile = f.read().strip() + else: + cloud_profile = None + + _resolve_config_options( + resolved_config, + '', + instance_name=( + instance_name, + f'project linked instance ("{instance_name}")' + ), + cloud_profile=( + cloud_profile, + f'project defined cloud profile ("{cloud_profile}")' + ), + ) + + opt_database_file = os.path.join(stash_dir, 'database') + if os.path.exists(opt_database_file): + with open(opt_database_file, 'rt') as f: + database = f.read().strip() + resolved_config.set_database(database, "project") + else: + raise errors.ClientConnectionError( + f'Found `gel.toml` but the project is not initialized. ' + f'Run `gel project init`.' + ) + + return resolved_config + + +def _parse_dsn_into_config( + resolved_config: ResolvedConnectConfig, + dsn: typing.Tuple[str, str] +): + dsn_str, source = dsn + + try: + parsed = urllib.parse.urlparse(dsn_str) + host = ( + urllib.parse.unquote(parsed.hostname) if parsed.hostname else None + ) + port = parsed.port + database = parsed.path + user = parsed.username + password = parsed.password + except Exception as e: + raise ValueError(f'invalid DSN or instance name: {str(e)}') + + if parsed.scheme != 'edgedb': + raise ValueError( + f'invalid DSN: scheme is expected to be ' + f'"edgedb", got {parsed.scheme!r}') + + query = ( + urllib.parse.parse_qs(parsed.query, keep_blank_values=True) + if parsed.query != '' + else {} + ) + for key, val in query.items(): + if isinstance(val, list): + if len(val) > 1: + raise ValueError( + f'invalid DSN: duplicate query parameter {key}') + query[key] = val[-1] + + def handle_dsn_part( + paramName, value, currentValue, setter, + formatter=lambda val: val + ): + param_values = [ + (value if value != '' else None), + query.get(paramName), + query.get(paramName + '_env'), + query.get(paramName + '_file') + ] + if len([p for p in param_values if p is not None]) > 1: + raise ValueError( + f'invalid DSN: more than one of ' + + f'{(paramName + ", ") if value else ""}' + + f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' + + f'was specified' + ) + + if currentValue is None: + param = ( + value if (value is not None and value != '') + else query.get(paramName) + ) + paramSource = source + + if param is None: + env = query.get(paramName + '_env') + if env is not None: + param = os.getenv(env) + if param is None: + raise ValueError( + f'{paramName}_env environment variable "{env}" ' + + f'doesn\'t exist') + paramSource = paramSource + f' ({paramName}_env: {env})' + if param is None: + filename = query.get(paramName + '_file') + if filename is not None: + with open(filename) as f: + param = f.read() + paramSource = ( + paramSource + f' ({paramName}_file: {filename})' + ) + + param = formatter(param) if param is not None else None + + setter(param, paramSource) + + query.pop(paramName, None) + query.pop(paramName + '_env', None) + query.pop(paramName + '_file', None) + + handle_dsn_part( + 'host', host, resolved_config._host, resolved_config.set_host + ) + + handle_dsn_part( + 'port', port, resolved_config._port, resolved_config.set_port + ) + + def strip_leading_slash(str): + return str[1:] if str.startswith('/') else str + + if ( + 'branch' in query or + 'branch_env' in query or + 'branch_file' in query + ): + if ( + 'database' in query or + 'database_env' in query or + 'database_file' in query + ): + raise ValueError( + f"invalid DSN: `database` and `branch` cannot be present " + f"at the same time" + ) + if resolved_config._database is None: + # Only update the config if 'database' has not been already + # resolved. + handle_dsn_part( + 'branch', strip_leading_slash(database), + resolved_config._branch, resolved_config.set_branch, + strip_leading_slash + ) + else: + # Clean up the query, if config already has 'database' + query.pop('branch', None) + query.pop('branch_env', None) + query.pop('branch_file', None) + + else: + if resolved_config._branch is None: + # Only update the config if 'branch' has not been already + # resolved. + handle_dsn_part( + 'database', strip_leading_slash(database), + resolved_config._database, resolved_config.set_database, + strip_leading_slash + ) + else: + # Clean up the query, if config already has 'branch' + query.pop('database', None) + query.pop('database_env', None) + query.pop('database_file', None) + + handle_dsn_part( + 'user', user, resolved_config._user, resolved_config.set_user + ) + + handle_dsn_part( + 'password', password, + resolved_config._password, resolved_config.set_password + ) + + handle_dsn_part( + 'secret_key', None, + resolved_config._secret_key, resolved_config.set_secret_key + ) + + handle_dsn_part( + 'tls_ca_file', None, + resolved_config._tls_ca_data, resolved_config.set_tls_ca_file + ) + + handle_dsn_part( + 'tls_server_name', None, + resolved_config._tls_server_name, + resolved_config.set_tls_server_name + ) + + handle_dsn_part( + 'tls_security', None, + resolved_config._tls_security, + resolved_config.set_tls_security + ) + + handle_dsn_part( + 'wait_until_available', None, + resolved_config._wait_until_available, + resolved_config.set_wait_until_available + ) + + resolved_config.add_server_settings(query) + + +def _jwt_base64_decode(payload): + remainder = len(payload) % 4 + if remainder == 2: + payload += '==' + elif remainder == 3: + payload += '=' + elif remainder != 0: + raise errors.ClientConnectionError("Invalid secret key") + payload = base64.urlsafe_b64decode(payload.encode("utf-8")) + return json.loads(payload.decode("utf-8")) + + +def _parse_cloud_instance_name_into_config( + resolved_config: ResolvedConnectConfig, + source: str, + org_slug: str, + instance_name: str, +): + org_slug = org_slug.lower() + instance_name = instance_name.lower() + + label = f"{instance_name}--{org_slug}" + if len(label) > DOMAIN_LABEL_MAX_LENGTH: + raise ValueError( + f"invalid instance name: cloud instance name length cannot exceed " + f"{DOMAIN_LABEL_MAX_LENGTH - 1} characters: " + f"{org_slug}/{instance_name}" + ) + secret_key = resolved_config.secret_key + if secret_key is None: + try: + config_dir = platform.config_dir() + if resolved_config._cloud_profile is None: + profile = profile_src = "default" + else: + profile = resolved_config._cloud_profile + profile_src = resolved_config._cloud_profile_source + path = config_dir / "cloud-credentials" / f"{profile}.json" + with open(path, "rt") as f: + secret_key = json.load(f)["secret_key"] + except Exception: + raise errors.ClientConnectionError( + "Cannot connect to cloud instances without secret key." + ) + resolved_config.set_secret_key( + secret_key, + f"cloud-credentials/{profile}.json specified by {profile_src}", + ) + try: + dns_zone = _jwt_base64_decode(secret_key.split(".", 2)[1])["iss"] + except errors.EdgeDBError: + raise + except Exception: + raise errors.ClientConnectionError("Invalid secret key") + payload = f"{org_slug}/{instance_name}".encode("utf-8") + dns_bucket = binascii.crc_hqx(payload, 0) % 100 + host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}" + resolved_config.set_host(host, source) + + +def _resolve_config_options( + resolved_config: ResolvedConnectConfig, + compound_error: str, + *, + dsn=None, + instance_name=None, + credentials=None, + credentials_file=None, + host=None, + port=None, + database=None, + branch=None, + user=None, + password=None, + secret_key=None, + tls_ca=None, + tls_ca_file=None, + tls_security=None, + tls_server_name=None, + server_settings=None, + wait_until_available=None, + cloud_profile=None, +): + if database is not None: + if branch is not None: + raise errors.ClientConnectionError( + f"{database[1]} and {branch[1]} are mutually exclusive" + ) + if resolved_config._branch is None: + # Only update the config if 'branch' has not been already + # resolved. + resolved_config.set_database(*database) + if branch is not None: + if resolved_config._database is None: + # Only update the config if 'database' has not been already + # resolved. + resolved_config.set_branch(*branch) + if user is not None: + resolved_config.set_user(*user) + if password is not None: + resolved_config.set_password(*password) + if secret_key is not None: + resolved_config.set_secret_key(*secret_key) + if tls_ca_file is not None: + if tls_ca is not None: + raise errors.ClientConnectionError( + f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive" + ) + resolved_config.set_tls_ca_file(*tls_ca_file) + if tls_ca is not None: + resolved_config.set_tls_ca_data(*tls_ca) + if tls_security is not None: + resolved_config.set_tls_security(*tls_security) + if tls_server_name is not None: + resolved_config.set_tls_server_name(*tls_server_name) + if server_settings is not None: + resolved_config.add_server_settings(server_settings[0]) + if wait_until_available is not None: + resolved_config.set_wait_until_available(*wait_until_available) + if cloud_profile is not None: + resolved_config._set_param('cloud_profile', *cloud_profile) + + compound_params = [ + dsn, + instance_name, + credentials, + credentials_file, + host or port, + ] + compound_params_count = len([p for p in compound_params if p is not None]) + + if compound_params_count > 1: + raise errors.ClientConnectionError(compound_error) + + elif compound_params_count == 1: + if dsn is not None or host is not None or port is not None: + if port is not None: + resolved_config.set_port(*port) + if dsn is None: + dsn = ( + 'edgedb://' + + (_prepare_host_for_dsn(host[0]) if host else ''), + host[1] if host is not None else port[1] + ) + _parse_dsn_into_config(resolved_config, dsn) + else: + if credentials_file is not None: + creds = cred_utils.read_credentials(credentials_file[0]) + source = "credentials" + elif credentials is not None: + try: + cred_data = json.loads(credentials[0]) + except ValueError as e: + raise RuntimeError(f"cannot read credentials") from e + else: + creds = cred_utils.validate_credentials(cred_data) + source = "credentials" + elif INSTANCE_NAME_RE.match(instance_name[0]): + source = instance_name[1] + creds = cred_utils.read_credentials( + cred_utils.get_credentials_path(instance_name[0]), + ) + else: + name_match = CLOUD_INSTANCE_NAME_RE.match(instance_name[0]) + if name_match is None: + raise ValueError( + f'invalid DSN or instance name: "{instance_name[0]}"' + ) + source = instance_name[1] + org, inst = name_match.groups() + _parse_cloud_instance_name_into_config( + resolved_config, source, org, inst + ) + return True + + resolved_config.set_host(creds.get('host'), source) + resolved_config.set_port(creds.get('port'), source) + if 'database' in creds and resolved_config._branch is None: + # Only update the config if 'branch' has not been already + # resolved. + resolved_config.set_database(creds.get('database'), source) + + elif 'branch' in creds and resolved_config._database is None: + # Only update the config if 'database' has not been already + # resolved. + resolved_config.set_branch(creds.get('branch'), source) + resolved_config.set_user(creds.get('user'), source) + resolved_config.set_password(creds.get('password'), source) + resolved_config.set_tls_ca_data(creds.get('tls_ca'), source) + resolved_config.set_tls_security( + creds.get('tls_security'), + source + ) + + return True + + else: + return False + + +def find_gel_project_dir(): + dir = os.getcwd() + dev = os.stat(dir).st_dev + + while True: + gel_toml = os.path.join(dir, 'gel.toml') + edgedb_toml = os.path.join(dir, 'edgedb.toml') + if not os.path.isfile(gel_toml) and not os.path.isfile(edgedb_toml): + parent = os.path.dirname(dir) + if parent == dir: + raise errors.ClientConnectionError( + f'no `gel.toml` found and ' + f'no connection options specified' + ) + parent_dev = os.stat(parent).st_dev + if parent_dev != dev: + raise errors.ClientConnectionError( + f'no `gel.toml` found and ' + f'no connection options specified' + f'(stopped searching for `edgedb.toml` at file system' + f'boundary {dir!r})' + ) + dir = parent + dev = parent_dev + continue + return dir + + +def parse_connect_arguments( + *, + dsn, + host, + port, + credentials, + credentials_file, + database, + branch, + user, + password, + secret_key, + tls_ca, + tls_ca_file, + tls_security, + tls_server_name, + timeout, + command_timeout, + wait_until_available, + server_settings, +) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]: + + if command_timeout is not None: + try: + if isinstance(command_timeout, bool): + raise ValueError + command_timeout = float(command_timeout) + if command_timeout <= 0: + raise ValueError + except ValueError: + raise ValueError( + 'invalid command_timeout value: ' + 'expected greater than 0 float (got {!r})'.format( + command_timeout)) from None + + connect_config = _parse_connect_dsn_and_args( + dsn=dsn, + host=host, + port=port, + credentials=credentials, + credentials_file=credentials_file, + database=database, + branch=branch, + user=user, + password=password, + secret_key=secret_key, + tls_ca=tls_ca, + tls_ca_file=tls_ca_file, + tls_security=tls_security, + tls_server_name=tls_server_name, + server_settings=server_settings, + wait_until_available=wait_until_available, + ) + + client_config = ClientConfiguration( + connect_timeout=timeout, + command_timeout=command_timeout, + wait_until_available=connect_config.wait_until_available, + ) + + return connect_config, client_config + + +def check_alpn_protocol(ssl_obj): + if ssl_obj.selected_alpn_protocol() != 'edgedb-binary': + raise errors.ClientConnectionFailedError( + "The server doesn't support the edgedb-binary protocol." + ) + + +def render_client_no_connection_error(prefix, addr, attempts, duration): + if isinstance(addr, str): + msg = ( + f'{prefix}' + f'\n\tAfter {attempts} attempts in {duration:.1f} sec' + f'\n\tIs the server running locally and accepting ' + f'\n\tconnections on Unix domain socket {addr!r}?' + ) + else: + msg = ( + f'{prefix}' + f'\n\tAfter {attempts} attempts in {duration:.1f} sec' + f'\n\tIs the server running on host {addr[0]!r} ' + f'and accepting ' + f'\n\tTCP/IP connections on port {addr[1]}?' + ) + return msg + + +def _extract_errno(s): + """Extract multiple errnos from error string + + When we connect to a host that has multiple underlying IP addresses, say + ``localhost`` having ``::1`` and ``127.0.0.1``, we get + ``OSError("Multiple exceptions:...")`` error without ``.errno`` attribute + set. There are multiple ones in the text, so we extract all of them. + """ + result = [] + for match in ERRNO_RE.finditer(s): + result.append(int(match.group(1))) + if result: + return result + + +def wrap_error(e): + message = str(e) + if e.errno is None: + errnos = _extract_errno(message) + else: + errnos = [e.errno] + + if errnos: + is_temp = any((code in TEMPORARY_ERROR_CODES for code in errnos)) + else: + is_temp = isinstance(e, TEMPORARY_ERRORS) + + if is_temp: + return errors.ClientConnectionFailedTemporarilyError(message) + else: + return errors.ClientConnectionFailedError(message) diff --git a/edgedb/connresource.py b/gel/connresource.py similarity index 100% rename from edgedb/connresource.py rename to gel/connresource.py diff --git a/gel/credentials.py b/gel/credentials.py new file mode 100644 index 00000000..71146fd9 --- /dev/null +++ b/gel/credentials.py @@ -0,0 +1,119 @@ +import os +import pathlib +import typing +import json + +from . import platform + + +class RequiredCredentials(typing.TypedDict, total=True): + port: int + user: str + + +class Credentials(RequiredCredentials, total=False): + host: typing.Optional[str] + password: typing.Optional[str] + # It's OK for database and branch to appear in credentials, as long as + # they match. + database: typing.Optional[str] + branch: typing.Optional[str] + tls_ca: typing.Optional[str] + tls_security: typing.Optional[str] + + +def get_credentials_path(instance_name: str) -> pathlib.Path: + return platform.search_config_dir("credentials", instance_name + ".json") + + +def read_credentials(path: os.PathLike) -> Credentials: + try: + with open(path, encoding='utf-8') as f: + credentials = json.load(f) + return validate_credentials(credentials) + except Exception as e: + raise RuntimeError( + f"cannot read credentials at {path}" + ) from e + + +def validate_credentials(data: dict) -> Credentials: + port = data.get('port') + if port is None: + port = 5656 + if not isinstance(port, int) or port < 1 or port > 65535: + raise ValueError("invalid `port` value") + + user = data.get('user') + if user is None: + raise ValueError("`user` key is required") + if not isinstance(user, str): + raise ValueError("`user` must be a string") + + result = { # required keys + "user": user, + "port": port, + } + + host = data.get('host') + if host is not None: + if not isinstance(host, str): + raise ValueError("`host` must be a string") + result['host'] = host + + database = data.get('database') + if database is not None: + if not isinstance(database, str): + raise ValueError("`database` must be a string") + result['database'] = database + + branch = data.get('branch') + if branch is not None: + if not isinstance(branch, str): + raise ValueError("`branch` must be a string") + if database is not None and branch != database: + raise ValueError( + f"`database` and `branch` cannot be different") + result['branch'] = branch + + password = data.get('password') + if password is not None: + if not isinstance(password, str): + raise ValueError("`password` must be a string") + result['password'] = password + + ca = data.get('tls_ca') + if ca is not None: + if not isinstance(ca, str): + raise ValueError("`tls_ca` must be a string") + result['tls_ca'] = ca + + cert_data = data.get('tls_cert_data') + if cert_data is not None: + if not isinstance(cert_data, str): + raise ValueError("`tls_cert_data` must be a string") + if ca is not None and ca != cert_data: + raise ValueError( + f"tls_ca and tls_cert_data are both set and disagree") + result['tls_ca'] = cert_data + + verify = data.get('tls_verify_hostname') + if verify is not None: + if not isinstance(verify, bool): + raise ValueError("`tls_verify_hostname` must be a bool") + result['tls_security'] = 'strict' if verify else 'no_host_verification' + + tls_security = data.get('tls_security') + if tls_security is not None: + if not isinstance(tls_security, str): + raise ValueError("`tls_security` must be a string") + result['tls_security'] = tls_security + + missmatch = ValueError(f"tls_verify_hostname={verify} and " + f"tls_security={tls_security} are incompatible") + if tls_security == "strict" and verify is False: + raise missmatch + if tls_security in {"no_host_verification", "insecure"} and verify is True: + raise missmatch + + return result diff --git a/edgedb/datatypes/.gitignore b/gel/datatypes/.gitignore similarity index 100% rename from edgedb/datatypes/.gitignore rename to gel/datatypes/.gitignore diff --git a/edgedb/datatypes/NOTES b/gel/datatypes/NOTES similarity index 86% rename from edgedb/datatypes/NOTES rename to gel/datatypes/NOTES index 3a18a430..26068364 100644 --- a/edgedb/datatypes/NOTES +++ b/gel/datatypes/NOTES @@ -5,7 +5,7 @@ Objects don't have tp_clear; that's because there shouldn't be a situation when there's a ref cycle between them -- the way the data is serialized on the wire precludes that. -Furthermore, all edgedb.datatypes objects are immutable, so +Furthermore, all gel.datatypes objects are immutable, so it's not possible to create a ref cycle with only them in the reference chain by using the public API. diff --git a/edgedb/datatypes/__init__.pxd b/gel/datatypes/__init__.pxd similarity index 100% rename from edgedb/datatypes/__init__.pxd rename to gel/datatypes/__init__.pxd diff --git a/gel/datatypes/__init__.py b/gel/datatypes/__init__.py new file mode 100644 index 00000000..73609285 --- /dev/null +++ b/gel/datatypes/__init__.py @@ -0,0 +1,17 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/edgedb/datatypes/args.c b/gel/datatypes/args.c similarity index 100% rename from edgedb/datatypes/args.c rename to gel/datatypes/args.c diff --git a/edgedb/datatypes/comp.c b/gel/datatypes/comp.c similarity index 100% rename from edgedb/datatypes/comp.c rename to gel/datatypes/comp.c diff --git a/edgedb/datatypes/config_memory.pxd b/gel/datatypes/config_memory.pxd similarity index 100% rename from edgedb/datatypes/config_memory.pxd rename to gel/datatypes/config_memory.pxd diff --git a/edgedb/datatypes/config_memory.pyx b/gel/datatypes/config_memory.pyx similarity index 97% rename from edgedb/datatypes/config_memory.pyx rename to gel/datatypes/config_memory.pyx index c28193e8..6ebf1ee4 100644 --- a/edgedb/datatypes/config_memory.pyx +++ b/gel/datatypes/config_memory.pyx @@ -43,7 +43,7 @@ cdef class ConfigMemory: return hash((ConfigMemory, self._bytes)) def __repr__(self): - return f'' + return f'' @cython.cdivision(True) def __str__(self): diff --git a/edgedb/datatypes/datatypes.h b/gel/datatypes/datatypes.h similarity index 94% rename from edgedb/datatypes/datatypes.h rename to gel/datatypes/datatypes.h index c6a7589b..d1077477 100644 --- a/edgedb/datatypes/datatypes.h +++ b/gel/datatypes/datatypes.h @@ -32,7 +32,7 @@ #define EDGE_POINTER_IS_LINK (1 << 2) -/* === edgedb.RecordDesc ==================================== */ +/* === gel.RecordDesc ==================================== */ extern PyTypeObject EdgeRecordDesc_Type; @@ -86,7 +86,7 @@ PyObject * EdgeRecordDesc_List(PyObject *, uint8_t, uint8_t); PyObject * EdgeRecordDesc_GetDataclassFields(PyObject *); -/* === edgedb.NamedTuple ==================================== */ +/* === gel.NamedTuple ==================================== */ #define EDGE_NAMEDTUPLE_FREELIST_SIZE 500 #define EDGE_NAMEDTUPLE_FREELIST_MAXSAVE 20 @@ -98,7 +98,7 @@ PyObject * EdgeNamedTuple_Type_New(PyObject *); PyObject * EdgeNamedTuple_New(PyObject *); -/* === edgedb.Object ======================================== */ +/* === gel.Object ======================================== */ #define EDGE_OBJECT_FREELIST_SIZE 2000 #define EDGE_OBJECT_FREELIST_MAXSAVE 20 diff --git a/edgedb/datatypes/datatypes.pxd b/gel/datatypes/datatypes.pxd similarity index 100% rename from edgedb/datatypes/datatypes.pxd rename to gel/datatypes/datatypes.pxd diff --git a/edgedb/datatypes/datatypes.pyx b/gel/datatypes/datatypes.pyx similarity index 100% rename from edgedb/datatypes/datatypes.pyx rename to gel/datatypes/datatypes.pyx diff --git a/edgedb/datatypes/date_duration.pxd b/gel/datatypes/date_duration.pxd similarity index 100% rename from edgedb/datatypes/date_duration.pxd rename to gel/datatypes/date_duration.pxd diff --git a/edgedb/datatypes/date_duration.pyx b/gel/datatypes/date_duration.pyx similarity index 97% rename from edgedb/datatypes/date_duration.pyx rename to gel/datatypes/date_duration.pyx index 728075f4..a8d6e6fc 100644 --- a/edgedb/datatypes/date_duration.pyx +++ b/gel/datatypes/date_duration.pyx @@ -43,7 +43,7 @@ cdef class DateDuration: return hash((DateDuration, self.days, self.months)) def __repr__(self): - return f'' + return f'' @cython.cdivision(True) def __str__(self): diff --git a/edgedb/datatypes/enum.pyx b/gel/datatypes/enum.pyx similarity index 97% rename from edgedb/datatypes/enum.pyx rename to gel/datatypes/enum.pyx index 9ca4b155..a9155ee1 100644 --- a/edgedb/datatypes/enum.pyx +++ b/gel/datatypes/enum.pyx @@ -25,7 +25,7 @@ class EnumValue(enum.Enum): return self._value_ def __repr__(self): - return f'' + return f'' @classmethod def _try_from(cls, value): diff --git a/edgedb/datatypes/freelist.h b/gel/datatypes/freelist.h similarity index 100% rename from edgedb/datatypes/freelist.h rename to gel/datatypes/freelist.h diff --git a/edgedb/datatypes/hash.c b/gel/datatypes/hash.c similarity index 100% rename from edgedb/datatypes/hash.c rename to gel/datatypes/hash.c diff --git a/edgedb/datatypes/internal.h b/gel/datatypes/internal.h similarity index 100% rename from edgedb/datatypes/internal.h rename to gel/datatypes/internal.h diff --git a/edgedb/datatypes/namedtuple.c b/gel/datatypes/namedtuple.c similarity index 96% rename from edgedb/datatypes/namedtuple.c rename to gel/datatypes/namedtuple.c index de0fa8ea..251a0e92 100644 --- a/edgedb/datatypes/namedtuple.c +++ b/gel/datatypes/namedtuple.c @@ -159,7 +159,7 @@ namedtuple_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { { PyErr_SetString( PyExc_ValueError, - "edgedb.NamedTuple requires at least one field/value"); + "gel.NamedTuple requires at least one field/value"); goto fail; } @@ -254,7 +254,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { if (args_size > size) { PyErr_Format( PyExc_ValueError, - "edgedb.NamedTuple only needs %zd arguments, %zd given", + "gel.NamedTuple only needs %zd arguments, %zd given", size, args_size); goto fail; } @@ -270,7 +270,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { } else { PyErr_Format( PyExc_ValueError, - "edgedb.NamedTuple requires %zd arguments, %zd given", + "gel.NamedTuple requires %zd arguments, %zd given", size, args_size); goto fail; } @@ -278,7 +278,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { if (PyDict_Size(kwargs) > size - args_size) { PyErr_SetString( PyExc_ValueError, - "edgedb.NamedTuple got extra keyword arguments"); + "gel.NamedTuple got extra keyword arguments"); goto fail; } for (Py_ssize_t i = args_size; i < size; i++) { @@ -294,7 +294,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { } else { PyErr_Format( PyExc_ValueError, - "edgedb.NamedTuple missing required argument: %U", + "gel.NamedTuple missing required argument: %U", key); Py_CLEAR(key); goto fail; @@ -407,7 +407,7 @@ static PyType_Slot namedtuple_slots[] = { static PyType_Spec namedtuple_spec = { - "edgedb.DerivedNamedTuple", + "gel.DerivedNamedTuple", sizeof(PyTupleObject) - sizeof(PyObject *), sizeof(PyObject *), Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, @@ -417,7 +417,7 @@ static PyType_Spec namedtuple_spec = { PyTypeObject EdgeNamedTuple_Type = { PyVarObject_HEAD_INIT(NULL, 0) - "edgedb.NamedTuple", + "gel.NamedTuple", .tp_basicsize = sizeof(PyTupleObject) - sizeof(PyObject *), .tp_itemsize = sizeof(PyObject *), .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, diff --git a/edgedb/datatypes/object.c b/gel/datatypes/object.c similarity index 99% rename from edgedb/datatypes/object.c rename to gel/datatypes/object.c index 1d3029be..a5965bc1 100644 --- a/edgedb/datatypes/object.c +++ b/gel/datatypes/object.c @@ -86,7 +86,7 @@ EdgeObject_GetRecordDesc(PyObject *o) if (!EdgeObject_Check(o)) { PyErr_Format( PyExc_TypeError, - "an instance of edgedb.Object expected"); + "an instance of gel.Object expected"); return NULL; } @@ -324,7 +324,7 @@ static PyMappingMethods object_as_mapping = { PyTypeObject EdgeObject_Type = { PyVarObject_HEAD_INIT(NULL, 0) - "edgedb.Object", + "gel.Object", .tp_basicsize = sizeof(EdgeObject) - sizeof(PyObject *), .tp_itemsize = sizeof(PyObject *), .tp_dealloc = (destructor)object_dealloc, diff --git a/edgedb/datatypes/pythoncapi_compat.h b/gel/datatypes/pythoncapi_compat.h similarity index 100% rename from edgedb/datatypes/pythoncapi_compat.h rename to gel/datatypes/pythoncapi_compat.h diff --git a/gel/datatypes/range.py b/gel/datatypes/range.py new file mode 100644 index 00000000..e3fd3d1e --- /dev/null +++ b/gel/datatypes/range.py @@ -0,0 +1,165 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator, + Sequence) + +T = TypeVar("T") + + +class Range(Generic[T]): + + __slots__ = ("_lower", "_upper", "_inc_lower", "_inc_upper", "_empty") + + def __init__( + self, + lower: Optional[T] = None, + upper: Optional[T] = None, + *, + inc_lower: bool = True, + inc_upper: bool = False, + empty: bool = False, + ) -> None: + self._empty = empty + + if empty: + if ( + lower != upper + or lower is not None and inc_upper and inc_lower + ): + raise ValueError( + "conflicting arguments in range constructor: " + "\"empty\" is `true` while the specified bounds " + "suggest otherwise" + ) + + self._lower = self._upper = None + self._inc_lower = self._inc_upper = False + else: + self._lower = lower + self._upper = upper + self._inc_lower = lower is not None and inc_lower + self._inc_upper = upper is not None and inc_upper + + @property + def lower(self) -> Optional[T]: + return self._lower + + @property + def inc_lower(self) -> bool: + return self._inc_lower + + @property + def upper(self) -> Optional[T]: + return self._upper + + @property + def inc_upper(self) -> bool: + return self._inc_upper + + def is_empty(self) -> bool: + return self._empty + + def __bool__(self): + return not self.is_empty() + + def __eq__(self, other) -> bool: + if isinstance(other, Range): + o = other + else: + return NotImplemented + + return ( + self._lower, + self._upper, + self._inc_lower, + self._inc_upper, + self._empty, + ) == ( + o._lower, + o._upper, + o._inc_lower, + o._inc_upper, + o._empty, + ) + + def __hash__(self) -> int: + return hash(( + self._lower, + self._upper, + self._inc_lower, + self._inc_upper, + self._empty, + )) + + def __str__(self) -> str: + if self._empty: + desc = "empty" + else: + lb = "(" if not self._inc_lower else "[" + if self._lower is not None: + lb += repr(self._lower) + + if self._upper is not None: + ub = repr(self._upper) + else: + ub = "" + + ub += ")" if self._inc_upper else "]" + + desc = f"{lb}, {ub}" + + return f"" + + __repr__ = __str__ + + +# TODO: maybe we should implement range and multirange operations as well as +# normalization of the sub-ranges? +class MultiRange(Iterable[T]): + + _ranges: Sequence[T] + + def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: + if iterable is not None: + self._ranges = tuple(iterable) + else: + self._ranges = tuple() + + def __len__(self) -> int: + return len(self._ranges) + + def __iter__(self) -> Iterator[T]: + return iter(self._ranges) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._ranges) + + def __str__(self) -> str: + return f'' + + __repr__ = __str__ + + def __eq__(self, other: Any) -> bool: + if isinstance(other, MultiRange): + return set(self._ranges) == set(other._ranges) + else: + return NotImplemented + + def __hash__(self) -> int: + return hash(self._ranges) diff --git a/edgedb/datatypes/record_desc.c b/gel/datatypes/record_desc.c similarity index 99% rename from edgedb/datatypes/record_desc.c rename to gel/datatypes/record_desc.c index 69d24512..0dc98192 100644 --- a/edgedb/datatypes/record_desc.c +++ b/gel/datatypes/record_desc.c @@ -202,7 +202,7 @@ static PyMethodDef record_desc_methods[] = { PyTypeObject EdgeRecordDesc_Type = { PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "edgedb.RecordDescriptor", + .tp_name = "gel.RecordDescriptor", .tp_basicsize = sizeof(EdgeRecordDescObject), .tp_dealloc = (destructor)record_desc_dealloc, .tp_getattro = PyObject_GenericGetAttr, diff --git a/edgedb/datatypes/relative_duration.pxd b/gel/datatypes/relative_duration.pxd similarity index 100% rename from edgedb/datatypes/relative_duration.pxd rename to gel/datatypes/relative_duration.pxd diff --git a/edgedb/datatypes/relative_duration.pyx b/gel/datatypes/relative_duration.pyx similarity index 98% rename from edgedb/datatypes/relative_duration.pyx rename to gel/datatypes/relative_duration.pyx index 27c0c648..3489b315 100644 --- a/edgedb/datatypes/relative_duration.pyx +++ b/gel/datatypes/relative_duration.pyx @@ -50,7 +50,7 @@ cdef class RelativeDuration: return hash((RelativeDuration, self.microseconds, self.days, self.months)) def __repr__(self): - return f'' + return f'' @cython.cdivision(True) def __str__(self): diff --git a/edgedb/datatypes/repr.c b/gel/datatypes/repr.c similarity index 100% rename from edgedb/datatypes/repr.c rename to gel/datatypes/repr.c diff --git a/gel/describe.py b/gel/describe.py new file mode 100644 index 00000000..92e38854 --- /dev/null +++ b/gel/describe.py @@ -0,0 +1,98 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import dataclasses +import typing +import uuid + +from . import enums + + +@dataclasses.dataclass(frozen=True) +class AnyType: + desc_id: uuid.UUID + name: typing.Optional[str] + + +@dataclasses.dataclass(frozen=True) +class Element: + type: AnyType + cardinality: enums.Cardinality + is_implicit: bool + kind: enums.ElementKind + + +@dataclasses.dataclass(frozen=True) +class SequenceType(AnyType): + element_type: AnyType + + +@dataclasses.dataclass(frozen=True) +class SetType(SequenceType): + pass + + +@dataclasses.dataclass(frozen=True) +class ObjectType(AnyType): + elements: typing.Dict[str, Element] + + +@dataclasses.dataclass(frozen=True) +class BaseScalarType(AnyType): + pass + + +@dataclasses.dataclass(frozen=True) +class ScalarType(AnyType): + base_type: BaseScalarType + + +@dataclasses.dataclass(frozen=True) +class TupleType(AnyType): + element_types: typing.Tuple[AnyType] + + +@dataclasses.dataclass(frozen=True) +class NamedTupleType(AnyType): + element_types: typing.Dict[str, AnyType] + + +@dataclasses.dataclass(frozen=True) +class ArrayType(SequenceType): + pass + + +@dataclasses.dataclass(frozen=True) +class EnumType(AnyType): + members: typing.Tuple[str] + + +@dataclasses.dataclass(frozen=True) +class SparseObjectType(ObjectType): + pass + + +@dataclasses.dataclass(frozen=True) +class RangeType(AnyType): + value_type: AnyType + + +@dataclasses.dataclass(frozen=True) +class MultiRangeType(AnyType): + value_type: AnyType diff --git a/gel/enums.py b/gel/enums.py new file mode 100644 index 00000000..6312f596 --- /dev/null +++ b/gel/enums.py @@ -0,0 +1,74 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2021-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import enum + + +class Capability(enum.IntFlag): + + NONE = 0 # noqa + MODIFICATIONS = 1 << 0 # noqa + SESSION_CONFIG = 1 << 1 # noqa + TRANSACTION = 1 << 2 # noqa + DDL = 1 << 3 # noqa + PERSISTENT_CONFIG = 1 << 4 # noqa + + ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa + EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa + LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa + + +class CompilationFlag(enum.IntFlag): + + INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa + INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa + INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa + + +class Cardinality(enum.Enum): + # Cardinality isn't applicable for the query: + # * the query is a command like CONFIGURE that + # does not return any data; + # * the query is composed of multiple queries. + NO_RESULT = 0x6e + + # Cardinality is 1 or 0 + AT_MOST_ONE = 0x6f + + # Cardinality is 1 + ONE = 0x41 + + # Cardinality is >= 0 + MANY = 0x6d + + # Cardinality is >= 1 + AT_LEAST_ONE = 0x4d + + def is_single(self) -> bool: + return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE} + + def is_multi(self) -> bool: + return self in {Cardinality.AT_LEAST_ONE, Cardinality.MANY} + + +class ElementKind(enum.Enum): + + LINK = 1 # noqa + PROPERTY = 2 # noqa + LINK_PROPERTY = 3 # noqa diff --git a/gel/errors/__init__.py b/gel/errors/__init__.py new file mode 100644 index 00000000..87f3711d --- /dev/null +++ b/gel/errors/__init__.py @@ -0,0 +1,518 @@ +# AUTOGENERATED FROM "edb/api/errors.txt" WITH +# $ edb gen-errors \ +# --import 'from gel.errors._base import *\nfrom gel.errors.tags import *' \ +# --extra-all "_base.__all__" \ +# --stdout \ +# --client + + +# flake8: noqa + + +from gel.errors._base import * +from gel.errors.tags import * + + +__all__ = _base.__all__ + ( # type: ignore + 'InternalServerError', + 'UnsupportedFeatureError', + 'ProtocolError', + 'BinaryProtocolError', + 'UnsupportedProtocolVersionError', + 'TypeSpecNotFoundError', + 'UnexpectedMessageError', + 'InputDataError', + 'ParameterTypeMismatchError', + 'StateMismatchError', + 'ResultCardinalityMismatchError', + 'CapabilityError', + 'UnsupportedCapabilityError', + 'DisabledCapabilityError', + 'QueryError', + 'InvalidSyntaxError', + 'EdgeQLSyntaxError', + 'SchemaSyntaxError', + 'GraphQLSyntaxError', + 'InvalidTypeError', + 'InvalidTargetError', + 'InvalidLinkTargetError', + 'InvalidPropertyTargetError', + 'InvalidReferenceError', + 'UnknownModuleError', + 'UnknownLinkError', + 'UnknownPropertyError', + 'UnknownUserError', + 'UnknownDatabaseError', + 'UnknownParameterError', + 'SchemaError', + 'SchemaDefinitionError', + 'InvalidDefinitionError', + 'InvalidModuleDefinitionError', + 'InvalidLinkDefinitionError', + 'InvalidPropertyDefinitionError', + 'InvalidUserDefinitionError', + 'InvalidDatabaseDefinitionError', + 'InvalidOperatorDefinitionError', + 'InvalidAliasDefinitionError', + 'InvalidFunctionDefinitionError', + 'InvalidConstraintDefinitionError', + 'InvalidCastDefinitionError', + 'DuplicateDefinitionError', + 'DuplicateModuleDefinitionError', + 'DuplicateLinkDefinitionError', + 'DuplicatePropertyDefinitionError', + 'DuplicateUserDefinitionError', + 'DuplicateDatabaseDefinitionError', + 'DuplicateOperatorDefinitionError', + 'DuplicateViewDefinitionError', + 'DuplicateFunctionDefinitionError', + 'DuplicateConstraintDefinitionError', + 'DuplicateCastDefinitionError', + 'DuplicateMigrationError', + 'SessionTimeoutError', + 'IdleSessionTimeoutError', + 'QueryTimeoutError', + 'TransactionTimeoutError', + 'IdleTransactionTimeoutError', + 'ExecutionError', + 'InvalidValueError', + 'DivisionByZeroError', + 'NumericOutOfRangeError', + 'AccessPolicyError', + 'QueryAssertionError', + 'IntegrityError', + 'ConstraintViolationError', + 'CardinalityViolationError', + 'MissingRequiredError', + 'TransactionError', + 'TransactionConflictError', + 'TransactionSerializationError', + 'TransactionDeadlockError', + 'WatchError', + 'ConfigurationError', + 'AccessError', + 'AuthenticationError', + 'AvailabilityError', + 'BackendUnavailableError', + 'ServerOfflineError', + 'BackendError', + 'UnsupportedBackendFeatureError', + 'LogMessage', + 'WarningMessage', + 'ClientError', + 'ClientConnectionError', + 'ClientConnectionFailedError', + 'ClientConnectionFailedTemporarilyError', + 'ClientConnectionTimeoutError', + 'ClientConnectionClosedError', + 'InterfaceError', + 'QueryArgumentError', + 'MissingArgumentError', + 'UnknownArgumentError', + 'InvalidArgumentError', + 'NoDataError', + 'InternalClientError', +) + + +class InternalServerError(EdgeDBError): + _code = 0x_01_00_00_00 + + +class UnsupportedFeatureError(EdgeDBError): + _code = 0x_02_00_00_00 + + +class ProtocolError(EdgeDBError): + _code = 0x_03_00_00_00 + + +class BinaryProtocolError(ProtocolError): + _code = 0x_03_01_00_00 + + +class UnsupportedProtocolVersionError(BinaryProtocolError): + _code = 0x_03_01_00_01 + + +class TypeSpecNotFoundError(BinaryProtocolError): + _code = 0x_03_01_00_02 + + +class UnexpectedMessageError(BinaryProtocolError): + _code = 0x_03_01_00_03 + + +class InputDataError(ProtocolError): + _code = 0x_03_02_00_00 + + +class ParameterTypeMismatchError(InputDataError): + _code = 0x_03_02_01_00 + + +class StateMismatchError(InputDataError): + _code = 0x_03_02_02_00 + tags = frozenset({SHOULD_RETRY}) + + +class ResultCardinalityMismatchError(ProtocolError): + _code = 0x_03_03_00_00 + + +class CapabilityError(ProtocolError): + _code = 0x_03_04_00_00 + + +class UnsupportedCapabilityError(CapabilityError): + _code = 0x_03_04_01_00 + + +class DisabledCapabilityError(CapabilityError): + _code = 0x_03_04_02_00 + + +class QueryError(EdgeDBError): + _code = 0x_04_00_00_00 + + +class InvalidSyntaxError(QueryError): + _code = 0x_04_01_00_00 + + +class EdgeQLSyntaxError(InvalidSyntaxError): + _code = 0x_04_01_01_00 + + +class SchemaSyntaxError(InvalidSyntaxError): + _code = 0x_04_01_02_00 + + +class GraphQLSyntaxError(InvalidSyntaxError): + _code = 0x_04_01_03_00 + + +class InvalidTypeError(QueryError): + _code = 0x_04_02_00_00 + + +class InvalidTargetError(InvalidTypeError): + _code = 0x_04_02_01_00 + + +class InvalidLinkTargetError(InvalidTargetError): + _code = 0x_04_02_01_01 + + +class InvalidPropertyTargetError(InvalidTargetError): + _code = 0x_04_02_01_02 + + +class InvalidReferenceError(QueryError): + _code = 0x_04_03_00_00 + + +class UnknownModuleError(InvalidReferenceError): + _code = 0x_04_03_00_01 + + +class UnknownLinkError(InvalidReferenceError): + _code = 0x_04_03_00_02 + + +class UnknownPropertyError(InvalidReferenceError): + _code = 0x_04_03_00_03 + + +class UnknownUserError(InvalidReferenceError): + _code = 0x_04_03_00_04 + + +class UnknownDatabaseError(InvalidReferenceError): + _code = 0x_04_03_00_05 + + +class UnknownParameterError(InvalidReferenceError): + _code = 0x_04_03_00_06 + + +class SchemaError(QueryError): + _code = 0x_04_04_00_00 + + +class SchemaDefinitionError(QueryError): + _code = 0x_04_05_00_00 + + +class InvalidDefinitionError(SchemaDefinitionError): + _code = 0x_04_05_01_00 + + +class InvalidModuleDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_01 + + +class InvalidLinkDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_02 + + +class InvalidPropertyDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_03 + + +class InvalidUserDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_04 + + +class InvalidDatabaseDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_05 + + +class InvalidOperatorDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_06 + + +class InvalidAliasDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_07 + + +class InvalidFunctionDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_08 + + +class InvalidConstraintDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_09 + + +class InvalidCastDefinitionError(InvalidDefinitionError): + _code = 0x_04_05_01_0A + + +class DuplicateDefinitionError(SchemaDefinitionError): + _code = 0x_04_05_02_00 + + +class DuplicateModuleDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_01 + + +class DuplicateLinkDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_02 + + +class DuplicatePropertyDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_03 + + +class DuplicateUserDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_04 + + +class DuplicateDatabaseDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_05 + + +class DuplicateOperatorDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_06 + + +class DuplicateViewDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_07 + + +class DuplicateFunctionDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_08 + + +class DuplicateConstraintDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_09 + + +class DuplicateCastDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_0A + + +class DuplicateMigrationError(DuplicateDefinitionError): + _code = 0x_04_05_02_0B + + +class SessionTimeoutError(QueryError): + _code = 0x_04_06_00_00 + + +class IdleSessionTimeoutError(SessionTimeoutError): + _code = 0x_04_06_01_00 + tags = frozenset({SHOULD_RETRY}) + + +class QueryTimeoutError(SessionTimeoutError): + _code = 0x_04_06_02_00 + + +class TransactionTimeoutError(SessionTimeoutError): + _code = 0x_04_06_0A_00 + + +class IdleTransactionTimeoutError(TransactionTimeoutError): + _code = 0x_04_06_0A_01 + + +class ExecutionError(EdgeDBError): + _code = 0x_05_00_00_00 + + +class InvalidValueError(ExecutionError): + _code = 0x_05_01_00_00 + + +class DivisionByZeroError(InvalidValueError): + _code = 0x_05_01_00_01 + + +class NumericOutOfRangeError(InvalidValueError): + _code = 0x_05_01_00_02 + + +class AccessPolicyError(InvalidValueError): + _code = 0x_05_01_00_03 + + +class QueryAssertionError(InvalidValueError): + _code = 0x_05_01_00_04 + + +class IntegrityError(ExecutionError): + _code = 0x_05_02_00_00 + + +class ConstraintViolationError(IntegrityError): + _code = 0x_05_02_00_01 + + +class CardinalityViolationError(IntegrityError): + _code = 0x_05_02_00_02 + + +class MissingRequiredError(IntegrityError): + _code = 0x_05_02_00_03 + + +class TransactionError(ExecutionError): + _code = 0x_05_03_00_00 + + +class TransactionConflictError(TransactionError): + _code = 0x_05_03_01_00 + tags = frozenset({SHOULD_RETRY}) + + +class TransactionSerializationError(TransactionConflictError): + _code = 0x_05_03_01_01 + tags = frozenset({SHOULD_RETRY}) + + +class TransactionDeadlockError(TransactionConflictError): + _code = 0x_05_03_01_02 + tags = frozenset({SHOULD_RETRY}) + + +class WatchError(ExecutionError): + _code = 0x_05_04_00_00 + + +class ConfigurationError(EdgeDBError): + _code = 0x_06_00_00_00 + + +class AccessError(EdgeDBError): + _code = 0x_07_00_00_00 + + +class AuthenticationError(AccessError): + _code = 0x_07_01_00_00 + + +class AvailabilityError(EdgeDBError): + _code = 0x_08_00_00_00 + + +class BackendUnavailableError(AvailabilityError): + _code = 0x_08_00_00_01 + tags = frozenset({SHOULD_RETRY}) + + +class ServerOfflineError(AvailabilityError): + _code = 0x_08_00_00_02 + tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) + + +class BackendError(EdgeDBError): + _code = 0x_09_00_00_00 + + +class UnsupportedBackendFeatureError(BackendError): + _code = 0x_09_00_01_00 + + +class LogMessage(EdgeDBMessage): + _code = 0x_F0_00_00_00 + + +class WarningMessage(LogMessage): + _code = 0x_F0_01_00_00 + + +class ClientError(EdgeDBError): + _code = 0x_FF_00_00_00 + + +class ClientConnectionError(ClientError): + _code = 0x_FF_01_00_00 + + +class ClientConnectionFailedError(ClientConnectionError): + _code = 0x_FF_01_01_00 + + +class ClientConnectionFailedTemporarilyError(ClientConnectionFailedError): + _code = 0x_FF_01_01_01 + tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) + + +class ClientConnectionTimeoutError(ClientConnectionError): + _code = 0x_FF_01_02_00 + tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) + + +class ClientConnectionClosedError(ClientConnectionError): + _code = 0x_FF_01_03_00 + tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) + + +class InterfaceError(ClientError): + _code = 0x_FF_02_00_00 + + +class QueryArgumentError(InterfaceError): + _code = 0x_FF_02_01_00 + + +class MissingArgumentError(QueryArgumentError): + _code = 0x_FF_02_01_01 + + +class UnknownArgumentError(QueryArgumentError): + _code = 0x_FF_02_01_02 + + +class InvalidArgumentError(QueryArgumentError): + _code = 0x_FF_02_01_03 + + +class NoDataError(ClientError): + _code = 0x_FF_03_00_00 + + +class InternalClientError(ClientError): + _code = 0x_FF_04_00_00 + diff --git a/gel/errors/_base.py b/gel/errors/_base.py new file mode 100644 index 00000000..3028f829 --- /dev/null +++ b/gel/errors/_base.py @@ -0,0 +1,363 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import io +import os +import traceback +import unicodedata +import warnings + +__all__ = ( + 'EdgeDBError', 'EdgeDBMessage', +) + + +class Meta(type): + + def __new__(mcls, name, bases, dct): + cls = super().__new__(mcls, name, bases, dct) + + code = dct.get('_code') + if code is not None: + mcls._index[code] = cls + + # If it's a base class add it to the base class index + b1, b2, b3, b4 = _decode(code) + if b1 == 0 or b2 == 0 or b3 == 0 or b4 == 0: + mcls._base_class_index[(b1, b2, b3, b4)] = cls + + return cls + + +class EdgeDBMessageMeta(Meta): + + _base_class_index = {} + _index = {} + + +class EdgeDBMessage(Warning, metaclass=EdgeDBMessageMeta): + + _code = None + + def __init__(self, severity, message): + super().__init__(message) + self._severity = severity + + def get_severity(self): + return self._severity + + def get_severity_name(self): + return _severity_name(self._severity) + + def get_code(self): + return self._code + + @staticmethod + def _from_code(code, severity, message, *args, **kwargs): + cls = _lookup_message_cls(code) + exc = cls(severity, message, *args, **kwargs) + exc._code = code + return exc + + +class EdgeDBErrorMeta(Meta): + + _base_class_index = {} + _index = {} + + +class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta): + + _code = None + _query = None + tags = frozenset() + + def __init__(self, *args, **kwargs): + self._attrs = {} + super().__init__(*args, **kwargs) + + def has_tag(self, tag): + return tag in self.tags + + @property + def _position(self): + # not a stable API method + return int(self._read_str_field(FIELD_POSITION_START, -1)) + + @property + def _position_start(self): + # not a stable API method + return int(self._read_str_field(FIELD_CHARACTER_START, -1)) + + @property + def _position_end(self): + # not a stable API method + return int(self._read_str_field(FIELD_CHARACTER_END, -1)) + + @property + def _line(self): + # not a stable API method + return int(self._read_str_field(FIELD_LINE_START, -1)) + + @property + def _col(self): + # not a stable API method + return int(self._read_str_field(FIELD_COLUMN_START, -1)) + + @property + def _hint(self): + # not a stable API method + return self._read_str_field(FIELD_HINT) + + @property + def _details(self): + # not a stable API method + return self._read_str_field(FIELD_DETAILS) + + def _read_str_field(self, key, default=None): + val = self._attrs.get(key) + if isinstance(val, bytes): + return val.decode('utf-8') + elif val is not None: + return val + return default + + def get_code(self): + return self._code + + def get_server_context(self): + return self._read_str_field(FIELD_SERVER_TRACEBACK) + + @staticmethod + def _from_code(code, *args, **kwargs): + cls = _lookup_error_cls(code) + exc = cls(*args, **kwargs) + exc._code = code + return exc + + @staticmethod + def _from_json(data): + exc = EdgeDBError._from_code(data['code'], data['message']) + exc._attrs = { + field: data[name] + for name, field in _JSON_FIELDS.items() + if name in data + } + return exc + + def __str__(self): + msg = super().__str__() + if SHOW_HINT and self._query and self._position_start >= 0: + try: + return _format_error( + msg, + self._query, + self._position_start, + max(1, self._position_end - self._position_start), + self._line if self._line > 0 else "?", + self._col if self._col > 0 else "?", + self._hint or "error", + self._details, + ) + except Exception: + return "".join( + ( + msg, + LINESEP, + LINESEP, + "During formatting of the above exception, " + "another exception occurred:", + LINESEP, + LINESEP, + traceback.format_exc(), + ) + ) + else: + return msg + + +def _lookup_cls(code: int, *, meta: type, default: type): + try: + return meta._index[code] + except KeyError: + pass + + b1, b2, b3, _ = _decode(code) + + try: + return meta._base_class_index[(b1, b2, b3, 0)] + except KeyError: + pass + try: + return meta._base_class_index[(b1, b2, 0, 0)] + except KeyError: + pass + try: + return meta._base_class_index[(b1, 0, 0, 0)] + except KeyError: + pass + + return default + + +def _lookup_error_cls(code: int): + return _lookup_cls(code, meta=EdgeDBErrorMeta, default=EdgeDBError) + + +def _lookup_message_cls(code: int): + return _lookup_cls(code, meta=EdgeDBMessageMeta, default=EdgeDBMessage) + + +def _decode(code: int): + return tuple(code.to_bytes(4, 'big')) + + +def _severity_name(severity): + if severity <= EDGE_SEVERITY_DEBUG: + return 'DEBUG' + if severity <= EDGE_SEVERITY_INFO: + return 'INFO' + if severity <= EDGE_SEVERITY_NOTICE: + return 'NOTICE' + if severity <= EDGE_SEVERITY_WARNING: + return 'WARNING' + if severity <= EDGE_SEVERITY_ERROR: + return 'ERROR' + if severity <= EDGE_SEVERITY_FATAL: + return 'FATAL' + return 'PANIC' + + +def _format_error(msg, query, start, offset, line, col, hint, details): + c = get_color() + rv = io.StringIO() + rv.write(f"{c.BOLD}{msg}{c.ENDC}{LINESEP}") + lines = query.splitlines(keepends=True) + num_len = len(str(len(lines))) + rv.write(f"{c.BLUE}{'':>{num_len}} ┌─{c.ENDC} query:{line}:{col}{LINESEP}") + rv.write(f"{c.BLUE}{'':>{num_len}} │ {c.ENDC}{LINESEP}") + for num, line in enumerate(lines): + length = len(line) + line = line.rstrip() # we'll use our own line separator + if start >= length: + # skip lines before the error + start -= length + continue + + if start >= 0: + # Error starts in current line, write the line before the error + first_half = repr(line[:start])[1:-1] + line = line[start:] + length -= start + rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.ENDC}{first_half}") + start = _unicode_width(first_half) + else: + # Multi-line error continues + rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.FAIL}│ {c.ENDC}") + + if offset > length: + # Error is ending beyond current line + line = repr(line)[1:-1] + rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}") + if start >= 0: + # Multi-line error starts + rv.write(f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╭─{'─' * start}^{c.ENDC}{LINESEP}") + offset -= length + start = -1 # mark multi-line + else: + # Error is ending within current line + first_half = repr(line[:offset])[1:-1] + line = repr(line[offset:])[1:-1] + rv.write(f"{c.FAIL}{first_half}{c.ENDC}{line}{LINESEP}") + size = _unicode_width(first_half) + if start >= 0: + # Mark single-line error + rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}" + f"{c.FAIL}{'^' * size} {hint}{c.ENDC}") + else: + # End of multi-line error + rv.write(f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}") + break + + if details: + rv.write(f"{LINESEP}Details: {details}") + + return rv.getvalue() + + +def _unicode_width(text): + return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else + 2 if unicodedata.east_asian_width(c) == "W" else 1 + for c in text) + + +FIELD_HINT = 0x_00_01 +FIELD_DETAILS = 0x_00_02 +FIELD_SERVER_TRACEBACK = 0x_01_01 + +# XXX: Subject to be changed/deprecated. +FIELD_POSITION_START = 0x_FF_F1 +FIELD_POSITION_END = 0x_FF_F2 +FIELD_LINE_START = 0x_FF_F3 +FIELD_COLUMN_START = 0x_FF_F4 +FIELD_UTF16_COLUMN_START = 0x_FF_F5 +FIELD_LINE_END = 0x_FF_F6 +FIELD_COLUMN_END = 0x_FF_F7 +FIELD_UTF16_COLUMN_END = 0x_FF_F8 +FIELD_CHARACTER_START = 0x_FF_F9 +FIELD_CHARACTER_END = 0x_FF_FA + + +EDGE_SEVERITY_DEBUG = 20 +EDGE_SEVERITY_INFO = 40 +EDGE_SEVERITY_NOTICE = 60 +EDGE_SEVERITY_WARNING = 80 +EDGE_SEVERITY_ERROR = 120 +EDGE_SEVERITY_FATAL = 200 +EDGE_SEVERITY_PANIC = 255 + + +# Fields to include in the json dump of the type +_JSON_FIELDS = { + 'hint': FIELD_HINT, + 'details': FIELD_DETAILS, + 'start': FIELD_CHARACTER_START, + 'end': FIELD_CHARACTER_END, + 'line': FIELD_LINE_START, + 'col': FIELD_COLUMN_START, +} + + +LINESEP = os.linesep + +try: + SHOW_HINT = {"default": True, "enabled": True, "disabled": False}[ + os.getenv("EDGEDB_ERROR_HINT", "default") + ] +except KeyError: + warnings.warn( + "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled", + stacklevel=1, + ) + SHOW_HINT = False + + +from gel.color import get_color diff --git a/gel/errors/tags.py b/gel/errors/tags.py new file mode 100644 index 00000000..275b31ac --- /dev/null +++ b/gel/errors/tags.py @@ -0,0 +1,25 @@ +__all__ = [ + 'Tag', + 'SHOULD_RECONNECT', + 'SHOULD_RETRY', +] + + +class Tag(object): + """Error tag + + Tags are used to differentiate certain properties of errors that apply to + error classes across hierarchy. + + Use ``error.has_tag(tag_name)`` to check for a tag. + """ + + def __init__(self, name): + self.name = name + + def __repr__(self): + return f'' + + +SHOULD_RECONNECT = Tag('SHOULD_RECONNECT') +SHOULD_RETRY = Tag('SHOULD_RETRY') diff --git a/gel/introspect.py b/gel/introspect.py new file mode 100644 index 00000000..4f7884d6 --- /dev/null +++ b/gel/introspect.py @@ -0,0 +1,66 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# IMPORTANT: this private API is subject to change. + + +import functools +import typing + +from gel.datatypes import datatypes as dt +from gel.enums import ElementKind + + +class PointerDescription(typing.NamedTuple): + + name: str + kind: ElementKind + implicit: bool + + +class ObjectDescription(typing.NamedTuple): + + pointers: typing.Tuple[PointerDescription, ...] + + +@functools.lru_cache() +def _introspect_object_desc(desc) -> ObjectDescription: + pointers = [] + # Call __dir__ directly as dir() scrambles the order. + for name in desc.__dir__(): + if desc.is_link(name): + kind = ElementKind.LINK + elif desc.is_linkprop(name): + continue + else: + kind = ElementKind.PROPERTY + + pointers.append( + PointerDescription( + name=name, + kind=kind, + implicit=desc.is_implicit(name))) + + return ObjectDescription( + pointers=tuple(pointers)) + + +def introspect_object(obj) -> ObjectDescription: + return _introspect_object_desc( + dt.get_object_descriptor(obj)) diff --git a/gel/options.py b/gel/options.py new file mode 100644 index 00000000..6f83d421 --- /dev/null +++ b/gel/options.py @@ -0,0 +1,492 @@ +import abc +import enum +import logging +import random +import typing +import sys +from collections import namedtuple + +from . import errors + + +logger = logging.getLogger('gel') + + +_RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"]) + + +def default_backoff(attempt): + return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001 + + +WarningHandler = typing.Callable[ + [typing.Tuple[errors.EdgeDBError, ...], typing.Any], + typing.Any, +] + + +def raise_warnings(warnings, res): + if ( + len(warnings) > 1 + and sys.version_info >= (3, 11) + ): + raise ExceptionGroup( # noqa + "Query produced warnings", warnings + ) + else: + raise warnings[0] + + +def log_warnings(warnings, res): + for w in warnings: + logger.warning("EdgeDB warning: %s", str(w)) + return res + + +class RetryCondition: + """Specific condition to retry on for fine-grained control""" + TransactionConflict = enum.auto() + NetworkError = enum.auto() + + +class IsolationLevel: + """Isolation level for transaction""" + Serializable = "SERIALIZABLE" + + +class RetryOptions: + """An immutable class that contains rules for `transaction()`""" + __slots__ = ['_default', '_overrides'] + + def __init__(self, attempts: int, backoff=default_backoff): + self._default = _RetryRule(attempts, backoff) + self._overrides = None + + def with_rule(self, condition, attempts=None, backoff=None): + default = self._default + overrides = self._overrides + if overrides is None: + overrides = {} + else: + overrides = overrides.copy() + overrides[condition] = _RetryRule( + default.attempts if attempts is None else attempts, + default.backoff if backoff is None else backoff, + ) + result = RetryOptions.__new__(RetryOptions) + result._default = default + result._overrides = overrides + return result + + @classmethod + def defaults(cls): + return cls( + attempts=3, + backoff=default_backoff, + ) + + def get_rule_for_exception(self, exception): + default = self._default + overrides = self._overrides + res = default + if overrides: + if isinstance(exception, errors.TransactionConflictError): + res = overrides.get(RetryCondition.TransactionConflict, res) + elif isinstance(exception, errors.ClientError): + res = overrides.get(RetryCondition.NetworkError, res) + return res + + +class TransactionOptions: + """Options for `transaction()`""" + __slots__ = ['_isolation', '_readonly', '_deferrable'] + + def __init__( + self, + isolation: IsolationLevel=IsolationLevel.Serializable, + readonly: bool = False, + deferrable: bool = False, + ): + self._isolation = isolation + self._readonly = readonly + self._deferrable = deferrable + + @classmethod + def defaults(cls): + return cls() + + def start_transaction_query(self): + isolation = str(self._isolation) + if self._readonly: + mode = 'READ ONLY' + else: + mode = 'READ WRITE' + + if self._deferrable: + defer = 'DEFERRABLE' + else: + defer = 'NOT DEFERRABLE' + + return f'START TRANSACTION ISOLATION {isolation}, {mode}, {defer};' + + def __repr__(self): + return ( + f'<{self.__class__.__name__} ' + f'isolation:{self._isolation}, ' + f'readonly:{self._readonly}, ' + f'deferrable:{self._deferrable}>' + ) + + +class State: + __slots__ = ['_module', '_aliases', '_config', '_globals'] + + def __init__( + self, + default_module: typing.Optional[str] = None, + module_aliases: typing.Mapping[str, str] = None, + config: typing.Mapping[str, typing.Any] = None, + globals_: typing.Mapping[str, typing.Any] = None, + ): + self._module = default_module + self._aliases = {} if module_aliases is None else dict(module_aliases) + self._config = {} if config is None else dict(config) + self._globals = ( + {} if globals_ is None else self.with_globals(globals_)._globals + ) + + @classmethod + def _new(cls, default_module, module_aliases, config, globals_): + rv = cls.__new__(cls) + rv._module = default_module + rv._aliases = module_aliases + rv._config = config + rv._globals = globals_ + return rv + + @classmethod + def defaults(cls): + return cls() + + def with_default_module(self, module: typing.Optional[str] = None): + return self._new( + default_module=module, + module_aliases=self._aliases, + config=self._config, + globals_=self._globals, + ) + + def with_module_aliases(self, *args, **aliases): + if len(args) > 1: + raise errors.InvalidArgumentError( + "with_module_aliases() takes from 0 to 1 positional arguments " + "but {} were given".format(len(args)) + ) + aliases_dict = args[0] if args else {} + aliases_dict.update(aliases) + new_aliases = self._aliases.copy() + new_aliases.update(aliases_dict) + return self._new( + default_module=self._module, + module_aliases=new_aliases, + config=self._config, + globals_=self._globals, + ) + + def with_config(self, *args, **config): + if len(args) > 1: + raise errors.InvalidArgumentError( + "with_config() takes from 0 to 1 positional arguments " + "but {} were given".format(len(args)) + ) + config_dict = args[0] if args else {} + config_dict.update(config) + new_config = self._config.copy() + new_config.update(config_dict) + return self._new( + default_module=self._module, + module_aliases=self._aliases, + config=new_config, + globals_=self._globals, + ) + + def resolve(self, name: str) -> str: + parts = name.split("::", 1) + if len(parts) == 1: + return f"{self._module or 'default'}::{name}" + elif len(parts) == 2: + mod, name = parts + mod = self._aliases.get(mod, mod) + return f"{mod}::{name}" + else: + raise AssertionError('broken split') + + def with_globals(self, *args, **globals_): + if len(args) > 1: + raise errors.InvalidArgumentError( + "with_globals() takes from 0 to 1 positional arguments " + "but {} were given".format(len(args)) + ) + new_globals = self._globals.copy() + if args: + for k, v in args[0].items(): + new_globals[self.resolve(k)] = v + for k, v in globals_.items(): + new_globals[self.resolve(k)] = v + return self._new( + default_module=self._module, + module_aliases=self._aliases, + config=self._config, + globals_=new_globals, + ) + + def without_module_aliases(self, *aliases): + if not aliases: + new_aliases = {} + else: + new_aliases = self._aliases.copy() + for alias in aliases: + new_aliases.pop(alias, None) + return self._new( + default_module=self._module, + module_aliases=new_aliases, + config=self._config, + globals_=self._globals, + ) + + def without_config(self, *config_names): + if not config_names: + new_config = {} + else: + new_config = self._config.copy() + for name in config_names: + new_config.pop(name, None) + return self._new( + default_module=self._module, + module_aliases=self._aliases, + config=new_config, + globals_=self._globals, + ) + + def without_globals(self, *global_names): + if not global_names: + new_globals = {} + else: + new_globals = self._globals.copy() + for name in global_names: + new_globals.pop(self.resolve(name), None) + return self._new( + default_module=self._module, + module_aliases=self._aliases, + config=self._config, + globals_=new_globals, + ) + + def as_dict(self): + rv = {} + if self._module is not None: + rv["module"] = self._module + if self._aliases: + rv["aliases"] = list(self._aliases.items()) + if self._config: + rv["config"] = self._config + if self._globals: + rv["globals"] = self._globals + return rv + + +class _OptionsMixin: + def __init__(self, *args, **kwargs): + self._options = _Options.defaults() + super().__init__(*args, **kwargs) + + @abc.abstractmethod + def _shallow_clone(self): + pass + + def with_transaction_options(self, options: TransactionOptions = None): + """Returns object with adjusted options for future transactions. + + :param options TransactionOptions: + Object that encapsulates transaction options. + + This method returns a "shallow copy" of the current object + with modified transaction options. + + Both ``self`` and returned object can be used after, but when using + them transaction options applied will be different. + + Transaction options are used by the ``transaction`` method. + """ + result = self._shallow_clone() + result._options = self._options.with_transaction_options(options) + return result + + def with_retry_options(self, options: RetryOptions=None): + """Returns object with adjusted options for future retrying + transactions. + + :param options RetryOptions: + Object that encapsulates retry options. + + This method returns a "shallow copy" of the current object + with modified retry options. + + Both ``self`` and returned object can be used after, but when using + them retry options applied will be different. + """ + + result = self._shallow_clone() + result._options = self._options.with_retry_options(options) + return result + + def with_warning_handler(self, warning_handler: WarningHandler=None): + """Returns object with adjusted options for handling warnings. + + :param warning_handler WarningHandler: + Function for handling warnings. It is passed a tuple of warnings + and the query result and returns a potentially updated query + result. + + This method returns a "shallow copy" of the current object + with modified retry options. + + Both ``self`` and returned object can be used after, but when using + them retry options applied will be different. + """ + + result = self._shallow_clone() + result._options = self._options.with_warning_handler(warning_handler) + return result + + def with_state(self, state: State): + result = self._shallow_clone() + result._options = self._options.with_state(state) + return result + + def with_default_module(self, module: typing.Optional[str] = None): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.with_default_module(module) + ) + return result + + def with_module_aliases(self, *args, **aliases): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.with_module_aliases(*args, **aliases) + ) + return result + + def with_config(self, *args, **config): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.with_config(*args, **config) + ) + return result + + def with_globals(self, *args, **globals_): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.with_globals(*args, **globals_) + ) + return result + + def without_module_aliases(self, *aliases): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.without_module_aliases(*aliases) + ) + return result + + def without_config(self, *config_names): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.without_config(*config_names) + ) + return result + + def without_globals(self, *global_names): + result = self._shallow_clone() + result._options = self._options.with_state( + self._options.state.without_globals(*global_names) + ) + return result + + +class _Options: + """Internal class for storing connection options""" + + __slots__ = [ + '_retry_options', '_transaction_options', '_state', + '_warning_handler' + ] + + def __init__( + self, + retry_options: RetryOptions, + transaction_options: TransactionOptions, + state: State, + warning_handler: WarningHandler, + ): + self._retry_options = retry_options + self._transaction_options = transaction_options + self._state = state + self._warning_handler = warning_handler + + @property + def retry_options(self): + return self._retry_options + + @property + def transaction_options(self): + return self._transaction_options + + @property + def state(self): + return self._state + + @property + def warning_handler(self): + return self._warning_handler + + def with_retry_options(self, options: RetryOptions): + return _Options( + options, + self._transaction_options, + self._state, + self._warning_handler, + ) + + def with_transaction_options(self, options: TransactionOptions): + return _Options( + self._retry_options, + options, + self._state, + self._warning_handler, + ) + + def with_state(self, state: State): + return _Options( + self._retry_options, + self._transaction_options, + state, + self._warning_handler, + ) + + def with_warning_handler(self, warning_handler: WarningHandler): + return _Options( + self._retry_options, + self._transaction_options, + self._state, + warning_handler, + ) + + @classmethod + def defaults(cls): + return cls( + RetryOptions.defaults(), + TransactionOptions.defaults(), + State.defaults(), + log_warnings, + ) diff --git a/edgedb/pgproto b/gel/pgproto similarity index 100% rename from edgedb/pgproto rename to gel/pgproto diff --git a/gel/platform.py b/gel/platform.py new file mode 100644 index 00000000..55410532 --- /dev/null +++ b/gel/platform.py @@ -0,0 +1,52 @@ +import functools +import os +import pathlib +import sys + +if sys.platform == "darwin": + def config_dir() -> pathlib.Path: + return ( + pathlib.Path.home() / "Library" / "Application Support" / "edgedb" + ) + + IS_WINDOWS = False + +elif sys.platform == "win32": + import ctypes + from ctypes import windll + + def config_dir() -> pathlib.Path: + path_buf = ctypes.create_unicode_buffer(255) + csidl = 28 # CSIDL_LOCAL_APPDATA + windll.shell32.SHGetFolderPathW(0, csidl, 0, 0, path_buf) + return pathlib.Path(path_buf.value) / "EdgeDB" / "config" + + IS_WINDOWS = True + +else: + def config_dir() -> pathlib.Path: + xdg_conf_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", ".")) + if not xdg_conf_dir.is_absolute(): + xdg_conf_dir = pathlib.Path.home() / ".config" + return xdg_conf_dir / "edgedb" + + IS_WINDOWS = False + + +def old_config_dir() -> pathlib.Path: + return pathlib.Path.home() / ".edgedb" + + +def search_config_dir(*suffix): + rv = functools.reduce(lambda p1, p2: p1 / p2, [config_dir(), *suffix]) + if rv.exists(): + return rv + + fallback = functools.reduce( + lambda p1, p2: p1 / p2, [old_config_dir(), *suffix] + ) + if fallback.exists(): + return fallback + + # None of the searched files exists, return the new path. + return rv diff --git a/edgedb/protocol/.gitignore b/gel/protocol/.gitignore similarity index 100% rename from edgedb/protocol/.gitignore rename to gel/protocol/.gitignore diff --git a/gel/protocol/__init__.py b/gel/protocol/__init__.py new file mode 100644 index 00000000..46ceb625 --- /dev/null +++ b/gel/protocol/__init__.py @@ -0,0 +1,17 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/edgedb/protocol/asyncio_proto.pxd b/gel/protocol/asyncio_proto.pxd similarity index 95% rename from edgedb/protocol/asyncio_proto.pxd rename to gel/protocol/asyncio_proto.pxd index 4eeb68e9..8919c24b 100644 --- a/edgedb/protocol/asyncio_proto.pxd +++ b/gel/protocol/asyncio_proto.pxd @@ -19,7 +19,7 @@ from . cimport protocol -from edgedb.pgproto.debug cimport PG_DEBUG +from gel.pgproto.debug cimport PG_DEBUG cdef class AsyncIOProtocol(protocol.SansIOProtocolBackwardsCompatible): diff --git a/edgedb/protocol/asyncio_proto.pyx b/gel/protocol/asyncio_proto.pyx similarity index 98% rename from edgedb/protocol/asyncio_proto.pyx rename to gel/protocol/asyncio_proto.pyx index db3f616f..39d3b62a 100644 --- a/edgedb/protocol/asyncio_proto.pyx +++ b/gel/protocol/asyncio_proto.pyx @@ -19,8 +19,8 @@ import asyncio -from edgedb import errors -from edgedb.pgproto.pgproto cimport ( +from gel import errors +from gel.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, ) diff --git a/edgedb/protocol/blocking_proto.pxd b/gel/protocol/blocking_proto.pxd similarity index 95% rename from edgedb/protocol/blocking_proto.pxd rename to gel/protocol/blocking_proto.pxd index cd19c2bb..13e4707e 100644 --- a/edgedb/protocol/blocking_proto.pxd +++ b/gel/protocol/blocking_proto.pxd @@ -19,7 +19,7 @@ from . cimport protocol -from edgedb.pgproto.debug cimport PG_DEBUG +from gel.pgproto.debug cimport PG_DEBUG cdef class BlockingIOProtocol(protocol.SansIOProtocolBackwardsCompatible): diff --git a/edgedb/protocol/blocking_proto.pyx b/gel/protocol/blocking_proto.pyx similarity index 99% rename from edgedb/protocol/blocking_proto.pyx rename to gel/protocol/blocking_proto.pyx index ea4c1c16..839fbe42 100644 --- a/edgedb/protocol/blocking_proto.pyx +++ b/gel/protocol/blocking_proto.pyx @@ -19,7 +19,7 @@ import socket import time -from edgedb.pgproto.pgproto cimport ( +from gel.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, ) diff --git a/edgedb/protocol/codecs/array.pxd b/gel/protocol/codecs/array.pxd similarity index 100% rename from edgedb/protocol/codecs/array.pxd rename to gel/protocol/codecs/array.pxd diff --git a/edgedb/protocol/codecs/array.pyx b/gel/protocol/codecs/array.pyx similarity index 100% rename from edgedb/protocol/codecs/array.pyx rename to gel/protocol/codecs/array.pyx diff --git a/edgedb/protocol/codecs/base.pxd b/gel/protocol/codecs/base.pxd similarity index 100% rename from edgedb/protocol/codecs/base.pxd rename to gel/protocol/codecs/base.pxd diff --git a/edgedb/protocol/codecs/base.pyx b/gel/protocol/codecs/base.pyx similarity index 100% rename from edgedb/protocol/codecs/base.pyx rename to gel/protocol/codecs/base.pyx diff --git a/edgedb/protocol/codecs/codecs.pxd b/gel/protocol/codecs/codecs.pxd similarity index 100% rename from edgedb/protocol/codecs/codecs.pxd rename to gel/protocol/codecs/codecs.pxd diff --git a/edgedb/protocol/codecs/codecs.pyx b/gel/protocol/codecs/codecs.pyx similarity index 99% rename from edgedb/protocol/codecs/codecs.pyx rename to gel/protocol/codecs/codecs.pyx index 2887320b..afa3b83f 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/gel/protocol/codecs/codecs.pyx @@ -21,9 +21,9 @@ import array import decimal import uuid import datetime -from edgedb import describe -from edgedb import enums -from edgedb.datatypes import datatypes +from gel import describe +from gel import enums +from gel.datatypes import datatypes from libc.string cimport memcpy from cpython.bytes cimport PyBytes_FromStringAndSize diff --git a/edgedb/protocol/codecs/edb_types.pxi b/gel/protocol/codecs/edb_types.pxi similarity index 100% rename from edgedb/protocol/codecs/edb_types.pxi rename to gel/protocol/codecs/edb_types.pxi diff --git a/edgedb/protocol/codecs/enum.pxd b/gel/protocol/codecs/enum.pxd similarity index 100% rename from edgedb/protocol/codecs/enum.pxd rename to gel/protocol/codecs/enum.pxd diff --git a/edgedb/protocol/codecs/enum.pyx b/gel/protocol/codecs/enum.pyx similarity index 92% rename from edgedb/protocol/codecs/enum.pyx rename to gel/protocol/codecs/enum.pyx index d7fe5620..0020397f 100644 --- a/edgedb/protocol/codecs/enum.pyx +++ b/gel/protocol/codecs/enum.pyx @@ -28,7 +28,7 @@ cdef class EnumCodec(BaseCodec): obj = self.cls._try_from(obj) except (TypeError, ValueError): raise TypeError( - f'a str or edgedb.EnumValue(__tid__={self.cls.__tid__}) ' + f'a str or gel.EnumValue(__tid__={self.cls.__tid__}) ' f'is expected as a valid enum argument, ' f'got {type(obj).__name__}') from None pgproto.text_encode(DEFAULT_CODEC_CONTEXT, buf, str(obj)) @@ -49,8 +49,8 @@ cdef class EnumCodec(BaseCodec): cls = "DerivedEnumValue" bases = (datatypes.EnumValue,) classdict = enum.EnumMeta.__prepare__(cls, bases) - classdict["__module__"] = "edgedb" - classdict["__qualname__"] = "edgedb.DerivedEnumValue" + classdict["__module__"] = "gel" + classdict["__qualname__"] = "gel.DerivedEnumValue" classdict["__tid__"] = pgproto.UUID(tid) for label in enum_labels: classdict[label.upper()] = label diff --git a/edgedb/protocol/codecs/namedtuple.pxd b/gel/protocol/codecs/namedtuple.pxd similarity index 100% rename from edgedb/protocol/codecs/namedtuple.pxd rename to gel/protocol/codecs/namedtuple.pxd diff --git a/edgedb/protocol/codecs/namedtuple.pyx b/gel/protocol/codecs/namedtuple.pyx similarity index 100% rename from edgedb/protocol/codecs/namedtuple.pyx rename to gel/protocol/codecs/namedtuple.pyx diff --git a/edgedb/protocol/codecs/object.pxd b/gel/protocol/codecs/object.pxd similarity index 100% rename from edgedb/protocol/codecs/object.pxd rename to gel/protocol/codecs/object.pxd diff --git a/edgedb/protocol/codecs/object.pyx b/gel/protocol/codecs/object.pyx similarity index 100% rename from edgedb/protocol/codecs/object.pyx rename to gel/protocol/codecs/object.pyx diff --git a/edgedb/protocol/codecs/range.pxd b/gel/protocol/codecs/range.pxd similarity index 100% rename from edgedb/protocol/codecs/range.pxd rename to gel/protocol/codecs/range.pxd diff --git a/edgedb/protocol/codecs/range.pyx b/gel/protocol/codecs/range.pyx similarity index 99% rename from edgedb/protocol/codecs/range.pyx rename to gel/protocol/codecs/range.pyx index ea573b89..52a656c4 100644 --- a/edgedb/protocol/codecs/range.pyx +++ b/gel/protocol/codecs/range.pyx @@ -17,7 +17,7 @@ # -from edgedb.datatypes import range as range_mod +from gel.datatypes import range as range_mod cdef uint8_t RANGE_EMPTY = 0x01 diff --git a/edgedb/protocol/codecs/scalar.pxd b/gel/protocol/codecs/scalar.pxd similarity index 100% rename from edgedb/protocol/codecs/scalar.pxd rename to gel/protocol/codecs/scalar.pxd diff --git a/edgedb/protocol/codecs/scalar.pyx b/gel/protocol/codecs/scalar.pyx similarity index 100% rename from edgedb/protocol/codecs/scalar.pyx rename to gel/protocol/codecs/scalar.pyx diff --git a/edgedb/protocol/codecs/set.pxd b/gel/protocol/codecs/set.pxd similarity index 100% rename from edgedb/protocol/codecs/set.pxd rename to gel/protocol/codecs/set.pxd diff --git a/edgedb/protocol/codecs/set.pyx b/gel/protocol/codecs/set.pyx similarity index 100% rename from edgedb/protocol/codecs/set.pyx rename to gel/protocol/codecs/set.pyx diff --git a/edgedb/protocol/codecs/tuple.pxd b/gel/protocol/codecs/tuple.pxd similarity index 100% rename from edgedb/protocol/codecs/tuple.pxd rename to gel/protocol/codecs/tuple.pxd diff --git a/edgedb/protocol/codecs/tuple.pyx b/gel/protocol/codecs/tuple.pyx similarity index 100% rename from edgedb/protocol/codecs/tuple.pyx rename to gel/protocol/codecs/tuple.pyx diff --git a/edgedb/protocol/consts.pxi b/gel/protocol/consts.pxi similarity index 100% rename from edgedb/protocol/consts.pxi rename to gel/protocol/consts.pxi diff --git a/edgedb/protocol/cpythonx.pxd b/gel/protocol/cpythonx.pxd similarity index 100% rename from edgedb/protocol/cpythonx.pxd rename to gel/protocol/cpythonx.pxd diff --git a/edgedb/protocol/lru.pxd b/gel/protocol/lru.pxd similarity index 100% rename from edgedb/protocol/lru.pxd rename to gel/protocol/lru.pxd diff --git a/edgedb/protocol/lru.pyx b/gel/protocol/lru.pyx similarity index 100% rename from edgedb/protocol/lru.pyx rename to gel/protocol/lru.pyx diff --git a/edgedb/protocol/protocol.pxd b/gel/protocol/protocol.pxd similarity index 97% rename from edgedb/protocol/protocol.pxd rename to gel/protocol/protocol.pxd index b6e143ae..7c5f943f 100644 --- a/edgedb/protocol/protocol.pxd +++ b/gel/protocol/protocol.pxd @@ -23,14 +23,14 @@ cimport cpython from libc.stdint cimport int16_t, int32_t, uint16_t, \ uint32_t, int64_t, uint64_t -from edgedb.pgproto.pgproto cimport ( +from gel.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, ) -from edgedb.pgproto cimport pgproto -from edgedb.pgproto.debug cimport PG_DEBUG +from gel.pgproto cimport pgproto +from gel.pgproto.debug cimport PG_DEBUG include "./lru.pxd" diff --git a/edgedb/protocol/protocol.pyx b/gel/protocol/protocol.pyx similarity index 99% rename from edgedb/protocol/protocol.pyx rename to gel/protocol/protocol.pyx index e464a039..e22efef7 100644 --- a/edgedb/protocol/protocol.pyx +++ b/gel/protocol/protocol.pyx @@ -30,7 +30,7 @@ import types import typing import weakref -from edgedb.pgproto.pgproto cimport ( +from gel.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, @@ -44,22 +44,22 @@ from edgedb.pgproto.pgproto cimport ( frb_get_len, ) -from edgedb.pgproto import pgproto -from edgedb.pgproto cimport pgproto -from edgedb.pgproto cimport hton -from edgedb.pgproto.pgproto import UUID +from gel.pgproto import pgproto +from gel.pgproto cimport pgproto +from gel.pgproto cimport hton +from gel.pgproto.pgproto import UUID from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t, \ UINT32_MAX -from edgedb.datatypes cimport datatypes +from gel.datatypes cimport datatypes from . cimport cpythonx -from edgedb import enums -from edgedb import errors -from edgedb import scram +from gel import enums +from gel import errors +from gel import scram include "./consts.pxi" diff --git a/edgedb/protocol/protocol_v0.pxd b/gel/protocol/protocol_v0.pxd similarity index 100% rename from edgedb/protocol/protocol_v0.pxd rename to gel/protocol/protocol_v0.pxd diff --git a/edgedb/protocol/protocol_v0.pyx b/gel/protocol/protocol_v0.pyx similarity index 99% rename from edgedb/protocol/protocol_v0.pyx rename to gel/protocol/protocol_v0.pyx index 2a4cb80b..1ba6852c 100644 --- a/edgedb/protocol/protocol_v0.pyx +++ b/gel/protocol/protocol_v0.pyx @@ -17,7 +17,7 @@ # -from edgedb import enums +from gel import enums DEF QUERY_OPT_IMPLICIT_LIMIT = 0xFF01 diff --git a/gel/py.typed b/gel/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/gel/scram/__init__.py b/gel/scram/__init__.py new file mode 100644 index 00000000..62f3e712 --- /dev/null +++ b/gel/scram/__init__.py @@ -0,0 +1,434 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Helpers for SCRAM authentication.""" + +import base64 +import hashlib +import hmac +import os +import typing + +from .saslprep import saslprep + + +RAW_NONCE_LENGTH = 18 + +# Per recommendations in RFC 7677. +DEFAULT_SALT_LENGTH = 16 +DEFAULT_ITERATIONS = 4096 + + +def generate_salt(length: int = DEFAULT_SALT_LENGTH) -> bytes: + return os.urandom(length) + + +def generate_nonce(length: int = RAW_NONCE_LENGTH) -> str: + return B64(os.urandom(length)) + + +def build_verifier(password: str, *, salt: typing.Optional[bytes] = None, + iterations: int = DEFAULT_ITERATIONS) -> str: + """Build the SCRAM verifier for the given password. + + Returns a string in the following format: + + "$:$:" + + The salt and keys are base64-encoded values. + """ + password = saslprep(password).encode('utf-8') + + if salt is None: + salt = generate_salt() + + salted_password = get_salted_password(password, salt, iterations) + client_key = get_client_key(salted_password) + stored_key = H(client_key) + server_key = get_server_key(salted_password) + + return (f'SCRAM-SHA-256${iterations}:{B64(salt)}$' + f'{B64(stored_key)}:{B64(server_key)}') + + +class SCRAMVerifier(typing.NamedTuple): + + mechanism: str + iterations: int + salt: bytes + stored_key: bytes + server_key: bytes + + +def parse_verifier(verifier: str) -> SCRAMVerifier: + + parts = verifier.split('$') + if len(parts) != 3: + raise ValueError('invalid SCRAM verifier') + + mechanism = parts[0] + if mechanism != 'SCRAM-SHA-256': + raise ValueError('invalid SCRAM verifier') + + iterations, _, salt = parts[1].partition(':') + stored_key, _, server_key = parts[2].partition(':') + if not salt or not server_key: + raise ValueError('invalid SCRAM verifier') + + try: + iterations = int(iterations) + except ValueError: + raise ValueError('invalid SCRAM verifier') from None + + return SCRAMVerifier( + mechanism=mechanism, + iterations=iterations, + salt=base64.b64decode(salt), + stored_key=base64.b64decode(stored_key), + server_key=base64.b64decode(server_key), + ) + + +def parse_client_first_message(resp: bytes): + + # Relevant bits of RFC 5802: + # + # saslname = 1*(value-safe-char / "=2C" / "=3D") + # ;; Conforms to . + # + # authzid = "a=" saslname + # ;; Protocol specific. + # + # cb-name = 1*(ALPHA / DIGIT / "." / "-") + # ;; See RFC 5056, Section 7. + # ;; E.g., "tls-server-end-point" or + # ;; "tls-unique". + # + # gs2-cbind-flag = ("p=" cb-name) / "n" / "y" + # ;; "n" -> client doesn't support channel binding. + # ;; "y" -> client does support channel binding + # ;; but thinks the server does not. + # ;; "p" -> client requires channel binding. + # ;; The selected channel binding follows "p=". + # + # gs2-header = gs2-cbind-flag "," [ authzid ] "," + # ;; GS2 header for SCRAM + # ;; (the actual GS2 header includes an optional + # ;; flag to indicate that the GSS mechanism is not + # ;; "standard", but since SCRAM is "standard", we + # ;; don't include that flag). + # + # username = "n=" saslname + # ;; Usernames are prepared using SASLprep. + # + # reserved-mext = "m=" 1*(value-char) + # ;; Reserved for signaling mandatory extensions. + # ;; The exact syntax will be defined in + # ;; the future. + # + # nonce = "r=" c-nonce [s-nonce] + # ;; Second part provided by server. + # + # c-nonce = printable + # + # client-first-message-bare = + # [reserved-mext ","] + # username "," nonce ["," extensions] + # + # client-first-message = + # gs2-header client-first-message-bare + + attrs = resp.split(b',') + + cb_attr = attrs[0] + if cb_attr == b'y': + cb = True + elif cb_attr == b'n': + cb = False + elif cb_attr[0:1] == b'p': + _, _, cb = cb_attr.partition(b'=') + if not cb: + raise ValueError('malformed SCRAM message') + else: + raise ValueError('malformed SCRAM message') + + authzid_attr = attrs[1] + if authzid_attr: + if authzid_attr[0:1] != b'a': + raise ValueError('malformed SCRAM message') + _, _, authzid = authzid_attr.partition(b'=') + else: + authzid = None + + user_attr = attrs[2] + if user_attr[0:1] == b'm': + raise ValueError('unsupported SCRAM extensions in message') + elif user_attr[0:1] != b'n': + raise ValueError('malformed SCRAM message') + + _, _, user = user_attr.partition(b'=') + + nonce_attr = attrs[3] + if nonce_attr[0:1] != b'r': + raise ValueError('malformed SCRAM message') + + _, _, nonce_bin = nonce_attr.partition(b'=') + nonce = nonce_bin.decode('ascii') + if not nonce.isprintable(): + raise ValueError('invalid characters in client nonce') + + # ["," extensions] are ignored + + return len(cb_attr) + 2, cb, authzid, user, nonce + + +def parse_client_final_message( + msg: bytes, client_nonce: str, server_nonce: str): + + # Relevant bits of RFC 5802: + # + # gs2-header = gs2-cbind-flag "," [ authzid ] "," + # ;; GS2 header for SCRAM + # ;; (the actual GS2 header includes an optional + # ;; flag to indicate that the GSS mechanism is not + # ;; "standard", but since SCRAM is "standard", we + # ;; don't include that flag). + # + # cbind-input = gs2-header [ cbind-data ] + # ;; cbind-data MUST be present for + # ;; gs2-cbind-flag of "p" and MUST be absent + # ;; for "y" or "n". + # + # channel-binding = "c=" base64 + # ;; base64 encoding of cbind-input. + # + # proof = "p=" base64 + # + # client-final-message-without-proof = + # channel-binding "," nonce ["," + # extensions] + # + # client-final-message = + # client-final-message-without-proof "," proof + + attrs = msg.split(b',') + + cb_attr = attrs[0] + if cb_attr[0:1] != b'c': + raise ValueError('malformed SCRAM message') + + _, _, cb_data = cb_attr.partition(b'=') + + nonce_attr = attrs[1] + if nonce_attr[0:1] != b'r': + raise ValueError('malformed SCRAM message') + + _, _, nonce_bin = nonce_attr.partition(b'=') + nonce = nonce_bin.decode('ascii') + + expected_nonce = f'{client_nonce}{server_nonce}' + + if nonce != expected_nonce: + raise ValueError( + 'invalid SCRAM client-final message: nonce does not match') + + proof = None + + for attr in attrs[2:]: + if attr[0:1] == b'p': + _, _, proof = attr.partition(b'=') + proof_attr_len = len(attr) + proof = base64.b64decode(proof) + elif proof is not None: + raise ValueError('malformed SCRAM message') + + if proof is None: + raise ValueError('malformed SCRAM message') + + return cb_data, proof, proof_attr_len + 1 + + +def build_client_first_message(client_nonce: str, username: str) -> str: + + bare = f'n={saslprep(username)},r={client_nonce}' + return f'n,,{bare}', bare + + +def build_server_first_message(server_nonce: str, client_nonce: str, + salt: bytes, iterations: int) -> str: + + return ( + f'r={client_nonce}{server_nonce},' + f's={B64(salt)},i={iterations}' + ) + + +def build_auth_message( + client_first_bare: bytes, + server_first: bytes, client_final: bytes) -> bytes: + + return b'%b,%b,%b' % (client_first_bare, server_first, client_final) + + +def build_client_final_message( + password: str, + salt: bytes, + iterations: int, + client_first_bare: bytes, + server_first: bytes, + server_nonce: str) -> str: + + client_final = f'c=biws,r={server_nonce}' + + AuthMessage = build_auth_message( + client_first_bare, server_first, client_final.encode('utf-8')) + + SaltedPassword = get_salted_password( + saslprep(password).encode('utf-8'), + salt, + iterations) + + ClientKey = get_client_key(SaltedPassword) + StoredKey = H(ClientKey) + ClientSignature = HMAC(StoredKey, AuthMessage) + ClientProof = XOR(ClientKey, ClientSignature) + + ServerKey = get_server_key(SaltedPassword) + ServerProof = HMAC(ServerKey, AuthMessage) + + return f'{client_final},p={B64(ClientProof)}', ServerProof + + +def build_server_final_message( + client_first_bare: bytes, server_first: bytes, + client_final: bytes, server_key: bytes) -> str: + + AuthMessage = build_auth_message( + client_first_bare, server_first, client_final) + ServerSignature = HMAC(server_key, AuthMessage) + return f'v={B64(ServerSignature)}' + + +def parse_server_first_message(msg: bytes): + + attrs = msg.split(b',') + + nonce_attr = attrs[0] + if nonce_attr[0:1] != b'r': + raise ValueError('malformed SCRAM message') + + _, _, nonce_bin = nonce_attr.partition(b'=') + nonce = nonce_bin.decode('ascii') + if not nonce.isprintable(): + raise ValueError('malformed SCRAM message') + + salt_attr = attrs[1] + if salt_attr[0:1] != b's': + raise ValueError('malformed SCRAM message') + + _, _, salt_b64 = salt_attr.partition(b'=') + salt = base64.b64decode(salt_b64) + + iter_attr = attrs[2] + if iter_attr[0:1] != b'i': + raise ValueError('malformed SCRAM message') + + _, _, iterations = iter_attr.partition(b'=') + + try: + itercount = int(iterations) + except ValueError: + raise ValueError('malformed SCRAM message') from None + + return nonce, salt, itercount + + +def parse_server_final_message(msg: bytes): + + attrs = msg.split(b',') + + nonce_attr = attrs[0] + if nonce_attr[0:1] != b'v': + raise ValueError('malformed SCRAM message') + + _, _, signature_b64 = nonce_attr.partition(b'=') + signature = base64.b64decode(signature_b64) + + return signature + + +def verify_password(password: bytes, verifier: str) -> bool: + """Check the given password against a verifier. + + Returns True if the password is OK, False otherwise. + """ + + password = saslprep(password).encode('utf-8') + v = parse_verifier(verifier) + salted_password = get_salted_password(password, v.salt, v.iterations) + computed_key = get_server_key(salted_password) + return v.server_key == computed_key + + +def verify_client_proof(client_first: bytes, server_first: bytes, + client_final: bytes, StoredKey: bytes, + ClientProof: bytes) -> bool: + AuthMessage = build_auth_message(client_first, server_first, client_final) + ClientSignature = HMAC(StoredKey, AuthMessage) + ClientKey = XOR(ClientProof, ClientSignature) + return H(ClientKey) == StoredKey + + +def B64(val: bytes) -> str: + """Return base64-encoded string representation of input binary data.""" + return base64.b64encode(val).decode() + + +def HMAC(key: bytes, msg: bytes) -> bytes: + return hmac.new(key, msg, digestmod=hashlib.sha256).digest() + + +def XOR(a: bytes, b: bytes) -> bytes: + if len(a) != len(b): + raise ValueError('scram.XOR received operands of unequal length') + xint = int.from_bytes(a, 'big') ^ int.from_bytes(b, 'big') + return xint.to_bytes(len(a), 'big') + + +def H(s: bytes) -> bytes: + return hashlib.sha256(s).digest() + + +def get_salted_password(password: bytes, salt: bytes, + iterations: int) -> bytes: + # U1 := HMAC(str, salt + INT(1)) + H_i = U_i = HMAC(password, salt + b'\x00\x00\x00\x01') + + for _ in range(iterations - 1): + U_i = HMAC(password, U_i) + H_i = XOR(H_i, U_i) + + return H_i + + +def get_client_key(salted_password: bytes) -> bytes: + return HMAC(salted_password, b'Client Key') + + +def get_server_key(salted_password: bytes) -> bytes: + return HMAC(salted_password, b'Server Key') diff --git a/gel/scram/saslprep.py b/gel/scram/saslprep.py new file mode 100644 index 00000000..79eb84d8 --- /dev/null +++ b/gel/scram/saslprep.py @@ -0,0 +1,82 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import stringprep +import unicodedata + + +# RFC4013 section 2.3 prohibited output. +_PROHIBITED = ( + # A strict reading of RFC 4013 requires table c12 here, but + # characters from it are mapped to SPACE in the Map step. Can + # normalization reintroduce them somehow? + stringprep.in_table_c12, + stringprep.in_table_c21_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9) + + +def saslprep(data: str, prohibit_unassigned_code_points=True): + """An implementation of RFC4013 SASLprep.""" + + if data == '': + return data + + if prohibit_unassigned_code_points: + prohibited = _PROHIBITED + (stringprep.in_table_a1,) + else: + prohibited = _PROHIBITED + + # RFC3454 section 2, step 1 - Map + # RFC4013 section 2.1 mappings + # Map Non-ASCII space characters to SPACE (U+0020). Map + # commonly mapped to nothing characters to, well, nothing. + in_table_c12 = stringprep.in_table_c12 + in_table_b1 = stringprep.in_table_b1 + data = u"".join( + [u"\u0020" if in_table_c12(elt) else elt + for elt in data if not in_table_b1(elt)]) + + # RFC3454 section 2, step 2 - Normalize + # RFC4013 section 2.2 normalization + data = unicodedata.ucd_3_2_0.normalize('NFKC', data) + + in_table_d1 = stringprep.in_table_d1 + if in_table_d1(data[0]): + if not in_table_d1(data[-1]): + # RFC3454, Section 6, #3. If a string contains any + # RandALCat character, the first and last characters + # MUST be RandALCat characters. + raise ValueError("SASLprep: failed bidirectional check") + # RFC3454, Section 6, #2. If a string contains any RandALCat + # character, it MUST NOT contain any LCat character. + prohibited = prohibited + (stringprep.in_table_d2,) + else: + # RFC3454, Section 6, #3. Following the logic of #3, if + # the first character is not a RandALCat, no other character + # can be either. + prohibited = prohibited + (in_table_d1,) + + # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi + for char in data: + if any(in_table(char) for in_table in prohibited): + raise ValueError( + "SASLprep: failed prohibited character check") + + return data diff --git a/gel/transaction.py b/gel/transaction.py new file mode 100644 index 00000000..8f3ab29f --- /dev/null +++ b/gel/transaction.py @@ -0,0 +1,224 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import enum + +from . import abstract +from . import errors +from . import options + + +class TransactionState(enum.Enum): + NEW = 0 + STARTED = 1 + COMMITTED = 2 + ROLLEDBACK = 3 + FAILED = 4 + + +class BaseTransaction: + + __slots__ = ( + '_client', + '_connection', + '_options', + '_state', + '__retry', + '__iteration', + '__started', + ) + + def __init__(self, retry, client, iteration): + self._client = client + self._connection = None + self._options = retry._options.transaction_options + self._state = TransactionState.NEW + self.__retry = retry + self.__iteration = iteration + self.__started = False + + def is_active(self) -> bool: + return self._state is TransactionState.STARTED + + def __check_state_base(self, opname): + if self._state is TransactionState.COMMITTED: + raise errors.InterfaceError( + 'cannot {}; the transaction is already committed'.format( + opname)) + if self._state is TransactionState.ROLLEDBACK: + raise errors.InterfaceError( + 'cannot {}; the transaction is already rolled back'.format( + opname)) + if self._state is TransactionState.FAILED: + raise errors.InterfaceError( + 'cannot {}; the transaction is in error state'.format( + opname)) + + def __check_state(self, opname): + if self._state is not TransactionState.STARTED: + if self._state is TransactionState.NEW: + raise errors.InterfaceError( + 'cannot {}; the transaction is not yet started'.format( + opname)) + self.__check_state_base(opname) + + def _make_start_query(self): + self.__check_state_base('start') + if self._state is TransactionState.STARTED: + raise errors.InterfaceError( + 'cannot start; the transaction is already started') + + return self._options.start_transaction_query() + + def _make_commit_query(self): + self.__check_state('commit') + return 'COMMIT;' + + def _make_rollback_query(self): + self.__check_state('rollback') + return 'ROLLBACK;' + + def __repr__(self): + attrs = [] + attrs.append('state:{}'.format(self._state.name.lower())) + attrs.append(repr(self._options)) + + if self.__class__.__module__.startswith('gel.'): + mod = 'gel' + else: + mod = self.__class__.__module__ + + return '<{}.{} {} {:#x}>'.format( + mod, self.__class__.__name__, ' '.join(attrs), id(self)) + + async def _ensure_transaction(self): + if not self.__started: + self.__started = True + query = self._make_start_query() + self._connection = await self._client._impl.acquire() + if self._connection.is_closed(): + await self._connection.connect( + single_attempt=self.__iteration != 0 + ) + try: + await self._privileged_execute(query) + except BaseException: + self._state = TransactionState.FAILED + raise + else: + self._state = TransactionState.STARTED + + async def _exit(self, extype, ex): + if not self.__started: + return False + + try: + if extype is None: + query = self._make_commit_query() + state = TransactionState.COMMITTED + else: + query = self._make_rollback_query() + state = TransactionState.ROLLEDBACK + try: + await self._privileged_execute(query) + except BaseException: + self._state = TransactionState.FAILED + if extype is None: + # COMMIT itself may fail; recover in connection + await self._privileged_execute("ROLLBACK;") + raise + else: + self._state = state + except errors.EdgeDBError as err: + if ex is None: + # On commit we don't know if commit is succeeded before the + # database have received it or after it have been done but + # network is dropped before we were able to receive a response. + # On a TransactionError, though, we know the we need + # to retry. + # TODO(tailhook) should other errors have retries? + if ( + isinstance(err, errors.TransactionError) + and err.has_tag(errors.SHOULD_RETRY) + and self.__retry._retry(err) + ): + pass + else: + raise err + # If we were going to rollback, look at original error + # to find out whether we want to retry, regardless of + # the rollback error. + # In this case we ignore rollback issue as original error is more + # important, e.g. in case `CancelledError` it's important + # to propagate it to cancel the whole task. + # NOTE: rollback error is always swallowed, should we use + # on_log_message for it? + finally: + await self._client._impl.release(self._connection) + + if ( + extype is not None and + issubclass(extype, errors.EdgeDBError) and + ex.has_tag(errors.SHOULD_RETRY) + ): + return self.__retry._retry(ex) + + def _get_query_cache(self) -> abstract.QueryCache: + return self._client._get_query_cache() + + def _get_state(self) -> options.State: + return self._client._get_state() + + def _get_warning_handler(self) -> options.WarningHandler: + return self._client._get_warning_handler() + + async def _query(self, query_context: abstract.QueryContext): + await self._ensure_transaction() + return await self._connection.raw_query(query_context) + + async def _execute(self, execute_context: abstract.ExecuteContext) -> None: + await self._ensure_transaction() + await self._connection._execute(execute_context) + + async def _privileged_execute(self, query: str) -> None: + await self._connection.privileged_execute(abstract.ExecuteContext( + query=abstract.QueryWithArgs(query, (), {}), + cache=self._get_query_cache(), + state=self._get_state(), + warning_handler=self._get_warning_handler(), + )) + + +class BaseRetry: + + def __init__(self, owner): + self._owner = owner + self._iteration = 0 + self._done = False + self._next_backoff = 0 + self._options = owner._options + + def _retry(self, exc): + self._last_exception = exc + rule = self._options.retry_options.get_rule_for_exception(exc) + if self._iteration >= rule.attempts: + return False + self._done = False + self._next_backoff = rule.backoff(self._iteration) + return True diff --git a/setup.py b/setup.py index 293030af..f4f445a8 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ readme = f.read() -with open(str(_ROOT / 'edgedb' / '_version.py')) as f: +with open(str(_ROOT / 'gel' / '_version.py')) as f: for line in f: if line.startswith('__version__ ='): _, _, version = line.partition('=') @@ -106,7 +106,7 @@ break else: raise RuntimeError( - 'unable to read the version from edgedb/_version.py') + 'unable to read the version from gel/_version.py') if (_ROOT / '.git').is_dir() and 'dev' in VERSION: @@ -267,8 +267,8 @@ def finalize_options(self): INCLUDE_DIRS = [ - 'edgedb/pgproto/', - 'edgedb/datatypes', + 'gel/pgproto/', + 'gel/datatypes', ] @@ -295,47 +295,50 @@ def finalize_options(self): url='https://github.com/edgedb/edgedb-python', license='Apache License, Version 2.0', packages=setuptools.find_packages(), - provides=['edgedb'], + provides=['edgedb', 'gel'], zip_safe=False, include_package_data=True, - package_data={'edgedb': ['py.typed']}, + package_data={ + 'edgedb': ['py.typed'], + 'gel': ['py.typed'], + }, ext_modules=[ distutils_extension.Extension( - "edgedb.pgproto.pgproto", - ["edgedb/pgproto/pgproto.pyx"], + "gel.pgproto.pgproto", + ["gel/pgproto/pgproto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS), distutils_extension.Extension( - "edgedb.datatypes.datatypes", - ["edgedb/datatypes/args.c", - "edgedb/datatypes/record_desc.c", - "edgedb/datatypes/namedtuple.c", - "edgedb/datatypes/object.c", - "edgedb/datatypes/hash.c", - "edgedb/datatypes/repr.c", - "edgedb/datatypes/comp.c", - "edgedb/datatypes/datatypes.pyx"], + "gel.datatypes.datatypes", + ["gel/datatypes/args.c", + "gel/datatypes/record_desc.c", + "gel/datatypes/namedtuple.c", + "gel/datatypes/object.c", + "gel/datatypes/hash.c", + "gel/datatypes/repr.c", + "gel/datatypes/comp.c", + "gel/datatypes/datatypes.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS), distutils_extension.Extension( - "edgedb.protocol.protocol", - ["edgedb/protocol/protocol.pyx"], + "gel.protocol.protocol", + ["gel/protocol/protocol.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, include_dirs=INCLUDE_DIRS), distutils_extension.Extension( - "edgedb.protocol.asyncio_proto", - ["edgedb/protocol/asyncio_proto.pyx"], + "gel.protocol.asyncio_proto", + ["gel/protocol/asyncio_proto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, include_dirs=INCLUDE_DIRS), distutils_extension.Extension( - "edgedb.protocol.blocking_proto", - ["edgedb/protocol/blocking_proto.pyx"], + "gel.protocol.blocking_proto", + ["gel/protocol/blocking_proto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS, include_dirs=INCLUDE_DIRS), @@ -349,7 +352,8 @@ def finalize_options(self): setup_requires=setup_requires, entry_points={ "console_scripts": [ - "edgedb-py=edgedb.codegen.cli:main", + "edgedb-py=gel.codegen.cli:main", + "gel-py=gel.codegen.cli:main", ] } ) diff --git a/tests/codegen/linked/test_linked_async_edgeql.py.assert b/tests/codegen/linked/test_linked_async_edgeql.py.assert index e54e2622..33058af6 100644 --- a/tests/codegen/linked/test_linked_async_edgeql.py.assert +++ b/tests/codegen/linked/test_linked_async_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel async def test_linked( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ diff --git a/tests/codegen/linked/test_linked_edgeql.py.assert b/tests/codegen/linked/test_linked_edgeql.py.assert index 00a1b169..d8d7aa04 100644 --- a/tests/codegen/linked/test_linked_edgeql.py.assert +++ b/tests/codegen/linked/test_linked_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel def test_linked( - executor: edgedb.Executor, + executor: gel.Executor, ) -> int: return executor.query_single( """\ diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert b/tests/codegen/test-project1/generated_async_edgeql.py.assert index f4e8fea7..22ba6c8d 100644 --- a/tests/codegen/test-project1/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert @@ -3,12 +3,12 @@ # 'select_scalar.edgeql' # 'linked/test_linked.edgeql' # WITH: -# $ edgedb-py --target async --file --no-skip-pydantic-validation +# $ gel-py --target async --file --no-skip-pydantic-validation from __future__ import annotations import dataclasses -import edgedb +import gel import uuid @@ -24,7 +24,7 @@ class SelectOptionalJsonResultItemSnakeCase: async def select_optional_json( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, arg0: str | None, ) -> list[tuple[str, SelectOptionalJsonResultItem]]: return await executor.query( @@ -40,7 +40,7 @@ async def select_optional_json( async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ @@ -50,7 +50,7 @@ async def select_scalar( async def test_linked( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert5 b/tests/codegen/test-project1/generated_async_edgeql.py.assert5 index f4e8fea7..22ba6c8d 100644 --- a/tests/codegen/test-project1/generated_async_edgeql.py.assert5 +++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert5 @@ -3,12 +3,12 @@ # 'select_scalar.edgeql' # 'linked/test_linked.edgeql' # WITH: -# $ edgedb-py --target async --file --no-skip-pydantic-validation +# $ gel-py --target async --file --no-skip-pydantic-validation from __future__ import annotations import dataclasses -import edgedb +import gel import uuid @@ -24,7 +24,7 @@ class SelectOptionalJsonResultItemSnakeCase: async def select_optional_json( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, arg0: str | None, ) -> list[tuple[str, SelectOptionalJsonResultItem]]: return await executor.query( @@ -40,7 +40,7 @@ async def select_optional_json( async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ @@ -50,7 +50,7 @@ async def select_scalar( async def test_linked( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ diff --git a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert index f70213b0..f294ffd2 100644 --- a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert +++ b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'select_optional_json.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -37,7 +37,7 @@ class SelectOptionalJsonResultItemSnakeCase(NoPydanticValidation): async def select_optional_json( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, arg0: typing.Optional[str], ) -> typing.List[typing.Tuple[str, SelectOptionalJsonResultItem]]: return await executor.query( diff --git a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert index 004dd27c..61d2cd0d 100644 --- a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert +++ b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'select_optional_json.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -21,7 +21,7 @@ class SelectOptionalJsonResultItemSnakeCase: def select_optional_json( - executor: edgedb.Executor, + executor: gel.Executor, arg0: typing.Optional[str], ) -> list[tuple[str, SelectOptionalJsonResultItem]]: return executor.query( diff --git a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert index 6312fe0e..2a6dc130 100644 --- a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert +++ b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'select_scalar.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ diff --git a/tests/codegen/test-project1/select_scalar_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_edgeql.py.assert index 4c870f1b..d8b16a53 100644 --- a/tests/codegen/test-project1/select_scalar_edgeql.py.assert +++ b/tests/codegen/test-project1/select_scalar_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'select_scalar.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel def select_scalar( - executor: edgedb.Executor, + executor: gel.Executor, ) -> int: return executor.query_single( """\ diff --git a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert index cbfb0961..ecf680f7 100644 --- a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert +++ b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel async def query_one( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_name_with_underscores: int, ) -> int: diff --git a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert index c7cfe601..56701a68 100644 --- a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert +++ b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel def query_one( - executor: edgedb.Executor, + executor: gel.Executor, *, arg_name_with_underscores: int, ) -> int: diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert index 007f2af1..4efdfad8 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert @@ -9,15 +9,15 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ edgedb-py --target async --file --no-skip-pydantic-validation +# $ gel-py --target async --file --no-skip-pydantic-validation from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -91,26 +91,26 @@ class MyQueryResult: ab: datetime.timedelta | None ac: int ad: int | None - ae: edgedb.RelativeDuration - af: edgedb.RelativeDuration | None - ag: edgedb.DateDuration - ah: edgedb.DateDuration | None - ai: edgedb.ConfigMemory - aj: edgedb.ConfigMemory | None - ak: edgedb.Range[int] - al: edgedb.Range[int] | None - am: edgedb.Range[int] - an: edgedb.Range[int] | None - ao: edgedb.Range[float] - ap: edgedb.Range[float] | None - aq: edgedb.Range[float] - ar: edgedb.Range[float] | None - as_: edgedb.Range[datetime.datetime] - at: edgedb.Range[datetime.datetime] | None - au: edgedb.Range[datetime.datetime] - av: edgedb.Range[datetime.datetime] | None - aw: edgedb.Range[datetime.date] - ax: edgedb.Range[datetime.date] | None + ae: gel.RelativeDuration + af: gel.RelativeDuration | None + ag: gel.DateDuration + ah: gel.DateDuration | None + ai: gel.ConfigMemory + aj: gel.ConfigMemory | None + ak: gel.Range[int] + al: gel.Range[int] | None + am: gel.Range[int] + an: gel.Range[int] | None + ao: gel.Range[float] + ap: gel.Range[float] | None + aq: gel.Range[float] + ar: gel.Range[float] | None + as_: gel.Range[datetime.datetime] + at: gel.Range[datetime.datetime] | None + au: gel.Range[datetime.datetime] + av: gel.Range[datetime.datetime] | None + aw: gel.Range[datetime.date] + ax: gel.Range[datetime.date] | None ay: MyScalar az: MyScalar | None ba: MyEnum @@ -141,7 +141,7 @@ class SelectObjectResultParamsItem: async def custom_vector_input( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, input: V3 | None = None, ) -> int | None: @@ -154,7 +154,7 @@ async def custom_vector_input( async def link_prop( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[LinkPropResult]: return await executor.query( """\ @@ -181,7 +181,7 @@ async def link_prop( async def my_query( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, a: uuid.UUID, b: uuid.UUID | None = None, @@ -213,26 +213,26 @@ async def my_query( ab: datetime.timedelta | None = None, ac: int, ad: int | None = None, - ae: edgedb.RelativeDuration, - af: edgedb.RelativeDuration | None = None, - ag: edgedb.DateDuration, - ah: edgedb.DateDuration | None = None, - ai: edgedb.ConfigMemory, - aj: edgedb.ConfigMemory | None = None, - ak: edgedb.Range[int], - al: edgedb.Range[int] | None = None, - am: edgedb.Range[int], - an: edgedb.Range[int] | None = None, - ao: edgedb.Range[float], - ap: edgedb.Range[float] | None = None, - aq: edgedb.Range[float], - ar: edgedb.Range[float] | None = None, - as_: edgedb.Range[datetime.datetime], - at: edgedb.Range[datetime.datetime] | None = None, - au: edgedb.Range[datetime.datetime], - av: edgedb.Range[datetime.datetime] | None = None, - aw: edgedb.Range[datetime.date], - ax: edgedb.Range[datetime.date] | None = None, + ae: gel.RelativeDuration, + af: gel.RelativeDuration | None = None, + ag: gel.DateDuration, + ah: gel.DateDuration | None = None, + ai: gel.ConfigMemory, + aj: gel.ConfigMemory | None = None, + ak: gel.Range[int], + al: gel.Range[int] | None = None, + am: gel.Range[int], + an: gel.Range[int] | None = None, + ao: gel.Range[float], + ap: gel.Range[float] | None = None, + aq: gel.Range[float], + ar: gel.Range[float] | None = None, + as_: gel.Range[datetime.datetime], + at: gel.Range[datetime.datetime] | None = None, + au: gel.Range[datetime.datetime], + av: gel.Range[datetime.datetime] | None = None, + aw: gel.Range[datetime.date], + ax: gel.Range[datetime.date] | None = None, bc: typing.Sequence[float], bd: typing.Sequence[float] | None = None, ) -> MyQueryResult: @@ -356,7 +356,7 @@ async def my_query( async def query_one( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_name_with_underscores: int, ) -> int: @@ -369,7 +369,7 @@ async def query_one( async def select_args( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_str: str, arg_datetime: datetime.datetime, @@ -387,7 +387,7 @@ async def select_args( async def select_object( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> SelectObjectResult | None: return await executor.query_single( """\ @@ -405,7 +405,7 @@ async def select_object( async def select_objects( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[SelectObjectResult]: return await executor.query( """\ @@ -422,7 +422,7 @@ async def select_objects( async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ @@ -432,8 +432,8 @@ async def select_scalar( async def select_scalars( - executor: edgedb.AsyncIOExecutor, -) -> list[edgedb.ConfigMemory]: + executor: gel.AsyncIOExecutor, +) -> list[gel.ConfigMemory]: return await executor.query( """\ select {1, 2, 3};\ diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert3 b/tests/codegen/test-project2/generated_async_edgeql.py.assert3 index 2095cc9a..4a1a7444 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert3 +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert3 @@ -9,15 +9,15 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ edgedb-py --target async --file --no-skip-pydantic-validation +# $ gel-py --target async --file --no-skip-pydantic-validation from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -92,26 +92,26 @@ class MyQueryResult: ab: datetime.timedelta | None ac: int ad: int | None - ae: edgedb.RelativeDuration - af: edgedb.RelativeDuration | None - ag: edgedb.DateDuration - ah: edgedb.DateDuration | None - ai: edgedb.ConfigMemory - aj: edgedb.ConfigMemory | None - ak: edgedb.Range[int] - al: edgedb.Range[int] | None - am: edgedb.Range[int] - an: edgedb.Range[int] | None - ao: edgedb.Range[float] - ap: edgedb.Range[float] | None - aq: edgedb.Range[float] - ar: edgedb.Range[float] | None - as_: edgedb.Range[datetime.datetime] - at: edgedb.Range[datetime.datetime] | None - au: edgedb.Range[datetime.datetime] - av: edgedb.Range[datetime.datetime] | None - aw: edgedb.Range[datetime.date] - ax: edgedb.Range[datetime.date] | None + ae: gel.RelativeDuration + af: gel.RelativeDuration | None + ag: gel.DateDuration + ah: gel.DateDuration | None + ai: gel.ConfigMemory + aj: gel.ConfigMemory | None + ak: gel.Range[int] + al: gel.Range[int] | None + am: gel.Range[int] + an: gel.Range[int] | None + ao: gel.Range[float] + ap: gel.Range[float] | None + aq: gel.Range[float] + ar: gel.Range[float] | None + as_: gel.Range[datetime.datetime] + at: gel.Range[datetime.datetime] | None + au: gel.Range[datetime.datetime] + av: gel.Range[datetime.datetime] | None + aw: gel.Range[datetime.date] + ax: gel.Range[datetime.date] | None ay: MyScalar az: MyScalar | None ba: MyEnum @@ -143,7 +143,7 @@ class SelectObjectResultParamsItem: async def custom_vector_input( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, input: Input | None = None, ) -> int | None: @@ -156,7 +156,7 @@ async def custom_vector_input( async def link_prop( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[LinkPropResult]: return await executor.query( """\ @@ -183,7 +183,7 @@ async def link_prop( async def my_query( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, a: uuid.UUID, b: uuid.UUID | None = None, @@ -215,26 +215,26 @@ async def my_query( ab: datetime.timedelta | None = None, ac: int, ad: int | None = None, - ae: edgedb.RelativeDuration, - af: edgedb.RelativeDuration | None = None, - ag: edgedb.DateDuration, - ah: edgedb.DateDuration | None = None, - ai: edgedb.ConfigMemory, - aj: edgedb.ConfigMemory | None = None, - ak: edgedb.Range[int], - al: edgedb.Range[int] | None = None, - am: edgedb.Range[int], - an: edgedb.Range[int] | None = None, - ao: edgedb.Range[float], - ap: edgedb.Range[float] | None = None, - aq: edgedb.Range[float], - ar: edgedb.Range[float] | None = None, - as_: edgedb.Range[datetime.datetime], - at: edgedb.Range[datetime.datetime] | None = None, - au: edgedb.Range[datetime.datetime], - av: edgedb.Range[datetime.datetime] | None = None, - aw: edgedb.Range[datetime.date], - ax: edgedb.Range[datetime.date] | None = None, + ae: gel.RelativeDuration, + af: gel.RelativeDuration | None = None, + ag: gel.DateDuration, + ah: gel.DateDuration | None = None, + ai: gel.ConfigMemory, + aj: gel.ConfigMemory | None = None, + ak: gel.Range[int], + al: gel.Range[int] | None = None, + am: gel.Range[int], + an: gel.Range[int] | None = None, + ao: gel.Range[float], + ap: gel.Range[float] | None = None, + aq: gel.Range[float], + ar: gel.Range[float] | None = None, + as_: gel.Range[datetime.datetime], + at: gel.Range[datetime.datetime] | None = None, + au: gel.Range[datetime.datetime], + av: gel.Range[datetime.datetime] | None = None, + aw: gel.Range[datetime.date], + ax: gel.Range[datetime.date] | None = None, bc: typing.Sequence[float], bd: typing.Sequence[float] | None = None, ) -> MyQueryResult: @@ -358,7 +358,7 @@ async def my_query( async def query_one( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_name_with_underscores: int, ) -> int: @@ -371,7 +371,7 @@ async def query_one( async def select_args( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_str: str, arg_datetime: datetime.datetime, @@ -389,7 +389,7 @@ async def select_args( async def select_object( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> SelectObjectResult | None: return await executor.query_single( """\ @@ -407,7 +407,7 @@ async def select_object( async def select_objects( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[SelectObjectResult]: return await executor.query( """\ @@ -424,7 +424,7 @@ async def select_objects( async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ @@ -434,8 +434,8 @@ async def select_scalar( async def select_scalars( - executor: edgedb.AsyncIOExecutor, -) -> list[edgedb.ConfigMemory]: + executor: gel.AsyncIOExecutor, +) -> list[gel.ConfigMemory]: return await executor.query( """\ select {1, 2, 3};\ diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert5 b/tests/codegen/test-project2/generated_async_edgeql.py.assert5 index 75d52ea7..8907bf73 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert5 @@ -9,15 +9,15 @@ # 'scalar/select_scalar.edgeql' # 'scalar/select_scalars.edgeql' # WITH: -# $ edgedb-py --target async --file --no-skip-pydantic-validation +# $ gel-py --target async --file --no-skip-pydantic-validation from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -92,26 +92,26 @@ class MyQueryResult: ab: datetime.timedelta | None ac: int ad: int | None - ae: edgedb.RelativeDuration - af: edgedb.RelativeDuration | None - ag: edgedb.DateDuration - ah: edgedb.DateDuration | None - ai: edgedb.ConfigMemory - aj: edgedb.ConfigMemory | None - ak: edgedb.Range[int] - al: edgedb.Range[int] | None - am: edgedb.Range[int] - an: edgedb.Range[int] | None - ao: edgedb.Range[float] - ap: edgedb.Range[float] | None - aq: edgedb.Range[float] - ar: edgedb.Range[float] | None - as_: edgedb.Range[datetime.datetime] - at: edgedb.Range[datetime.datetime] | None - au: edgedb.Range[datetime.datetime] - av: edgedb.Range[datetime.datetime] | None - aw: edgedb.Range[datetime.date] - ax: edgedb.Range[datetime.date] | None + ae: gel.RelativeDuration + af: gel.RelativeDuration | None + ag: gel.DateDuration + ah: gel.DateDuration | None + ai: gel.ConfigMemory + aj: gel.ConfigMemory | None + ak: gel.Range[int] + al: gel.Range[int] | None + am: gel.Range[int] + an: gel.Range[int] | None + ao: gel.Range[float] + ap: gel.Range[float] | None + aq: gel.Range[float] + ar: gel.Range[float] | None + as_: gel.Range[datetime.datetime] + at: gel.Range[datetime.datetime] | None + au: gel.Range[datetime.datetime] + av: gel.Range[datetime.datetime] | None + aw: gel.Range[datetime.date] + ax: gel.Range[datetime.date] | None ay: MyScalar az: MyScalar | None ba: MyEnum @@ -143,7 +143,7 @@ class SelectObjectResultParamsItem: async def custom_vector_input( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, input: V3 | None = None, ) -> int | None: @@ -156,7 +156,7 @@ async def custom_vector_input( async def link_prop( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[LinkPropResult]: return await executor.query( """\ @@ -183,7 +183,7 @@ async def link_prop( async def my_query( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, a: uuid.UUID, b: uuid.UUID | None = None, @@ -215,26 +215,26 @@ async def my_query( ab: datetime.timedelta | None = None, ac: int, ad: int | None = None, - ae: edgedb.RelativeDuration, - af: edgedb.RelativeDuration | None = None, - ag: edgedb.DateDuration, - ah: edgedb.DateDuration | None = None, - ai: edgedb.ConfigMemory, - aj: edgedb.ConfigMemory | None = None, - ak: edgedb.Range[int], - al: edgedb.Range[int] | None = None, - am: edgedb.Range[int], - an: edgedb.Range[int] | None = None, - ao: edgedb.Range[float], - ap: edgedb.Range[float] | None = None, - aq: edgedb.Range[float], - ar: edgedb.Range[float] | None = None, - as_: edgedb.Range[datetime.datetime], - at: edgedb.Range[datetime.datetime] | None = None, - au: edgedb.Range[datetime.datetime], - av: edgedb.Range[datetime.datetime] | None = None, - aw: edgedb.Range[datetime.date], - ax: edgedb.Range[datetime.date] | None = None, + ae: gel.RelativeDuration, + af: gel.RelativeDuration | None = None, + ag: gel.DateDuration, + ah: gel.DateDuration | None = None, + ai: gel.ConfigMemory, + aj: gel.ConfigMemory | None = None, + ak: gel.Range[int], + al: gel.Range[int] | None = None, + am: gel.Range[int], + an: gel.Range[int] | None = None, + ao: gel.Range[float], + ap: gel.Range[float] | None = None, + aq: gel.Range[float], + ar: gel.Range[float] | None = None, + as_: gel.Range[datetime.datetime], + at: gel.Range[datetime.datetime] | None = None, + au: gel.Range[datetime.datetime], + av: gel.Range[datetime.datetime] | None = None, + aw: gel.Range[datetime.date], + ax: gel.Range[datetime.date] | None = None, bc: typing.Sequence[float], bd: typing.Sequence[float] | None = None, ) -> MyQueryResult: @@ -358,7 +358,7 @@ async def my_query( async def query_one( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_name_with_underscores: int, ) -> int: @@ -371,7 +371,7 @@ async def query_one( async def select_args( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_str: str, arg_datetime: datetime.datetime, @@ -389,7 +389,7 @@ async def select_args( async def select_object( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> SelectObjectResult | None: return await executor.query_single( """\ @@ -407,7 +407,7 @@ async def select_object( async def select_objects( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> list[SelectObjectResult]: return await executor.query( """\ @@ -424,7 +424,7 @@ async def select_objects( async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ @@ -434,8 +434,8 @@ async def select_scalar( async def select_scalars( - executor: edgedb.AsyncIOExecutor, -) -> list[edgedb.ConfigMemory]: + executor: gel.AsyncIOExecutor, +) -> list[gel.ConfigMemory]: return await executor.query( """\ select {1, 2, 3};\ diff --git a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert index 95263c5b..bc323374 100644 --- a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'object/link_prop.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses import datetime -import edgedb +import gel import typing import uuid @@ -52,7 +52,7 @@ class LinkPropResultFriendsItem(NoPydanticValidation): async def link_prop( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> typing.List[LinkPropResult]: return await executor.query( """\ diff --git a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert index 3eaf4ddb..d3912dac 100644 --- a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert +++ b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'object/link_prop.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses import datetime -import edgedb +import gel import typing import uuid @@ -36,7 +36,7 @@ class LinkPropResultFriendsItem: def link_prop( - executor: edgedb.Executor, + executor: gel.Executor, ) -> list[LinkPropResult]: return executor.query( """\ diff --git a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert index 3e121c7a..37e1ec36 100644 --- a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'object/select_object.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -41,7 +41,7 @@ class SelectObjectResultParamsItem(NoPydanticValidation): async def select_object( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> typing.Optional[SelectObjectResult]: return await executor.query_single( """\ diff --git a/tests/codegen/test-project2/object/select_object_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_edgeql.py.assert index c3ebb60e..2fa0df0a 100644 --- a/tests/codegen/test-project2/object/select_object_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_object_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'object/select_object.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -25,7 +25,7 @@ class SelectObjectResultParamsItem: def select_object( - executor: edgedb.Executor, + executor: gel.Executor, ) -> typing.Optional[SelectObjectResult]: return executor.query_single( """\ diff --git a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert index 25e8ec92..f4e97aba 100644 --- a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'object/select_objects.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -41,7 +41,7 @@ class SelectObjectsResultParamsItem(NoPydanticValidation): async def select_objects( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> typing.List[SelectObjectsResult]: return await executor.query( """\ diff --git a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert index fed335d7..875b3cfa 100644 --- a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert +++ b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert @@ -1,10 +1,10 @@ # AUTOGENERATED FROM 'object/select_objects.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses -import edgedb +import gel import typing import uuid @@ -25,7 +25,7 @@ class SelectObjectsResultParamsItem: def select_objects( - executor: edgedb.Executor, + executor: gel.Executor, ) -> list[SelectObjectsResult]: return executor.query( """\ diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert index 7643f4f4..4f3ece14 100644 --- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses import datetime -import edgedb +import gel class NoPydanticValidation: @@ -31,7 +31,7 @@ class SelectArgsResult(NoPydanticValidation): async def select_args( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_str: str, arg_datetime: datetime.datetime, diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 index b318bb5c..0e48f41d 100644 --- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import dataclasses import datetime -import edgedb +import gel import uuid @@ -33,7 +33,7 @@ class SelectArgsResult(NoPydanticValidation): async def select_args( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, arg_str: str, arg_datetime: datetime.datetime, diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert index 88fd6cae..5adeb922 100644 --- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses import datetime -import edgedb +import gel @dataclasses.dataclass @@ -15,7 +15,7 @@ class SelectArgsResult: def select_args( - executor: edgedb.Executor, + executor: gel.Executor, *, arg_str: str, arg_datetime: datetime.datetime, diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 index cc0d8cfc..ca9cfddb 100644 --- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 @@ -1,11 +1,11 @@ # AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import dataclasses import datetime -import edgedb +import gel import uuid @@ -17,7 +17,7 @@ class SelectArgsResult: def select_args( - executor: edgedb.Executor, + executor: gel.Executor, *, arg_str: str, arg_datetime: datetime.datetime, diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert index 79b133af..ddb0d695 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -71,26 +71,26 @@ class MyQueryResult(NoPydanticValidation): ab: typing.Optional[datetime.timedelta] ac: int ad: typing.Optional[int] - ae: edgedb.RelativeDuration - af: typing.Optional[edgedb.RelativeDuration] - ag: edgedb.DateDuration - ah: typing.Optional[edgedb.DateDuration] - ai: edgedb.ConfigMemory - aj: typing.Optional[edgedb.ConfigMemory] - ak: edgedb.Range[int] - al: typing.Optional[edgedb.Range[int]] - am: edgedb.Range[int] - an: typing.Optional[edgedb.Range[int]] - ao: edgedb.Range[float] - ap: typing.Optional[edgedb.Range[float]] - aq: edgedb.Range[float] - ar: typing.Optional[edgedb.Range[float]] - as_: edgedb.Range[datetime.datetime] - at: typing.Optional[edgedb.Range[datetime.datetime]] - au: edgedb.Range[datetime.datetime] - av: typing.Optional[edgedb.Range[datetime.datetime]] - aw: edgedb.Range[datetime.date] - ax: typing.Optional[edgedb.Range[datetime.date]] + ae: gel.RelativeDuration + af: typing.Optional[gel.RelativeDuration] + ag: gel.DateDuration + ah: typing.Optional[gel.DateDuration] + ai: gel.ConfigMemory + aj: typing.Optional[gel.ConfigMemory] + ak: gel.Range[int] + al: typing.Optional[gel.Range[int]] + am: gel.Range[int] + an: typing.Optional[gel.Range[int]] + ao: gel.Range[float] + ap: typing.Optional[gel.Range[float]] + aq: gel.Range[float] + ar: typing.Optional[gel.Range[float]] + as_: gel.Range[datetime.datetime] + at: typing.Optional[gel.Range[datetime.datetime]] + au: gel.Range[datetime.datetime] + av: typing.Optional[gel.Range[datetime.datetime]] + aw: gel.Range[datetime.date] + ax: typing.Optional[gel.Range[datetime.date]] ay: MyScalar az: typing.Optional[MyScalar] ba: MyEnum @@ -100,7 +100,7 @@ class MyQueryResult(NoPydanticValidation): async def my_query( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, a: uuid.UUID, b: typing.Optional[uuid.UUID] = None, @@ -132,26 +132,26 @@ async def my_query( ab: typing.Optional[datetime.timedelta] = None, ac: int, ad: typing.Optional[int] = None, - ae: edgedb.RelativeDuration, - af: typing.Optional[edgedb.RelativeDuration] = None, - ag: edgedb.DateDuration, - ah: typing.Optional[edgedb.DateDuration] = None, - ai: edgedb.ConfigMemory, - aj: typing.Optional[edgedb.ConfigMemory] = None, - ak: edgedb.Range[int], - al: typing.Optional[edgedb.Range[int]] = None, - am: edgedb.Range[int], - an: typing.Optional[edgedb.Range[int]] = None, - ao: edgedb.Range[float], - ap: typing.Optional[edgedb.Range[float]] = None, - aq: edgedb.Range[float], - ar: typing.Optional[edgedb.Range[float]] = None, - as_: edgedb.Range[datetime.datetime], - at: typing.Optional[edgedb.Range[datetime.datetime]] = None, - au: edgedb.Range[datetime.datetime], - av: typing.Optional[edgedb.Range[datetime.datetime]] = None, - aw: edgedb.Range[datetime.date], - ax: typing.Optional[edgedb.Range[datetime.date]] = None, + ae: gel.RelativeDuration, + af: typing.Optional[gel.RelativeDuration] = None, + ag: gel.DateDuration, + ah: typing.Optional[gel.DateDuration] = None, + ai: gel.ConfigMemory, + aj: typing.Optional[gel.ConfigMemory] = None, + ak: gel.Range[int], + al: typing.Optional[gel.Range[int]] = None, + am: gel.Range[int], + an: typing.Optional[gel.Range[int]] = None, + ao: gel.Range[float], + ap: typing.Optional[gel.Range[float]] = None, + aq: gel.Range[float], + ar: typing.Optional[gel.Range[float]] = None, + as_: gel.Range[datetime.datetime], + at: typing.Optional[gel.Range[datetime.datetime]] = None, + au: gel.Range[datetime.datetime], + av: typing.Optional[gel.Range[datetime.datetime]] = None, + aw: gel.Range[datetime.date], + ax: typing.Optional[gel.Range[datetime.date]] = None, bc: typing.Sequence[float], bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 index 05e4a7e2..7b7138cb 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -72,26 +72,26 @@ class MyQueryResult(NoPydanticValidation): ab: typing.Optional[datetime.timedelta] ac: int ad: typing.Optional[int] - ae: edgedb.RelativeDuration - af: typing.Optional[edgedb.RelativeDuration] - ag: edgedb.DateDuration - ah: typing.Optional[edgedb.DateDuration] - ai: edgedb.ConfigMemory - aj: typing.Optional[edgedb.ConfigMemory] - ak: edgedb.Range[int] - al: typing.Optional[edgedb.Range[int]] - am: edgedb.Range[int] - an: typing.Optional[edgedb.Range[int]] - ao: edgedb.Range[float] - ap: typing.Optional[edgedb.Range[float]] - aq: edgedb.Range[float] - ar: typing.Optional[edgedb.Range[float]] - as_: edgedb.Range[datetime.datetime] - at: typing.Optional[edgedb.Range[datetime.datetime]] - au: edgedb.Range[datetime.datetime] - av: typing.Optional[edgedb.Range[datetime.datetime]] - aw: edgedb.Range[datetime.date] - ax: typing.Optional[edgedb.Range[datetime.date]] + ae: gel.RelativeDuration + af: typing.Optional[gel.RelativeDuration] + ag: gel.DateDuration + ah: typing.Optional[gel.DateDuration] + ai: gel.ConfigMemory + aj: typing.Optional[gel.ConfigMemory] + ak: gel.Range[int] + al: typing.Optional[gel.Range[int]] + am: gel.Range[int] + an: typing.Optional[gel.Range[int]] + ao: gel.Range[float] + ap: typing.Optional[gel.Range[float]] + aq: gel.Range[float] + ar: typing.Optional[gel.Range[float]] + as_: gel.Range[datetime.datetime] + at: typing.Optional[gel.Range[datetime.datetime]] + au: gel.Range[datetime.datetime] + av: typing.Optional[gel.Range[datetime.datetime]] + aw: gel.Range[datetime.date] + ax: typing.Optional[gel.Range[datetime.date]] ay: MyScalar az: typing.Optional[MyScalar] ba: MyEnum @@ -101,7 +101,7 @@ class MyQueryResult(NoPydanticValidation): async def my_query( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, a: uuid.UUID, b: typing.Optional[uuid.UUID] = None, @@ -133,26 +133,26 @@ async def my_query( ab: typing.Optional[datetime.timedelta] = None, ac: int, ad: typing.Optional[int] = None, - ae: edgedb.RelativeDuration, - af: typing.Optional[edgedb.RelativeDuration] = None, - ag: edgedb.DateDuration, - ah: typing.Optional[edgedb.DateDuration] = None, - ai: edgedb.ConfigMemory, - aj: typing.Optional[edgedb.ConfigMemory] = None, - ak: edgedb.Range[int], - al: typing.Optional[edgedb.Range[int]] = None, - am: edgedb.Range[int], - an: typing.Optional[edgedb.Range[int]] = None, - ao: edgedb.Range[float], - ap: typing.Optional[edgedb.Range[float]] = None, - aq: edgedb.Range[float], - ar: typing.Optional[edgedb.Range[float]] = None, - as_: edgedb.Range[datetime.datetime], - at: typing.Optional[edgedb.Range[datetime.datetime]] = None, - au: edgedb.Range[datetime.datetime], - av: typing.Optional[edgedb.Range[datetime.datetime]] = None, - aw: edgedb.Range[datetime.date], - ax: typing.Optional[edgedb.Range[datetime.date]] = None, + ae: gel.RelativeDuration, + af: typing.Optional[gel.RelativeDuration] = None, + ag: gel.DateDuration, + ah: typing.Optional[gel.DateDuration] = None, + ai: gel.ConfigMemory, + aj: typing.Optional[gel.ConfigMemory] = None, + ak: gel.Range[int], + al: typing.Optional[gel.Range[int]] = None, + am: gel.Range[int], + an: typing.Optional[gel.Range[int]] = None, + ao: gel.Range[float], + ap: typing.Optional[gel.Range[float]] = None, + aq: gel.Range[float], + ar: typing.Optional[gel.Range[float]] = None, + as_: gel.Range[datetime.datetime], + at: typing.Optional[gel.Range[datetime.datetime]] = None, + au: gel.Range[datetime.datetime], + av: typing.Optional[gel.Range[datetime.datetime]] = None, + aw: gel.Range[datetime.date], + ax: typing.Optional[gel.Range[datetime.date]] = None, bc: typing.Sequence[float], bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert index fa79a94c..b22baa49 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -55,26 +55,26 @@ class MyQueryResult: ab: typing.Optional[datetime.timedelta] ac: int ad: typing.Optional[int] - ae: edgedb.RelativeDuration - af: typing.Optional[edgedb.RelativeDuration] - ag: edgedb.DateDuration - ah: typing.Optional[edgedb.DateDuration] - ai: edgedb.ConfigMemory - aj: typing.Optional[edgedb.ConfigMemory] - ak: edgedb.Range[int] - al: typing.Optional[edgedb.Range[int]] - am: edgedb.Range[int] - an: typing.Optional[edgedb.Range[int]] - ao: edgedb.Range[float] - ap: typing.Optional[edgedb.Range[float]] - aq: edgedb.Range[float] - ar: typing.Optional[edgedb.Range[float]] - as_: edgedb.Range[datetime.datetime] - at: typing.Optional[edgedb.Range[datetime.datetime]] - au: edgedb.Range[datetime.datetime] - av: typing.Optional[edgedb.Range[datetime.datetime]] - aw: edgedb.Range[datetime.date] - ax: typing.Optional[edgedb.Range[datetime.date]] + ae: gel.RelativeDuration + af: typing.Optional[gel.RelativeDuration] + ag: gel.DateDuration + ah: typing.Optional[gel.DateDuration] + ai: gel.ConfigMemory + aj: typing.Optional[gel.ConfigMemory] + ak: gel.Range[int] + al: typing.Optional[gel.Range[int]] + am: gel.Range[int] + an: typing.Optional[gel.Range[int]] + ao: gel.Range[float] + ap: typing.Optional[gel.Range[float]] + aq: gel.Range[float] + ar: typing.Optional[gel.Range[float]] + as_: gel.Range[datetime.datetime] + at: typing.Optional[gel.Range[datetime.datetime]] + au: gel.Range[datetime.datetime] + av: typing.Optional[gel.Range[datetime.datetime]] + aw: gel.Range[datetime.date] + ax: typing.Optional[gel.Range[datetime.date]] ay: MyScalar az: typing.Optional[MyScalar] ba: MyEnum @@ -84,7 +84,7 @@ class MyQueryResult: def my_query( - executor: edgedb.Executor, + executor: gel.Executor, *, a: uuid.UUID, b: typing.Optional[uuid.UUID] = None, @@ -116,26 +116,26 @@ def my_query( ab: typing.Optional[datetime.timedelta] = None, ac: int, ad: typing.Optional[int] = None, - ae: edgedb.RelativeDuration, - af: typing.Optional[edgedb.RelativeDuration] = None, - ag: edgedb.DateDuration, - ah: typing.Optional[edgedb.DateDuration] = None, - ai: edgedb.ConfigMemory, - aj: typing.Optional[edgedb.ConfigMemory] = None, - ak: edgedb.Range[int], - al: typing.Optional[edgedb.Range[int]] = None, - am: edgedb.Range[int], - an: typing.Optional[edgedb.Range[int]] = None, - ao: edgedb.Range[float], - ap: typing.Optional[edgedb.Range[float]] = None, - aq: edgedb.Range[float], - ar: typing.Optional[edgedb.Range[float]] = None, - as_: edgedb.Range[datetime.datetime], - at: typing.Optional[edgedb.Range[datetime.datetime]] = None, - au: edgedb.Range[datetime.datetime], - av: typing.Optional[edgedb.Range[datetime.datetime]] = None, - aw: edgedb.Range[datetime.date], - ax: typing.Optional[edgedb.Range[datetime.date]] = None, + ae: gel.RelativeDuration, + af: typing.Optional[gel.RelativeDuration] = None, + ag: gel.DateDuration, + ah: typing.Optional[gel.DateDuration] = None, + ai: gel.ConfigMemory, + aj: typing.Optional[gel.ConfigMemory] = None, + ak: gel.Range[int], + al: typing.Optional[gel.Range[int]] = None, + am: gel.Range[int], + an: typing.Optional[gel.Range[int]] = None, + ao: gel.Range[float], + ap: typing.Optional[gel.Range[float]] = None, + aq: gel.Range[float], + ar: typing.Optional[gel.Range[float]] = None, + as_: gel.Range[datetime.datetime], + at: typing.Optional[gel.Range[datetime.datetime]] = None, + au: gel.Range[datetime.datetime], + av: typing.Optional[gel.Range[datetime.datetime]] = None, + aw: gel.Range[datetime.date], + ax: typing.Optional[gel.Range[datetime.date]] = None, bc: typing.Sequence[float], bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 index c5a85ea8..2d60d18a 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations import array import dataclasses import datetime -import edgedb import enum +import gel import typing import uuid @@ -56,26 +56,26 @@ class MyQueryResult: ab: typing.Optional[datetime.timedelta] ac: int ad: typing.Optional[int] - ae: edgedb.RelativeDuration - af: typing.Optional[edgedb.RelativeDuration] - ag: edgedb.DateDuration - ah: typing.Optional[edgedb.DateDuration] - ai: edgedb.ConfigMemory - aj: typing.Optional[edgedb.ConfigMemory] - ak: edgedb.Range[int] - al: typing.Optional[edgedb.Range[int]] - am: edgedb.Range[int] - an: typing.Optional[edgedb.Range[int]] - ao: edgedb.Range[float] - ap: typing.Optional[edgedb.Range[float]] - aq: edgedb.Range[float] - ar: typing.Optional[edgedb.Range[float]] - as_: edgedb.Range[datetime.datetime] - at: typing.Optional[edgedb.Range[datetime.datetime]] - au: edgedb.Range[datetime.datetime] - av: typing.Optional[edgedb.Range[datetime.datetime]] - aw: edgedb.Range[datetime.date] - ax: typing.Optional[edgedb.Range[datetime.date]] + ae: gel.RelativeDuration + af: typing.Optional[gel.RelativeDuration] + ag: gel.DateDuration + ah: typing.Optional[gel.DateDuration] + ai: gel.ConfigMemory + aj: typing.Optional[gel.ConfigMemory] + ak: gel.Range[int] + al: typing.Optional[gel.Range[int]] + am: gel.Range[int] + an: typing.Optional[gel.Range[int]] + ao: gel.Range[float] + ap: typing.Optional[gel.Range[float]] + aq: gel.Range[float] + ar: typing.Optional[gel.Range[float]] + as_: gel.Range[datetime.datetime] + at: typing.Optional[gel.Range[datetime.datetime]] + au: gel.Range[datetime.datetime] + av: typing.Optional[gel.Range[datetime.datetime]] + aw: gel.Range[datetime.date] + ax: typing.Optional[gel.Range[datetime.date]] ay: MyScalar az: typing.Optional[MyScalar] ba: MyEnum @@ -85,7 +85,7 @@ class MyQueryResult: def my_query( - executor: edgedb.Executor, + executor: gel.Executor, *, a: uuid.UUID, b: typing.Optional[uuid.UUID] = None, @@ -117,26 +117,26 @@ def my_query( ab: typing.Optional[datetime.timedelta] = None, ac: int, ad: typing.Optional[int] = None, - ae: edgedb.RelativeDuration, - af: typing.Optional[edgedb.RelativeDuration] = None, - ag: edgedb.DateDuration, - ah: typing.Optional[edgedb.DateDuration] = None, - ai: edgedb.ConfigMemory, - aj: typing.Optional[edgedb.ConfigMemory] = None, - ak: edgedb.Range[int], - al: typing.Optional[edgedb.Range[int]] = None, - am: edgedb.Range[int], - an: typing.Optional[edgedb.Range[int]] = None, - ao: edgedb.Range[float], - ap: typing.Optional[edgedb.Range[float]] = None, - aq: edgedb.Range[float], - ar: typing.Optional[edgedb.Range[float]] = None, - as_: edgedb.Range[datetime.datetime], - at: typing.Optional[edgedb.Range[datetime.datetime]] = None, - au: edgedb.Range[datetime.datetime], - av: typing.Optional[edgedb.Range[datetime.datetime]] = None, - aw: edgedb.Range[datetime.date], - ax: typing.Optional[edgedb.Range[datetime.date]] = None, + ae: gel.RelativeDuration, + af: typing.Optional[gel.RelativeDuration] = None, + ag: gel.DateDuration, + ah: typing.Optional[gel.DateDuration] = None, + ai: gel.ConfigMemory, + aj: typing.Optional[gel.ConfigMemory] = None, + ak: gel.Range[int], + al: typing.Optional[gel.Range[int]] = None, + am: gel.Range[int], + an: typing.Optional[gel.Range[int]] = None, + ao: gel.Range[float], + ap: typing.Optional[gel.Range[float]] = None, + aq: gel.Range[float], + ar: typing.Optional[gel.Range[float]] = None, + as_: gel.Range[datetime.datetime], + at: typing.Optional[gel.Range[datetime.datetime]] = None, + au: gel.Range[datetime.datetime], + av: typing.Optional[gel.Range[datetime.datetime]] = None, + aw: gel.Range[datetime.date], + ax: typing.Optional[gel.Range[datetime.date]] = None, bc: typing.Sequence[float], bd: typing.Optional[typing.Sequence[float]] = None, ) -> MyQueryResult: diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert index 39272b4f..277e472a 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert @@ -1,9 +1,9 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel import typing @@ -11,7 +11,7 @@ V3 = typing.Sequence[float] async def custom_vector_input( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, input: typing.Optional[V3] = None, ) -> typing.Optional[int]: diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 index a16beb5e..1a1f0e0d 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 +++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 @@ -1,9 +1,9 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel import typing @@ -11,7 +11,7 @@ Input = typing.Sequence[float] async def custom_vector_input( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, *, input: typing.Optional[Input] = None, ) -> typing.Optional[int]: diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert index ed2947c0..6ccca1a6 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert @@ -1,9 +1,9 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel import typing @@ -11,7 +11,7 @@ V3 = typing.Sequence[float] def custom_vector_input( - executor: edgedb.Executor, + executor: gel.Executor, *, input: typing.Optional[V3] = None, ) -> typing.Optional[int]: diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 index 4cf4659a..509bd484 100644 --- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 +++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 @@ -1,9 +1,9 @@ # AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel import typing @@ -11,7 +11,7 @@ Input = typing.Sequence[float] def custom_vector_input( - executor: edgedb.Executor, + executor: gel.Executor, *, input: typing.Optional[Input] = None, ) -> typing.Optional[int]: diff --git a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert index a9fc5917..327be5e3 100644 --- a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel async def select_scalar( - executor: edgedb.AsyncIOExecutor, + executor: gel.AsyncIOExecutor, ) -> int: return await executor.query_single( """\ diff --git a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert index c8cfeb1b..0337455b 100644 --- a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert @@ -1,13 +1,13 @@ # AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel def select_scalar( - executor: edgedb.Executor, + executor: gel.Executor, ) -> int: return executor.query_single( """\ diff --git a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert index d7e9b233..a7291850 100644 --- a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert @@ -1,15 +1,15 @@ # AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH: -# $ edgedb-py +# $ gel-py from __future__ import annotations -import edgedb +import gel import typing async def select_scalars( - executor: edgedb.AsyncIOExecutor, -) -> typing.List[edgedb.ConfigMemory]: + executor: gel.AsyncIOExecutor, +) -> typing.List[gel.ConfigMemory]: return await executor.query( """\ select {1, 2, 3};\ diff --git a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert index 74016c24..9cf2ebb4 100644 --- a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert +++ b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert @@ -1,14 +1,14 @@ # AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH: -# $ edgedb-py --target blocking --no-skip-pydantic-validation +# $ gel-py --target blocking --no-skip-pydantic-validation from __future__ import annotations -import edgedb +import gel def select_scalars( - executor: edgedb.Executor, -) -> list[edgedb.ConfigMemory]: + executor: gel.Executor, +) -> list[gel.ConfigMemory]: return executor.query( """\ select {1, 2, 3};\ diff --git a/tests/test_async_query.py b/tests/test_async_query.py index da731bb3..851afd02 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -29,7 +29,7 @@ import edgedb from edgedb import abstract -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb.options import RetryOptions from edgedb.protocol import protocol @@ -988,7 +988,7 @@ async def test_enum_argument_01(self): await tx.query_single('SELECT $0', 'Oups') with self.assertRaisesRegex( - edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'): + edgedb.InvalidArgumentError, 'a str or gel.EnumValue'): await self.client.query_single('SELECT $0', 123) async def test_enum_argument_02(self): diff --git a/tests/test_async_retry.py b/tests/test_async_retry.py index d47c5603..0fb13d08 100644 --- a/tests/test_async_retry.py +++ b/tests/test_async_retry.py @@ -24,7 +24,7 @@ import edgedb from edgedb import errors from edgedb import RetryOptions -from edgedb import _testbase as tb +from gel import _testbase as tb log = logging.getLogger(__name__) diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 8ceeb239..030b4672 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -21,7 +21,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import TransactionOptions from edgedb.options import RetryOptions diff --git a/tests/test_asyncio_client.py b/tests/test_asyncio_client.py index b4d5b03d..3971029c 100644 --- a/tests/test_asyncio_client.py +++ b/tests/test_asyncio_client.py @@ -21,7 +21,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import errors from edgedb import asyncio_client diff --git a/tests/test_blocking_client.py b/tests/test_blocking_client.py index 099baa71..4ebc153d 100644 --- a/tests/test_blocking_client.py +++ b/tests/test_blocking_client.py @@ -24,7 +24,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import errors from edgedb import blocking_client diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 801b00ae..3124f550 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -24,7 +24,7 @@ import os import tempfile -from edgedb import _testbase as tb +from gel import _testbase as tb ASSERT_SUFFIX = os.environ.get("EDGEDB_TEST_CODEGEN_ASSERT_SUFFIX", ".assert") @@ -91,7 +91,7 @@ async def run(*args, extra_env=None): p.returncode, args, output=await p.stdout.read(), ) - cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "edgedb-py") + cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "gel-py") await run( cmd, extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"} ) diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py index 24027f06..13454671 100644 --- a/tests/test_con_utils.py +++ b/tests/test_con_utils.py @@ -29,7 +29,7 @@ from unittest import mock -from edgedb import con_utils +from gel import con_utils from edgedb import errors @@ -389,9 +389,9 @@ def test_connect_params(self): self.run_testcase(testcase) - @mock.patch("edgedb.platform.config_dir", + @mock.patch("gel.platform.config_dir", lambda: pathlib.Path("/home/user/.config/edgedb")) - @mock.patch("edgedb.platform.IS_WINDOWS", False) + @mock.patch("gel.platform.IS_WINDOWS", False) @mock.patch("pathlib.Path.exists", lambda p: True) @mock.patch("os.path.realpath", lambda p: p) def test_stash_path(self): @@ -423,7 +423,7 @@ def test_project_config(self): "database": "inst1_db", })) - with mock.patch('edgedb.platform.config_dir', + with mock.patch('gel.platform.config_dir', lambda: home / '.edgedb'), \ mock.patch('os.getcwd', lambda: str(project)): stash_path = con_utils._stash_path(project) diff --git a/tests/test_connect.py b/tests/test_connect.py index e22f1958..48cebade 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -21,7 +21,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb class TestConnect(tb.AsyncQueryTestCase): diff --git a/tests/test_datetime.py b/tests/test_datetime.py index ff7dfbc0..bb19a5b1 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -20,7 +20,7 @@ import random from datetime import timedelta -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import errors from edgedb.datatypes.datatypes import RelativeDuration, DateDuration @@ -178,7 +178,7 @@ async def test_relative_duration_02(self): self.assertEqual(d1.months, 0) self.assertEqual(d1.microseconds, 1) - self.assertEqual(repr(d1), '') + self.assertEqual(repr(d1), '') async def test_relative_duration_03(self): # Make sure that when we break down the microseconds into the bigger diff --git a/tests/test_enum.py b/tests/test_enum.py index 1f3b6fae..fa0040c4 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -22,7 +22,7 @@ import uuid import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb class TestEnum(tb.AsyncQueryTestCase): @@ -40,7 +40,7 @@ async def test_enum_01(self): self.assertTrue(isinstance(ct_red, edgedb.EnumValue)) self.assertTrue(isinstance(ct_red.__tid__, uuid.UUID)) - self.assertEqual(repr(ct_red), "") + self.assertEqual(repr(ct_red), "") self.assertEqual(str(ct_red), 'red') with self.assertRaises(TypeError): diff --git a/tests/test_errors.py b/tests/test_errors.py index f76efa28..209c80e3 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -20,8 +20,8 @@ import unittest -from edgedb import errors -from edgedb.errors import _base as base_errors +from gel import errors +from gel.errors import _base as base_errors class TestErrors(unittest.TestCase): diff --git a/tests/test_globals.py b/tests/test_globals.py index f4c72029..ee9a2bb0 100644 --- a/tests/test_globals.py +++ b/tests/test_globals.py @@ -17,7 +17,7 @@ # -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import errors diff --git a/tests/test_memory.py b/tests/test_memory.py index 63c032e2..0004d1a7 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -16,7 +16,7 @@ # limitations under the License. # -from edgedb import _testbase as tb +from gel import _testbase as tb class TestConfigMemory(tb.SyncQueryTestCase): diff --git a/tests/test_namedtuples.py b/tests/test_namedtuples.py index 04a74ddc..3a525449 100644 --- a/tests/test_namedtuples.py +++ b/tests/test_namedtuples.py @@ -20,7 +20,7 @@ from collections import namedtuple, UserDict import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb class TestNamedTupleTypes(tb.SyncQueryTestCase): diff --git a/tests/test_postgis.py b/tests/test_postgis.py index 9d6583ef..029ff57d 100644 --- a/tests/test_postgis.py +++ b/tests/test_postgis.py @@ -18,7 +18,7 @@ from collections import namedtuple -from edgedb import _testbase as tb +from gel import _testbase as tb Geo = namedtuple('Geo', ['wkb']) diff --git a/tests/test_proto.py b/tests/test_proto.py index 48be7f15..9b984ce5 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -20,7 +20,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb class TestProto(tb.SyncQueryTestCase): diff --git a/tests/test_state.py b/tests/test_state.py index d59ef1d5..e7047765 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -17,7 +17,7 @@ # -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import State diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index 36566282..f6d27f60 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -27,7 +27,7 @@ import edgedb from edgedb import abstract -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb.protocol import protocol @@ -767,7 +767,7 @@ def test_enum_argument_01(self): tx.query_single('SELECT $0', 'Oups') with self.assertRaisesRegex( - edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue' + edgedb.InvalidArgumentError, 'a str or gel.EnumValue' ): self.client.query_single('SELECT $0', 123) diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py index ae32c633..8b1ef1a6 100644 --- a/tests/test_sync_retry.py +++ b/tests/test_sync_retry.py @@ -23,10 +23,8 @@ import unittest.mock from concurrent import futures -import edgedb -from edgedb import _testbase as tb -from edgedb import errors -from edgedb import RetryOptions +import gel +from gel import _testbase as tb class Barrier: @@ -82,7 +80,7 @@ def test_sync_retry_02(self): }; ''') 1 / 0 - with self.assertRaises(edgedb.NoDataError): + with self.assertRaises(gel.NoDataError): self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_02' @@ -109,9 +107,9 @@ def cleanup(): self.addCleanup(cleanup) - _start.side_effect = errors.BackendUnavailableError() + _start.side_effect = gel.BackendUnavailableError() - with self.assertRaises(errors.BackendUnavailableError): + with self.assertRaises(gel.BackendUnavailableError): for tx in self.client.transaction(): with tx: tx.execute(''' @@ -119,7 +117,7 @@ def cleanup(): name := 'counter_retry_begin' }; ''') - with self.assertRaises(edgedb.NoDataError): + with self.assertRaises(gel.NoDataError): self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' @@ -134,7 +132,7 @@ def cleanup(): def recover_after_first_error(*_, **__): patcher.stop() - raise errors.BackendUnavailableError() + raise gel.BackendUnavailableError() _start.side_effect = recover_after_first_error call_count = _start.call_count @@ -156,16 +154,16 @@ def test_sync_retry_conflict(self): self.execute_conflict('counter2') def test_sync_conflict_no_retry(self): - with self.assertRaises(edgedb.TransactionSerializationError): + with self.assertRaises(gel.TransactionSerializationError): self.execute_conflict( 'counter3', - RetryOptions(attempts=1, backoff=edgedb.default_backoff) + gel.RetryOptions(attempts=1, backoff=gel.default_backoff) ) def execute_conflict(self, name='counter2', options=None): con_args = self.get_connect_args().copy() con_args.update(database=self.get_database_name()) - client2 = edgedb.create_client(**con_args) + client2 = gel.create_client(**con_args) self.addCleanup(client2.close) barrier = Barrier(2) @@ -243,13 +241,13 @@ def test_sync_transaction_interface_errors(self): for tx in self.client.transaction(): tx.start() - with self.assertRaisesRegex(edgedb.InterfaceError, + with self.assertRaisesRegex(gel.InterfaceError, r'.*Use `with transaction:`'): for tx in self.client.transaction(): tx.execute("SELECT 123") with self.assertRaisesRegex( - edgedb.InterfaceError, + gel.InterfaceError, r"already in a `with` block", ): for tx in self.client.transaction(): diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 497af782..e7774f8f 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -21,7 +21,7 @@ import edgedb -from edgedb import _testbase as tb +from gel import _testbase as tb from edgedb import TransactionOptions diff --git a/tests/test_vector.py b/tests/test_vector.py index 514ded14..b4708281 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -16,7 +16,7 @@ # limitations under the License. # -from edgedb import _testbase as tb +from gel import _testbase as tb import edgedb import array diff --git a/tools/make_import_shims.py b/tools/make_import_shims.py new file mode 100644 index 00000000..60325af8 --- /dev/null +++ b/tools/make_import_shims.py @@ -0,0 +1,42 @@ +import os +import sys + +MODS = sorted(['gel', 'gel._taskgroup', 'gel._version', 'gel.abstract', 'gel.ai', 'gel.ai.core', 'gel.ai.types', 'gel.asyncio_client', 'gel.base_client', 'gel.blocking_client', 'gel.codegen', 'gel.color', 'gel.con_utils', 'gel.credentials', 'gel.datatypes', 'gel.datatypes.datatypes', 'gel.datatypes.range', 'gel.describe', 'gel.enums', 'gel.errors', 'gel.errors._base', 'gel.errors.tags', 'gel.introspect', 'gel.options', 'gel.pgproto', 'gel.pgproto.pgproto', 'gel.pgproto.types', 'gel.platform', 'gel.protocol', 'gel.protocol.asyncio_proto', 'gel.protocol.blocking_proto', 'gel.protocol.protocol', 'gel.scram', 'gel.scram.saslprep', 'gel.transaction']) + + + +def main(): + for mod in MODS: + is_package = any(k.startswith(mod + '.') for k in MODS) + + nmod = 'edgedb' + mod[len('gel'):] + slash_name = nmod.replace('.', '/') + if is_package: + os.mkdir(slash_name) + fname = slash_name + '/__init__.py' + else: + fname = slash_name + '.py' + + # import * skips things not in __all__ or with underscores at + # the start, so we have to do some nonsense. + with open(fname, 'w') as f: + f.write(f'''\ +# Auto-generated shim +import {mod} as _mod +import sys as _sys +_cur = _sys.modules['{nmod}'] +for _k in vars(_mod): + if not _k.startswith('__') or _k in ('__all__', '__doc__'): + setattr(_cur, _k, getattr(_mod, _k)) +del _cur +del _sys +del _mod +del _k +''') + + with open('edgedb/py.typed', 'w') as f: + pass + + +if __name__ == '__main__': + main()