diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index ca0eecc0..015cf07d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -4,14 +4,15 @@ on:
workflow_dispatch:
inputs: {}
- pull_request:
- branches:
- - "master"
- - "ci"
- - "release/[0-9]+.x"
- - "release/[0-9]+.[0-9]+.x"
- paths:
- - "edgedb/_version.py"
+ # XXX: Commented out to prevent firing during this refactor.
+ # pull_request:
+ # branches:
+ # - "master"
+ # - "ci"
+ # - "release/[0-9]+.x"
+ # - "release/[0-9]+.[0-9]+.x"
+ # paths:
+ # - "gel/_version.py"
jobs:
validate-release-request:
diff --git a/.gitignore b/.gitignore
index 84d8e3dc..4f1a755b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,3 +29,4 @@ docs/_build
/edgedb/**/*.html
/tmp
/wheelhouse
+/env*
diff --git a/.gitmodules b/.gitmodules
index 1830ced8..cdefde59 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,5 +1,5 @@
[submodule "edgedb/pgproto"]
- path = edgedb/pgproto
+ path = gel/pgproto
url = https://github.com/MagicStack/py-pgproto.git
[submodule "tests/shared-client-testcases"]
path = tests/shared-client-testcases
diff --git a/Makefile b/Makefile
index 4244be36..9f641039 100644
--- a/Makefile
+++ b/Makefile
@@ -11,25 +11,25 @@ all: compile
clean:
rm -fr $(ROOT)/dist/
rm -fr $(ROOT)/doc/_build/
- rm -fr $(ROOT)/edgedb/pgproto/*.c
- rm -fr $(ROOT)/edgedb/pgproto/*.html
- rm -fr $(ROOT)/edgedb/pgproto/codecs/*.html
- rm -fr $(ROOT)/edgedb/protocol/*.c
- rm -fr $(ROOT)/edgedb/protocol/*.html
- rm -fr $(ROOT)/edgedb/protocol/*.so
- rm -fr $(ROOT)/edgedb/datatypes/*.so
- rm -fr $(ROOT)/edgedb/datatypes/datatypes.c
+ rm -fr $(ROOT)/gel/pgproto/*.c
+ rm -fr $(ROOT)/gel/pgproto/*.html
+ rm -fr $(ROOT)/gel/pgproto/codecs/*.html
+ rm -fr $(ROOT)/gel/protocol/*.c
+ rm -fr $(ROOT)/gel/protocol/*.html
+ rm -fr $(ROOT)/gel/protocol/*.so
+ rm -fr $(ROOT)/gel/datatypes/*.so
+ rm -fr $(ROOT)/gel/datatypes/datatypes.c
rm -fr $(ROOT)/build
- rm -fr $(ROOT)/edgedb/protocol/codecs/*.html
+ rm -fr $(ROOT)/gel/protocol/codecs/*.html
find . -name '__pycache__' | xargs rm -rf
_touch:
- rm -fr $(ROOT)/edgedb/datatypes/datatypes.c
- rm -fr $(ROOT)/edgedb/protocol/protocol.c
- find $(ROOT)/edgedb/protocol -name '*.pyx' | xargs touch
- find $(ROOT)/edgedb/datatypes -name '*.pyx' | xargs touch
- find $(ROOT)/edgedb/datatypes -name '*.c' | xargs touch
+ rm -fr $(ROOT)/gel/datatypes/datatypes.c
+ rm -fr $(ROOT)/gel/protocol/protocol.c
+ find $(ROOT)/gel/protocol -name '*.pyx' | xargs touch
+ find $(ROOT)/gel/datatypes -name '*.pyx' | xargs touch
+ find $(ROOT)/gel/datatypes -name '*.c' | xargs touch
compile: _touch
@@ -39,12 +39,12 @@ compile: _touch
gen-errors:
edb gen-errors --import "$(echo "from edgedb.errors._base import *"; echo "from edgedb.errors.tags import *")" \
--extra-all "_base.__all__" --stdout --client > $(ROOT)/.errors
- mv $(ROOT)/.errors $(ROOT)/edgedb/errors/__init__.py
+ mv $(ROOT)/.errors $(ROOT)/gel/errors/__init__.py
$(PYTHON) tools/gen_init.py
gen-types:
- edb gen-types --stdout > $(ROOT)/edgedb/protocol/codecs/edb_types.pxi
+ edb gen-types --stdout > $(ROOT)/gel/protocol/codecs/edb_types.pxi
debug: _touch
diff --git a/edgedb/__init__.py b/edgedb/__init__.py
index f714cc29..711d3ee0 100644
--- a/edgedb/__init__.py
+++ b/edgedb/__init__.py
@@ -1,284 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-# flake8: noqa
-
-from ._version import __version__
-
-from edgedb.datatypes.datatypes import (
- Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory
-)
-from edgedb.datatypes.datatypes import Set, Object, Array
-from edgedb.datatypes.range import Range, MultiRange
-
-from .abstract import (
- Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor,
-)
-
-from .asyncio_client import (
- create_async_client,
- AsyncIOClient
-)
-
-from .blocking_client import create_client, Client
-from .enums import Cardinality, ElementKind
-from .options import RetryCondition, IsolationLevel, default_backoff
-from .options import RetryOptions, TransactionOptions
-from .options import State
-
-from .errors._base import EdgeDBError, EdgeDBMessage
-
-__all__ = [
- "Array",
- "AsyncIOClient",
- "AsyncIOExecutor",
- "AsyncIOReadOnlyExecutor",
- "Cardinality",
- "Client",
- "ConfigMemory",
- "DateDuration",
- "EdgeDBError",
- "EdgeDBMessage",
- "ElementKind",
- "EnumValue",
- "Executor",
- "IsolationLevel",
- "NamedTuple",
- "Object",
- "Range",
- "ReadOnlyExecutor",
- "RelativeDuration",
- "RetryCondition",
- "RetryOptions",
- "Set",
- "State",
- "TransactionOptions",
- "Tuple",
- "create_async_client",
- "create_client",
- "default_backoff",
-]
-
-
-# The below is generated by `make gen-errors`.
-# DO NOT MODIFY BY HAND.
-#
-#
-from .errors import (
- InternalServerError,
- UnsupportedFeatureError,
- ProtocolError,
- BinaryProtocolError,
- UnsupportedProtocolVersionError,
- TypeSpecNotFoundError,
- UnexpectedMessageError,
- InputDataError,
- ParameterTypeMismatchError,
- StateMismatchError,
- ResultCardinalityMismatchError,
- CapabilityError,
- UnsupportedCapabilityError,
- DisabledCapabilityError,
- QueryError,
- InvalidSyntaxError,
- EdgeQLSyntaxError,
- SchemaSyntaxError,
- GraphQLSyntaxError,
- InvalidTypeError,
- InvalidTargetError,
- InvalidLinkTargetError,
- InvalidPropertyTargetError,
- InvalidReferenceError,
- UnknownModuleError,
- UnknownLinkError,
- UnknownPropertyError,
- UnknownUserError,
- UnknownDatabaseError,
- UnknownParameterError,
- SchemaError,
- SchemaDefinitionError,
- InvalidDefinitionError,
- InvalidModuleDefinitionError,
- InvalidLinkDefinitionError,
- InvalidPropertyDefinitionError,
- InvalidUserDefinitionError,
- InvalidDatabaseDefinitionError,
- InvalidOperatorDefinitionError,
- InvalidAliasDefinitionError,
- InvalidFunctionDefinitionError,
- InvalidConstraintDefinitionError,
- InvalidCastDefinitionError,
- DuplicateDefinitionError,
- DuplicateModuleDefinitionError,
- DuplicateLinkDefinitionError,
- DuplicatePropertyDefinitionError,
- DuplicateUserDefinitionError,
- DuplicateDatabaseDefinitionError,
- DuplicateOperatorDefinitionError,
- DuplicateViewDefinitionError,
- DuplicateFunctionDefinitionError,
- DuplicateConstraintDefinitionError,
- DuplicateCastDefinitionError,
- DuplicateMigrationError,
- SessionTimeoutError,
- IdleSessionTimeoutError,
- QueryTimeoutError,
- TransactionTimeoutError,
- IdleTransactionTimeoutError,
- ExecutionError,
- InvalidValueError,
- DivisionByZeroError,
- NumericOutOfRangeError,
- AccessPolicyError,
- QueryAssertionError,
- IntegrityError,
- ConstraintViolationError,
- CardinalityViolationError,
- MissingRequiredError,
- TransactionError,
- TransactionConflictError,
- TransactionSerializationError,
- TransactionDeadlockError,
- WatchError,
- ConfigurationError,
- AccessError,
- AuthenticationError,
- AvailabilityError,
- BackendUnavailableError,
- ServerOfflineError,
- BackendError,
- UnsupportedBackendFeatureError,
- LogMessage,
- WarningMessage,
- ClientError,
- ClientConnectionError,
- ClientConnectionFailedError,
- ClientConnectionFailedTemporarilyError,
- ClientConnectionTimeoutError,
- ClientConnectionClosedError,
- InterfaceError,
- QueryArgumentError,
- MissingArgumentError,
- UnknownArgumentError,
- InvalidArgumentError,
- NoDataError,
- InternalClientError,
-)
-
-__all__.extend([
- "InternalServerError",
- "UnsupportedFeatureError",
- "ProtocolError",
- "BinaryProtocolError",
- "UnsupportedProtocolVersionError",
- "TypeSpecNotFoundError",
- "UnexpectedMessageError",
- "InputDataError",
- "ParameterTypeMismatchError",
- "StateMismatchError",
- "ResultCardinalityMismatchError",
- "CapabilityError",
- "UnsupportedCapabilityError",
- "DisabledCapabilityError",
- "QueryError",
- "InvalidSyntaxError",
- "EdgeQLSyntaxError",
- "SchemaSyntaxError",
- "GraphQLSyntaxError",
- "InvalidTypeError",
- "InvalidTargetError",
- "InvalidLinkTargetError",
- "InvalidPropertyTargetError",
- "InvalidReferenceError",
- "UnknownModuleError",
- "UnknownLinkError",
- "UnknownPropertyError",
- "UnknownUserError",
- "UnknownDatabaseError",
- "UnknownParameterError",
- "SchemaError",
- "SchemaDefinitionError",
- "InvalidDefinitionError",
- "InvalidModuleDefinitionError",
- "InvalidLinkDefinitionError",
- "InvalidPropertyDefinitionError",
- "InvalidUserDefinitionError",
- "InvalidDatabaseDefinitionError",
- "InvalidOperatorDefinitionError",
- "InvalidAliasDefinitionError",
- "InvalidFunctionDefinitionError",
- "InvalidConstraintDefinitionError",
- "InvalidCastDefinitionError",
- "DuplicateDefinitionError",
- "DuplicateModuleDefinitionError",
- "DuplicateLinkDefinitionError",
- "DuplicatePropertyDefinitionError",
- "DuplicateUserDefinitionError",
- "DuplicateDatabaseDefinitionError",
- "DuplicateOperatorDefinitionError",
- "DuplicateViewDefinitionError",
- "DuplicateFunctionDefinitionError",
- "DuplicateConstraintDefinitionError",
- "DuplicateCastDefinitionError",
- "DuplicateMigrationError",
- "SessionTimeoutError",
- "IdleSessionTimeoutError",
- "QueryTimeoutError",
- "TransactionTimeoutError",
- "IdleTransactionTimeoutError",
- "ExecutionError",
- "InvalidValueError",
- "DivisionByZeroError",
- "NumericOutOfRangeError",
- "AccessPolicyError",
- "QueryAssertionError",
- "IntegrityError",
- "ConstraintViolationError",
- "CardinalityViolationError",
- "MissingRequiredError",
- "TransactionError",
- "TransactionConflictError",
- "TransactionSerializationError",
- "TransactionDeadlockError",
- "WatchError",
- "ConfigurationError",
- "AccessError",
- "AuthenticationError",
- "AvailabilityError",
- "BackendUnavailableError",
- "ServerOfflineError",
- "BackendError",
- "UnsupportedBackendFeatureError",
- "LogMessage",
- "WarningMessage",
- "ClientError",
- "ClientConnectionError",
- "ClientConnectionFailedError",
- "ClientConnectionFailedTemporarilyError",
- "ClientConnectionTimeoutError",
- "ClientConnectionClosedError",
- "InterfaceError",
- "QueryArgumentError",
- "MissingArgumentError",
- "UnknownArgumentError",
- "InvalidArgumentError",
- "NoDataError",
- "InternalClientError",
-])
-#
+# Auto-generated shim
+import gel as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/_taskgroup.py b/edgedb/_taskgroup.py
index 0a1859d7..93ca15db 100644
--- a/edgedb/_taskgroup.py
+++ b/edgedb/_taskgroup.py
@@ -1,295 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import asyncio
-import functools
-import itertools
-import textwrap
-import traceback
-
-
-class TaskGroup:
-
- def __init__(self, *, name=None):
- if name is None:
- self._name = f'tg-{_name_counter()}'
- else:
- self._name = str(name)
-
- self._entered = False
- self._exiting = False
- self._aborting = False
- self._loop = None
- self._parent_task = None
- self._parent_cancel_requested = False
- self._tasks = set()
- self._unfinished_tasks = 0
- self._errors = []
- self._base_error = None
- self._on_completed_fut = None
-
- def get_name(self):
- return self._name
-
- def __repr__(self):
- msg = f''
- return msg
-
- async def __aenter__(self):
- if self._entered:
- raise RuntimeError(
- f"TaskGroup {self!r} has been already entered")
- self._entered = True
-
- if self._loop is None:
- self._loop = asyncio.get_event_loop()
-
- if hasattr(asyncio, 'current_task'):
- self._parent_task = asyncio.current_task(self._loop)
- else:
- self._parent_task = asyncio.Task.current_task(self._loop)
-
- if self._parent_task is None:
- raise RuntimeError(
- f'TaskGroup {self!r} cannot determine the parent task')
- self._patch_task(self._parent_task)
-
- return self
-
- async def __aexit__(self, et, exc, tb):
- self._exiting = True
- propagate_cancelation = False
-
- if (exc is not None and
- self._is_base_error(exc) and
- self._base_error is None):
- self._base_error = exc
-
- if et is asyncio.CancelledError:
- if self._parent_cancel_requested:
- # Only if we did request task to cancel ourselves
- # we mark it as no longer cancelled.
- self._parent_task.__cancel_requested__ = False
- else:
- propagate_cancelation = True
-
- if et is not None and not self._aborting:
- # Our parent task is being cancelled:
- #
- # async with TaskGroup() as g:
- # g.create_task(...)
- # await ... # <- CancelledError
- #
- if et is asyncio.CancelledError:
- propagate_cancelation = True
-
- # or there's an exception in "async with":
- #
- # async with TaskGroup() as g:
- # g.create_task(...)
- # 1 / 0
- #
- self._abort()
-
- # We use while-loop here because "self._on_completed_fut"
- # can be cancelled multiple times if our parent task
- # is being cancelled repeatedly (or even once, when
- # our own cancellation is already in progress)
- while self._unfinished_tasks:
- if self._on_completed_fut is None:
- self._on_completed_fut = self._loop.create_future()
-
- try:
- await self._on_completed_fut
- except asyncio.CancelledError:
- if not self._aborting:
- # Our parent task is being cancelled:
- #
- # async def wrapper():
- # async with TaskGroup() as g:
- # g.create_task(foo)
- #
- # "wrapper" is being cancelled while "foo" is
- # still running.
- propagate_cancelation = True
- self._abort()
-
- self._on_completed_fut = None
-
- assert self._unfinished_tasks == 0
- self._on_completed_fut = None # no longer needed
-
- if self._base_error is not None:
- raise self._base_error
-
- if propagate_cancelation:
- # The wrapping task was cancelled; since we're done with
- # closing all child tasks, just propagate the cancellation
- # request now.
- raise asyncio.CancelledError()
-
- if et is not None and et is not asyncio.CancelledError:
- self._errors.append(exc)
-
- if self._errors:
- # Exceptions are heavy objects that can have object
- # cycles (bad for GC); let's not keep a reference to
- # a bunch of them.
- errors = self._errors
- self._errors = None
-
- me = TaskGroupError('unhandled errors in a TaskGroup',
- errors=errors)
- raise me from None
-
- def create_task(self, coro):
- if not self._entered:
- raise RuntimeError(f"TaskGroup {self!r} has not been entered")
- if self._exiting:
- raise RuntimeError(f"TaskGroup {self!r} is awaiting in exit")
- task = self._loop.create_task(coro)
- task.add_done_callback(self._on_task_done)
- self._unfinished_tasks += 1
- self._tasks.add(task)
- return task
-
- def _is_base_error(self, exc):
- assert isinstance(exc, BaseException)
- return not isinstance(exc, Exception)
-
- def _patch_task(self, task):
- # In Python 3.8 we'll need proper API on asyncio.Task to
- # make TaskGroups possible. We need to be able to access
- # information about task cancellation, more specifically,
- # we need a flag to say if a task was cancelled or not.
- # We also need to be able to flip that flag.
-
- def _task_cancel(task, orig_cancel):
- task.__cancel_requested__ = True
- return orig_cancel()
-
- if hasattr(task, '__cancel_requested__'):
- return
-
- task.__cancel_requested__ = False
- # confirm that we were successful at adding the new attribute:
- assert not task.__cancel_requested__
-
- orig_cancel = task.cancel
- task.cancel = functools.partial(_task_cancel, task, orig_cancel)
-
- def _abort(self):
- self._aborting = True
-
- for t in self._tasks:
- if not t.done():
- t.cancel()
-
- def _on_task_done(self, task):
- self._unfinished_tasks -= 1
- assert self._unfinished_tasks >= 0
-
- if self._exiting and not self._unfinished_tasks:
- if not self._on_completed_fut.done():
- self._on_completed_fut.set_result(True)
-
- if task.cancelled():
- return
-
- exc = task.exception()
- if exc is None:
- return
-
- self._errors.append(exc)
- if self._is_base_error(exc) and self._base_error is None:
- self._base_error = exc
-
- if self._parent_task.done():
- # Not sure if this case is possible, but we want to handle
- # it anyways.
- self._loop.call_exception_handler({
- 'message': f'Task {task!r} has errored out but its parent '
- f'task {self._parent_task} is already completed',
- 'exception': exc,
- 'task': task,
- })
- return
-
- self._abort()
- if not self._parent_task.__cancel_requested__:
- # If parent task *is not* being cancelled, it means that we want
- # to manually cancel it to abort whatever is being run right now
- # in the TaskGroup. But we want to mark parent task as
- # "not cancelled" later in __aexit__. Example situation that
- # we need to handle:
- #
- # async def foo():
- # try:
- # async with TaskGroup() as g:
- # g.create_task(crash_soon())
- # await something # <- this needs to be canceled
- # # by the TaskGroup, e.g.
- # # foo() needs to be cancelled
- # except Exception:
- # # Ignore any exceptions raised in the TaskGroup
- # pass
- # await something_else # this line has to be called
- # # after TaskGroup is finished.
- self._parent_cancel_requested = True
- self._parent_task.cancel()
-
-
-class MultiError(Exception):
-
- def __init__(self, msg, *args, errors=()):
- if errors:
- types = set(type(e).__name__ for e in errors)
- msg = f'{msg}; {len(errors)} sub errors: ({", ".join(types)})'
- for er in errors:
- msg += f'\n + {type(er).__name__}: {er}'
- if er.__traceback__:
- er_tb = ''.join(traceback.format_tb(er.__traceback__))
- er_tb = textwrap.indent(er_tb, ' | ')
- msg += f'\n{er_tb}\n'
- super().__init__(msg, *args)
- self.__errors__ = tuple(errors)
-
- def get_error_types(self):
- return {type(e) for e in self.__errors__}
-
- def __reduce__(self):
- return (type(self), (self.args,), {'__errors__': self.__errors__})
-
-
-class TaskGroupError(MultiError):
- pass
-
-
-_name_counter = itertools.count(1).__next__
+# Auto-generated shim
+import gel._taskgroup as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb._taskgroup']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/_version.py b/edgedb/_version.py
index 41998d40..5f0756de 100644
--- a/edgedb/_version.py
+++ b/edgedb/_version.py
@@ -1,31 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# This file MUST NOT contain anything but the __version__ assignment.
-#
-# When making a release, change the value of __version__
-# to an appropriate value, and open a pull request against
-# the correct branch (master if making a new feature release).
-# The commit message MUST contain a properly formatted release
-# log, and the commit must be signed.
-#
-# The release automation will: build and test the packages for the
-# supported platforms, publish the packages on PyPI, merge the PR
-# to the target branch, create a Git tag pointing to the commit.
-
-__version__ = '3.0.0b2'
+# Auto-generated shim
+import gel._version as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb._version']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/abstract.py b/edgedb/abstract.py
index 0c2c06cc..8ba2b2a2 100644
--- a/edgedb/abstract.py
+++ b/edgedb/abstract.py
@@ -1,437 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-from __future__ import annotations
-import abc
-import dataclasses
-import typing
-
-from . import describe
-from . import enums
-from . import options
-from .protocol import protocol
-
-__all__ = (
- "QueryWithArgs",
- "QueryCache",
- "QueryOptions",
- "QueryContext",
- "Executor",
- "AsyncIOExecutor",
- "ReadOnlyExecutor",
- "AsyncIOReadOnlyExecutor",
- "DescribeContext",
- "DescribeResult",
-)
-
-
-class QueryWithArgs(typing.NamedTuple):
- query: str
- args: typing.Tuple
- kwargs: typing.Dict[str, typing.Any]
- input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL
-
-
-class QueryCache(typing.NamedTuple):
- codecs_registry: protocol.CodecsRegistry
- query_cache: protocol.LRUMapping
-
-
-class QueryOptions(typing.NamedTuple):
- output_format: protocol.OutputFormat
- expect_one: bool
- required_one: bool
-
-
-class QueryContext(typing.NamedTuple):
- query: QueryWithArgs
- cache: QueryCache
- query_options: QueryOptions
- retry_options: typing.Optional[options.RetryOptions]
- state: typing.Optional[options.State]
- warning_handler: options.WarningHandler
-
- def lower(
- self, *, allow_capabilities: enums.Capability
- ) -> protocol.ExecuteContext:
- return protocol.ExecuteContext(
- query=self.query.query,
- args=self.query.args,
- kwargs=self.query.kwargs,
- reg=self.cache.codecs_registry,
- qc=self.cache.query_cache,
- input_language=self.query.input_language,
- output_format=self.query_options.output_format,
- expect_one=self.query_options.expect_one,
- required_one=self.query_options.required_one,
- allow_capabilities=allow_capabilities,
- state=self.state.as_dict() if self.state else None,
- )
-
-
-class ExecuteContext(typing.NamedTuple):
- query: QueryWithArgs
- cache: QueryCache
- state: typing.Optional[options.State]
- warning_handler: options.WarningHandler
-
- def lower(
- self, *, allow_capabilities: enums.Capability
- ) -> protocol.ExecuteContext:
- return protocol.ExecuteContext(
- query=self.query.query,
- args=self.query.args,
- kwargs=self.query.kwargs,
- reg=self.cache.codecs_registry,
- qc=self.cache.query_cache,
- input_language=self.query.input_language,
- output_format=protocol.OutputFormat.NONE,
- allow_capabilities=allow_capabilities,
- state=self.state.as_dict() if self.state else None,
- )
-
-
-@dataclasses.dataclass
-class DescribeContext:
- query: str
- state: typing.Optional[options.State]
- inject_type_names: bool
- input_language: protocol.InputLanguage
- output_format: protocol.OutputFormat
- expect_one: bool
-
- def lower(
- self, *, allow_capabilities: enums.Capability
- ) -> protocol.ExecuteContext:
- return protocol.ExecuteContext(
- query=self.query,
- args=None,
- kwargs=None,
- reg=protocol.CodecsRegistry(),
- qc=protocol.LRUMapping(maxsize=1),
- input_language=self.input_language,
- output_format=self.output_format,
- expect_one=self.expect_one,
- inline_typenames=self.inject_type_names,
- allow_capabilities=allow_capabilities,
- state=self.state.as_dict() if self.state else None,
- )
-
-
-@dataclasses.dataclass
-class DescribeResult:
- input_type: typing.Optional[describe.AnyType]
- output_type: typing.Optional[describe.AnyType]
- output_cardinality: enums.Cardinality
- capabilities: enums.Capability
-
-
-_query_opts = QueryOptions(
- output_format=protocol.OutputFormat.BINARY,
- expect_one=False,
- required_one=False,
-)
-_query_single_opts = QueryOptions(
- output_format=protocol.OutputFormat.BINARY,
- expect_one=True,
- required_one=False,
-)
-_query_required_single_opts = QueryOptions(
- output_format=protocol.OutputFormat.BINARY,
- expect_one=True,
- required_one=True,
-)
-_query_json_opts = QueryOptions(
- output_format=protocol.OutputFormat.JSON,
- expect_one=False,
- required_one=False,
-)
-_query_single_json_opts = QueryOptions(
- output_format=protocol.OutputFormat.JSON,
- expect_one=True,
- required_one=False,
-)
-_query_required_single_json_opts = QueryOptions(
- output_format=protocol.OutputFormat.JSON,
- expect_one=True,
- required_one=True,
-)
-
-
-class BaseReadOnlyExecutor(abc.ABC):
- __slots__ = ()
-
- @abc.abstractmethod
- def _get_query_cache(self) -> QueryCache:
- ...
-
- def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
- return None
-
- @abc.abstractmethod
- def _get_state(self) -> options.State:
- ...
-
- @abc.abstractmethod
- def _get_warning_handler(self) -> options.WarningHandler:
- ...
-
-
-class ReadOnlyExecutor(BaseReadOnlyExecutor):
- """Subclasses can execute *at least* read-only queries"""
-
- __slots__ = ()
-
- @abc.abstractmethod
- def _query(self, query_context: QueryContext):
- ...
-
- def query(self, query: str, *args, **kwargs) -> list:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_single(
- self, query: str, *args, **kwargs
- ) -> typing.Union[typing.Any, None]:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_single_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_required_single_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_json(self, query: str, *args, **kwargs) -> str:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_single_json(self, query: str, *args, **kwargs) -> str:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_single_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_required_single_json(self, query: str, *args, **kwargs) -> str:
- return self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_required_single_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
- return self._query(QueryContext(
- query=QueryWithArgs(
- query,
- args,
- kwargs,
- input_language=protocol.InputLanguage.SQL,
- ),
- cache=self._get_query_cache(),
- query_options=_query_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- @abc.abstractmethod
- def _execute(self, execute_context: ExecuteContext):
- ...
-
- def execute(self, commands: str, *args, **kwargs) -> None:
- self._execute(ExecuteContext(
- query=QueryWithArgs(commands, args, kwargs),
- cache=self._get_query_cache(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- def execute_sql(self, commands: str, *args, **kwargs) -> None:
- self._execute(ExecuteContext(
- query=QueryWithArgs(
- commands,
- args,
- kwargs,
- input_language=protocol.InputLanguage.SQL,
- ),
- cache=self._get_query_cache(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
-
-class Executor(ReadOnlyExecutor):
- """Subclasses can execute both read-only and modification queries"""
-
- __slots__ = ()
-
-
-class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor):
- """Subclasses can execute *at least* read-only queries"""
-
- __slots__ = ()
-
- @abc.abstractmethod
- async def _query(self, query_context: QueryContext):
- ...
-
- async def query(self, query: str, *args, **kwargs) -> list:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_single_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_required_single(
- self,
- query: str,
- *args,
- **kwargs
- ) -> typing.Any:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_required_single_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_json(self, query: str, *args, **kwargs) -> str:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_single_json(self, query: str, *args, **kwargs) -> str:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_single_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_required_single_json(
- self,
- query: str,
- *args,
- **kwargs
- ) -> str:
- return await self._query(QueryContext(
- query=QueryWithArgs(query, args, kwargs),
- cache=self._get_query_cache(),
- query_options=_query_required_single_json_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
- return await self._query(QueryContext(
- query=QueryWithArgs(
- query,
- args,
- kwargs,
- input_language=protocol.InputLanguage.SQL,
- ),
- cache=self._get_query_cache(),
- query_options=_query_opts,
- retry_options=self._get_retry_options(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- @abc.abstractmethod
- async def _execute(self, execute_context: ExecuteContext) -> None:
- ...
-
- async def execute(self, commands: str, *args, **kwargs) -> None:
- await self._execute(ExecuteContext(
- query=QueryWithArgs(commands, args, kwargs),
- cache=self._get_query_cache(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
- async def execute_sql(self, commands: str, *args, **kwargs) -> None:
- await self._execute(ExecuteContext(
- query=QueryWithArgs(
- commands,
- args,
- kwargs,
- input_language=protocol.InputLanguage.SQL,
- ),
- cache=self._get_query_cache(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
-
-class AsyncIOExecutor(AsyncIOReadOnlyExecutor):
- """Subclasses can execute both read-only and modification queries"""
-
- __slots__ = ()
+# Auto-generated shim
+import gel.abstract as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.abstract']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/ai/__init__.py b/edgedb/ai/__init__.py
index 96111c2b..71c802a7 100644
--- a/edgedb/ai/__init__.py
+++ b/edgedb/ai/__init__.py
@@ -1,32 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
-from .core import create_ai, EdgeDBAI
-from .core import create_async_ai, AsyncEdgeDBAI
-
-__all__ = [
- "AIOptions",
- "ChatParticipantRole",
- "Prompt",
- "QueryContext",
- "create_ai",
- "EdgeDBAI",
- "create_async_ai",
- "AsyncEdgeDBAI",
-]
+# Auto-generated shim
+import gel.ai as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.ai']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py
index 69fe235d..ec1645ca 100644
--- a/edgedb/ai/core.py
+++ b/edgedb/ai/core.py
@@ -1,191 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from __future__ import annotations
-import typing
-
-import edgedb
-import httpx
-import httpx_sse
-
-from . import types
-
-
-def create_ai(client: edgedb.Client, **kwargs) -> EdgeDBAI:
- client.ensure_connected()
- return EdgeDBAI(client, types.AIOptions(**kwargs))
-
-
-async def create_async_ai(
- client: edgedb.AsyncIOClient, **kwargs
-) -> AsyncEdgeDBAI:
- await client.ensure_connected()
- return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))
-
-
-class BaseEdgeDBAI:
- options: types.AIOptions
- context: types.QueryContext
- client_cls = NotImplemented
-
- def __init__(
- self,
- client: typing.Union[edgedb.Client, edgedb.AsyncIOClient],
- options: types.AIOptions,
- **kwargs,
- ):
- pool = client._impl
- host, port = pool._working_addr
- params = pool._working_params
- proto = "http" if params.tls_security == "insecure" else "https"
- branch = params.branch
- self.options = options
- self.context = types.QueryContext(**kwargs)
- args = dict(
- base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
- verify=params.ssl_ctx,
- )
- if params.password is not None:
- args["auth"] = (params.user, params.password)
- elif params.secret_key is not None:
- args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
- self._init_client(**args)
-
- def _init_client(self, **kwargs):
- raise NotImplementedError
-
- def with_config(self, **kwargs) -> typing.Self:
- cls = type(self)
- rv = cls.__new__(cls)
- rv.options = self.options.derive(kwargs)
- rv.context = self.context
- rv.client = self.client
- return rv
-
- def with_context(self, **kwargs) -> typing.Self:
- cls = type(self)
- rv = cls.__new__(cls)
- rv.options = self.options
- rv.context = self.context.derive(kwargs)
- rv.client = self.client
- return rv
-
- def _make_rag_request(
- self,
- *,
- message: str,
- context: typing.Optional[types.QueryContext] = None,
- stream: bool,
- ) -> types.RAGRequest:
- if context is None:
- context = self.context
- return types.RAGRequest(
- model=self.options.model,
- prompt=self.options.prompt,
- context=context,
- query=message,
- stream=stream,
- )
-
-
-class EdgeDBAI(BaseEdgeDBAI):
- client: httpx.Client
-
- def _init_client(self, **kwargs):
- self.client = httpx.Client(**kwargs)
-
- def query_rag(
- self, message: str, context: typing.Optional[types.QueryContext] = None
- ) -> str:
- resp = self.client.post(
- **self._make_rag_request(
- context=context,
- message=message,
- stream=False,
- ).to_httpx_request()
- )
- resp.raise_for_status()
- return resp.json()["response"]
-
- def stream_rag(
- self, message: str, context: typing.Optional[types.QueryContext] = None
- ) -> typing.Iterator[str]:
- with httpx_sse.connect_sse(
- self.client,
- "post",
- **self._make_rag_request(
- context=context,
- message=message,
- stream=True,
- ).to_httpx_request(),
- ) as event_source:
- event_source.response.raise_for_status()
- for sse in event_source.iter_sse():
- yield sse.data
-
- def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
- resp = self.client.post(
- "/embeddings", json={"input": inputs, "model": model}
- )
- resp.raise_for_status()
- return resp.json()["data"][0]["embedding"]
-
-
-class AsyncEdgeDBAI(BaseEdgeDBAI):
- client: httpx.AsyncClient
-
- def _init_client(self, **kwargs):
- self.client = httpx.AsyncClient(**kwargs)
-
- async def query_rag(
- self, message: str, context: typing.Optional[types.QueryContext] = None
- ) -> str:
- resp = await self.client.post(
- **self._make_rag_request(
- context=context,
- message=message,
- stream=False,
- ).to_httpx_request()
- )
- resp.raise_for_status()
- return resp.json()["response"]
-
- async def stream_rag(
- self, message: str, context: typing.Optional[types.QueryContext] = None
- ) -> typing.Iterator[str]:
- async with httpx_sse.aconnect_sse(
- self.client,
- "post",
- **self._make_rag_request(
- context=context,
- message=message,
- stream=True,
- ).to_httpx_request(),
- ) as event_source:
- event_source.response.raise_for_status()
- async for sse in event_source.aiter_sse():
- yield sse.data
-
- async def generate_embeddings(
- self, *inputs: str, model: str
- ) -> list[float]:
- resp = await self.client.post(
- "/embeddings", json={"input": inputs, "model": model}
- )
- resp.raise_for_status()
- return resp.json()["data"][0]["embedding"]
+# Auto-generated shim
+import gel.ai.core as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.ai.core']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/ai/types.py b/edgedb/ai/types.py
index 41bf24c0..e4d1f77d 100644
--- a/edgedb/ai/types.py
+++ b/edgedb/ai/types.py
@@ -1,81 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import typing
-
-import dataclasses as dc
-import enum
-
-
-class ChatParticipantRole(enum.Enum):
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
-
-
-class Custom(typing.TypedDict):
- role: ChatParticipantRole
- content: str
-
-
-class Prompt:
- name: typing.Optional[str]
- id: typing.Optional[str]
- custom: typing.Optional[typing.List[Custom]]
-
-
-@dc.dataclass
-class AIOptions:
- model: str
- prompt: typing.Optional[Prompt] = None
-
- def derive(self, kwargs):
- return AIOptions(**{**dc.asdict(self), **kwargs})
-
-
-@dc.dataclass
-class QueryContext:
- query: str = ""
- variables: typing.Optional[typing.Dict[str, typing.Any]] = None
- globals: typing.Optional[typing.Dict[str, typing.Any]] = None
- max_object_count: typing.Optional[int] = None
-
- def derive(self, kwargs):
- return QueryContext(**{**dc.asdict(self), **kwargs})
-
-
-@dc.dataclass
-class RAGRequest:
- model: str
- prompt: typing.Optional[Prompt]
- context: QueryContext
- query: str
- stream: typing.Optional[bool]
-
- def to_httpx_request(self) -> typing.Dict[str, typing.Any]:
- return dict(
- url="/rag",
- headers={
- "Content-Type": "application/json",
- "Accept": (
- "text/event-stream" if self.stream else "application/json"
- ),
- },
- json=dc.asdict(self),
- )
+# Auto-generated shim
+import gel.ai.types as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.ai.types']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py
index 78625040..bc80c00f 100644
--- a/edgedb/asyncio_client.py
+++ b/edgedb/asyncio_client.py
@@ -1,448 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import asyncio
-import contextlib
-import logging
-import socket
-import ssl
-import typing
-
-from . import abstract
-from . import base_client
-from . import con_utils
-from . import errors
-from . import transaction
-from .protocol import asyncio_proto
-from .protocol.protocol import InputLanguage, OutputFormat
-
-
-__all__ = (
- 'create_async_client', 'AsyncIOClient'
-)
-
-
-logger = logging.getLogger(__name__)
-
-
-class AsyncIOConnection(base_client.BaseConnection):
- __slots__ = ("_loop",)
-
- def __init__(self, loop, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._loop = loop
-
- def is_closed(self):
- protocol = self._protocol
- return protocol is None or not protocol.connected
-
- async def connect_addr(self, addr, timeout):
- try:
- await asyncio.wait_for(self._connect_addr(addr), timeout)
- except asyncio.TimeoutError as e:
- raise TimeoutError from e
-
- async def sleep(self, seconds):
- await asyncio.sleep(seconds)
-
- async def aclose(self):
- """Send graceful termination message wait for connection to drop."""
- if not self.is_closed():
- try:
- self._protocol.terminate()
- await self._protocol.wait_for_disconnect()
- except (Exception, asyncio.CancelledError):
- self.terminate()
- raise
- finally:
- self._cleanup()
-
- def _protocol_factory(self):
- return asyncio_proto.AsyncIOProtocol(self._params, self._loop)
-
- async def _connect_addr(self, addr):
- tr = None
-
- try:
- if isinstance(addr, str):
- # UNIX socket
- tr, pr = await self._loop.create_unix_connection(
- self._protocol_factory, addr
- )
- else:
- try:
- tr, pr = await self._loop.create_connection(
- self._protocol_factory,
- *addr,
- ssl=self._params.ssl_ctx,
- server_hostname=(
- self._params.tls_server_name or addr[0]
- ),
- )
- except ssl.CertificateError as e:
- raise con_utils.wrap_error(e) from e
- except ssl.SSLError as e:
- raise con_utils.wrap_error(e) from e
- else:
- con_utils.check_alpn_protocol(
- tr.get_extra_info('ssl_object')
- )
- except socket.gaierror as e:
- # All name resolution errors are considered temporary
- raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e
- except OSError as e:
- raise con_utils.wrap_error(e) from e
- except Exception:
- if tr is not None:
- tr.close()
- raise
-
- pr.set_connection(self)
-
- try:
- await pr.connect()
- except OSError as e:
- if tr is not None:
- tr.close()
- raise con_utils.wrap_error(e) from e
- except BaseException:
- if tr is not None:
- tr.close()
- raise
-
- self._protocol = pr
- self._addr = addr
-
- def _dispatch_log_message(self, msg):
- for cb in self._log_listeners:
- self._loop.call_soon(cb, self, msg)
-
-
-class _PoolConnectionHolder(base_client.PoolConnectionHolder):
- __slots__ = ()
- _event_class = asyncio.Event
-
- async def close(self, *, wait=True):
- if self._con is None:
- return
- if wait:
- await self._con.aclose()
- else:
- self._pool._loop.create_task(self._con.aclose())
-
- async def wait_until_released(self, timeout=None):
- await self._release_event.wait()
-
-
-class _AsyncIOPoolImpl(base_client.BasePoolImpl):
- __slots__ = ('_loop',)
- _holder_class = _PoolConnectionHolder
-
- def __init__(
- self,
- connect_args,
- *,
- max_concurrency: typing.Optional[int],
- connection_class,
- ):
- if not issubclass(connection_class, AsyncIOConnection):
- raise TypeError(
- f'connection_class is expected to be a subclass of '
- f'edgedb.asyncio_client.AsyncIOConnection, '
- f'got {connection_class}')
- self._loop = None
- super().__init__(
- connect_args,
- lambda *args: connection_class(self._loop, *args),
- max_concurrency=max_concurrency,
- )
-
- def _ensure_initialized(self):
- if self._loop is None:
- self._loop = asyncio.get_event_loop()
- self._queue = asyncio.LifoQueue(maxsize=self._max_concurrency)
- self._first_connect_lock = asyncio.Lock()
- self._resize_holder_pool()
-
- def _set_queue_maxsize(self, maxsize):
- self._queue._maxsize = maxsize
-
- async def _maybe_get_first_connection(self):
- async with self._first_connect_lock:
- if self._working_addr is None:
- return await self._get_first_connection()
-
- async def acquire(self, timeout=None):
- self._ensure_initialized()
-
- async def _acquire_impl():
- ch = await self._queue.get() # type: _PoolConnectionHolder
- try:
- proxy = await ch.acquire() # type: AsyncIOConnection
- except (Exception, asyncio.CancelledError):
- self._queue.put_nowait(ch)
- raise
- else:
- # Record the timeout, as we will apply it by default
- # in release().
- ch._timeout = timeout
- return proxy
-
- if self._closing:
- raise errors.InterfaceError('pool is closing')
-
- if timeout is None:
- return await _acquire_impl()
- else:
- return await asyncio.wait_for(
- _acquire_impl(), timeout=timeout)
-
- async def _release(self, holder):
-
- if not isinstance(holder._con, AsyncIOConnection):
- raise errors.InterfaceError(
- f'release() received invalid connection: '
- f'{holder._con!r} does not belong to any connection pool'
- )
-
- timeout = None
-
- # Use asyncio.shield() to guarantee that task cancellation
- # does not prevent the connection from being returned to the
- # pool properly.
- return await asyncio.shield(holder.release(timeout))
-
- async def aclose(self):
- """Attempt to gracefully close all connections in the pool.
-
- Wait until all pool connections are released, close them and
- shut down the pool. If any error (including cancellation) occurs
- in ``close()`` the pool will terminate by calling
- _AsyncIOPoolImpl.terminate() .
-
- It is advisable to use :func:`python:asyncio.wait_for` to set
- a timeout.
- """
- if self._closed:
- return
-
- if not self._loop:
- self._closed = True
- return
-
- self._closing = True
-
- try:
- warning_callback = self._loop.call_later(
- 60, self._warn_on_long_close)
-
- release_coros = [
- ch.wait_until_released() for ch in self._holders]
- await asyncio.gather(*release_coros)
-
- close_coros = [
- ch.close() for ch in self._holders]
- await asyncio.gather(*close_coros)
-
- except (Exception, asyncio.CancelledError):
- self.terminate()
- raise
-
- finally:
- warning_callback.cancel()
- self._closed = True
- self._closing = False
-
- def _warn_on_long_close(self):
- logger.warning(
- 'AsyncIOClient.aclose() is taking over 60 seconds to complete. '
- 'Check if you have any unreleased connections left. '
- 'Use asyncio.wait_for() to set a timeout for '
- 'AsyncIOClient.aclose().')
-
-
-class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor):
-
- __slots__ = ("_managed", "_locked")
-
- def __init__(self, retry, client, iteration):
- super().__init__(retry, client, iteration)
- self._managed = False
- self._locked = False
-
- async def __aenter__(self):
- if self._managed:
- raise errors.InterfaceError(
- 'cannot enter context: already in an `async with` block')
- self._managed = True
- return self
-
- async def __aexit__(self, extype, ex, tb):
- with self._exclusive():
- self._managed = False
- return await self._exit(extype, ex)
-
- async def _ensure_transaction(self):
- if not self._managed:
- raise errors.InterfaceError(
- "Only managed retriable transactions are supported. "
- "Use `async with transaction:`"
- )
- await super()._ensure_transaction()
-
- async def _query(self, query_context: abstract.QueryContext):
- with self._exclusive():
- return await super()._query(query_context)
-
- async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- with self._exclusive():
- await super()._execute(execute_context)
-
- @contextlib.contextmanager
- def _exclusive(self):
- if self._locked:
- raise errors.InterfaceError(
- "concurrent queries within the same transaction "
- "are not allowed"
- )
- self._locked = True
- try:
- yield
- finally:
- self._locked = False
-
-
-class AsyncIORetry(transaction.BaseRetry):
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- # Note: when changing this code consider also
- # updating Retry.__next__.
- if self._done:
- raise StopAsyncIteration
- if self._next_backoff:
- await asyncio.sleep(self._next_backoff)
- self._done = True
- iteration = AsyncIOIteration(self, self._owner, self._iteration)
- self._iteration += 1
- return iteration
-
-
-class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor):
- """A lazy connection pool.
-
- A Client can be used to manage a set of connections to the database.
- Connections are first acquired from the pool, then used, and then released
- back to the pool. Once a connection is released, it's reset to close all
- open cursors and other resources *except* prepared statements.
-
- Clients are created by calling
- :func:`~edgedb.asyncio_client.create_async_client`.
- """
-
- __slots__ = ()
- _impl_class = _AsyncIOPoolImpl
-
- async def ensure_connected(self):
- await self._impl.ensure_connected()
- return self
-
- async def aclose(self):
- """Attempt to gracefully close all connections in the pool.
-
- Wait until all pool connections are released, close them and
- shut down the pool. If any error (including cancellation) occurs
- in ``aclose()`` the pool will terminate by calling
- AsyncIOClient.terminate() .
-
- It is advisable to use :func:`python:asyncio.wait_for` to set
- a timeout.
- """
- await self._impl.aclose()
-
- def transaction(self) -> AsyncIORetry:
- return AsyncIORetry(self)
-
- async def __aenter__(self):
- return await self.ensure_connected()
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self.aclose()
-
- async def _describe_query(
- self,
- query: str,
- *,
- inject_type_names: bool = False,
- input_language: InputLanguage = InputLanguage.EDGEQL,
- output_format: OutputFormat = OutputFormat.BINARY,
- expect_one: bool = False,
- ) -> abstract.DescribeResult:
- return await self._describe(abstract.DescribeContext(
- query=query,
- state=self._get_state(),
- inject_type_names=inject_type_names,
- input_language=input_language,
- output_format=output_format,
- expect_one=expect_one,
- ))
-
-
-def create_async_client(
- dsn=None,
- *,
- max_concurrency=None,
- host: str = None,
- port: int = None,
- credentials: str = None,
- credentials_file: str = None,
- user: str = None,
- password: str = None,
- secret_key: str = None,
- database: str = None,
- branch: str = None,
- tls_ca: str = None,
- tls_ca_file: str = None,
- tls_security: str = None,
- wait_until_available: int = 30,
- timeout: int = 10,
-):
- return AsyncIOClient(
- connection_class=AsyncIOConnection,
- max_concurrency=max_concurrency,
-
- # connect arguments
- dsn=dsn,
- host=host,
- port=port,
- credentials=credentials,
- credentials_file=credentials_file,
- user=user,
- password=password,
- secret_key=secret_key,
- database=database,
- branch=branch,
- tls_ca=tls_ca,
- tls_ca_file=tls_ca_file,
- tls_security=tls_security,
- wait_until_available=wait_until_available,
- timeout=timeout,
- )
+# Auto-generated shim
+import gel.asyncio_client as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.asyncio_client']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/base_client.py b/edgedb/base_client.py
index e1610219..0a491133 100644
--- a/edgedb/base_client.py
+++ b/edgedb/base_client.py
@@ -1,734 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import abc
-import random
-import time
-import typing
-
-from . import abstract
-from . import con_utils
-from . import enums
-from . import errors
-from . import options as _options
-from .protocol import protocol
-
-
-BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection')
-QUERY_CACHE_SIZE = 1000
-
-
-class BaseConnection(metaclass=abc.ABCMeta):
- _protocol: typing.Any
- _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]]
- _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]]
- _config: con_utils.ClientConfiguration
- _params: con_utils.ResolvedConnectConfig
- _log_listeners: typing.Set[
- typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None]
- ]
- __slots__ = (
- "__weakref__",
- "_protocol",
- "_addr",
- "_addrs",
- "_config",
- "_params",
- "_log_listeners",
- "_holder",
- )
-
- def __init__(
- self,
- addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]],
- config: con_utils.ClientConfiguration,
- params: con_utils.ResolvedConnectConfig,
- ):
- self._addr = None
- self._protocol = None
- self._addrs = addrs
- self._config = config
- self._params = params
- self._log_listeners = set()
- self._holder = None
-
- @abc.abstractmethod
- def _dispatch_log_message(self, msg):
- ...
-
- def _on_log_message(self, msg):
- if self._log_listeners:
- self._dispatch_log_message(msg)
-
- def connected_addr(self):
- return self._addr
-
- def _get_last_status(self) -> typing.Optional[str]:
- if self._protocol is None:
- return None
- status = self._protocol.last_status
- if status is not None:
- status = status.decode()
- return status
-
- def _cleanup(self):
- self._log_listeners.clear()
- if self._holder:
- self._holder._release_on_close()
- self._holder = None
-
- def add_log_listener(
- self: BaseConnection_T,
- callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage],
- None]
- ) -> None:
- """Add a listener for EdgeDB log messages.
-
- :param callable callback:
- A callable receiving the following arguments:
- **connection**: a Connection the callback is registered with;
- **message**: the `edgedb.EdgeDBMessage` message.
- """
- self._log_listeners.add(callback)
-
- def remove_log_listener(
- self: BaseConnection_T,
- callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage],
- None]
- ) -> None:
- """Remove a listening callback for log messages."""
- self._log_listeners.discard(callback)
-
- @property
- def dbname(self) -> str:
- return self._params.database
-
- @property
- def branch(self) -> str:
- return self._params.branch
-
- @abc.abstractmethod
- def is_closed(self) -> bool:
- ...
-
- @abc.abstractmethod
- async def connect_addr(self, addr, timeout):
- ...
-
- @abc.abstractmethod
- async def sleep(self, seconds):
- ...
-
- async def connect(self, *, single_attempt=False):
- start = time.monotonic()
- if single_attempt:
- max_time = 0
- else:
- max_time = start + self._config.wait_until_available
- iteration = 1
-
- while True:
- for addr in self._addrs:
- try:
- await self.connect_addr(addr, self._config.connect_timeout)
- except TimeoutError as e:
- if iteration == 1 or time.monotonic() < max_time:
- continue
- else:
- raise errors.ClientConnectionTimeoutError(
- f"connecting to {addr} failed in"
- f" {self._config.connect_timeout} sec"
- ) from e
- except errors.ClientConnectionError as e:
- if (
- e.has_tag(errors.SHOULD_RECONNECT) and
- (iteration == 1 or time.monotonic() < max_time)
- ):
- continue
- nice_err = e.__class__(
- con_utils.render_client_no_connection_error(
- e,
- addr,
- attempts=iteration,
- duration=time.monotonic() - start,
- ))
- raise nice_err from e.__cause__
- else:
- return
-
- iteration += 1
- await self.sleep(0.01 + random.random() * 0.2)
-
- async def privileged_execute(
- self, execute_context: abstract.ExecuteContext
- ):
- if self._protocol.is_legacy:
- await self._protocol.legacy_simple_query(
- execute_context.query.query, enums.Capability.ALL
- )
- else:
- await self._protocol.execute(
- execute_context.lower(allow_capabilities=enums.Capability.ALL)
- )
-
- def is_in_transaction(self) -> bool:
- """Return True if Connection is currently inside a transaction.
-
- :return bool: True if inside transaction, False otherwise.
- """
- return self._protocol.is_in_transaction()
-
- def get_settings(self) -> typing.Dict[str, typing.Any]:
- return self._protocol.get_settings()
-
- async def raw_query(self, query_context: abstract.QueryContext):
- if self.is_closed():
- await self.connect()
-
- reconnect = False
- i = 0
- if self._protocol.is_legacy:
- allow_capabilities = enums.Capability.LEGACY_EXECUTE
- else:
- allow_capabilities = enums.Capability.EXECUTE
- ctx = query_context.lower(allow_capabilities=allow_capabilities)
- while True:
- i += 1
- try:
- if reconnect:
- await self.connect(single_attempt=True)
- if self._protocol.is_legacy:
- return await self._protocol.legacy_execute_anonymous(ctx)
- else:
- res = await self._protocol.query(ctx)
- if ctx.warnings:
- res = query_context.warning_handler(ctx.warnings, res)
- return res
-
- except errors.EdgeDBError as e:
- if query_context.retry_options is None:
- raise
- if not e.has_tag(errors.SHOULD_RETRY):
- raise e
- # A query is read-only if it has no capabilities i.e.
- # capabilities == 0. Read-only queries are safe to retry.
- # Explicit transaction conflicts as well.
- if (
- ctx.capabilities != 0
- and not isinstance(e, errors.TransactionConflictError)
- ):
- raise e
- rule = query_context.retry_options.get_rule_for_exception(e)
- if i >= rule.attempts:
- raise e
- await self.sleep(rule.backoff(i))
- reconnect = self.is_closed()
-
- async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- if self._protocol.is_legacy:
- if execute_context.query.args or execute_context.query.kwargs:
- raise errors.InterfaceError(
- "Legacy protocol doesn't support arguments in execute()"
- )
- await self._protocol.legacy_simple_query(
- execute_context.query.query, enums.Capability.LEGACY_EXECUTE
- )
- else:
- ctx = execute_context.lower(
- allow_capabilities=enums.Capability.EXECUTE
- )
- res = await self._protocol.execute(ctx)
- if ctx.warnings:
- res = execute_context.warning_handler(ctx.warnings, res)
-
- async def describe(
- self, describe_context: abstract.DescribeContext
- ) -> abstract.DescribeResult:
- ctx = describe_context.lower(
- allow_capabilities=enums.Capability.EXECUTE
- )
- await self._protocol._parse(ctx)
- return abstract.DescribeResult(
- input_type=ctx.in_dc.make_type(describe_context),
- output_type=ctx.out_dc.make_type(describe_context),
- output_cardinality=enums.Cardinality(ctx.cardinality[0]),
- capabilities=ctx.capabilities,
- )
-
- def terminate(self):
- if not self.is_closed():
- try:
- self._protocol.abort()
- finally:
- self._cleanup()
-
- def __repr__(self):
- if self.is_closed():
- return '<{classname} [closed] {id:#x}>'.format(
- classname=self.__class__.__name__, id=id(self))
- else:
- return '<{classname} [connected to {addr}] {id:#x}>'.format(
- classname=self.__class__.__name__,
- addr=self.connected_addr(),
- id=id(self))
-
-
-class PoolConnectionHolder(abc.ABC):
- __slots__ = (
- "_con",
- "_pool",
- "_release_event",
- "_timeout",
- "_generation",
- )
- _event_class = NotImplemented
-
- def __init__(self, pool):
-
- self._pool = pool
- self._con = None
-
- self._timeout = None
- self._generation = None
-
- self._release_event = self._event_class()
- self._release_event.set()
-
- @abc.abstractmethod
- async def close(self, *, wait=True):
- ...
-
- @abc.abstractmethod
- async def wait_until_released(self, timeout=None):
- ...
-
- async def connect(self):
- if self._con is not None:
- raise errors.InternalClientError(
- 'PoolConnectionHolder.connect() called while another '
- 'connection already exists')
-
- self._con = await self._pool._get_new_connection()
- assert self._con._holder is None
- self._con._holder = self
- self._generation = self._pool._generation
-
- async def acquire(self) -> BaseConnection:
- if self._con is None or self._con.is_closed():
- self._con = None
- await self.connect()
-
- elif self._generation != self._pool._generation:
- # Connections have been expired, re-connect the holder.
- self._con._holder = None # don't release the connection
- await self.close(wait=False)
- self._con = None
- await self.connect()
-
- self._release_event.clear()
-
- return self._con
-
- async def release(self, timeout):
- if self._release_event.is_set():
- raise errors.InternalClientError(
- 'PoolConnectionHolder.release() called on '
- 'a free connection holder')
-
- if self._con.is_closed():
- # This is usually the case when the connection is broken rather
- # than closed by the user, so we need to call _release_on_close()
- # here to release the holder back to the queue, because
- # self._con._cleanup() was never called. On the other hand, it is
- # safe to call self._release() twice - the second call is no-op.
- self._release_on_close()
- return
-
- self._timeout = None
-
- if self._generation != self._pool._generation:
- # The connection has expired because it belongs to
- # an older generation (BasePoolImpl.expire_connections() has
- # been called.)
- await self.close()
- return
-
- # Free this connection holder and invalidate the
- # connection proxy.
- self._release()
-
- def terminate(self):
- if self._con is not None:
- # AsyncIOConnection.terminate() will call _release_on_close() to
- # finish holder cleanup.
- self._con.terminate()
-
- def _release_on_close(self):
- self._release()
- self._con = None
-
- def _release(self):
- """Release this connection holder."""
- if self._release_event.is_set():
- # The holder is not checked out.
- return
-
- self._release_event.set()
-
- # Put ourselves back to the pool queue.
- self._pool._queue.put_nowait(self)
-
-
-class BasePoolImpl(abc.ABC):
- __slots__ = (
- "_connect_args",
- "_codecs_registry",
- "_query_cache",
- "_connection_factory",
- "_queue",
- "_user_max_concurrency",
- "_max_concurrency",
- "_first_connect_lock",
- "_working_addr",
- "_working_config",
- "_working_params",
- "_holders",
- "_initialized",
- "_initializing",
- "_closing",
- "_closed",
- "_generation",
- )
-
- _holder_class = NotImplemented
-
- def __init__(
- self,
- connect_args,
- connection_factory,
- *,
- max_concurrency: typing.Optional[int],
- ):
- self._connection_factory = connection_factory
- self._connect_args = connect_args
- self._codecs_registry = protocol.CodecsRegistry()
- self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)
-
- if max_concurrency is not None and max_concurrency <= 0:
- raise ValueError(
- 'max_concurrency is expected to be greater than zero'
- )
-
- self._user_max_concurrency = max_concurrency
- self._max_concurrency = max_concurrency if max_concurrency else 1
-
- self._holders = []
- self._queue = None
-
- self._first_connect_lock = None
- self._working_addr = None
- self._working_config = None
- self._working_params = None
-
- self._closing = False
- self._closed = False
- self._generation = 0
-
- @abc.abstractmethod
- def _ensure_initialized(self):
- ...
-
- @abc.abstractmethod
- def _set_queue_maxsize(self, maxsize):
- ...
-
- @abc.abstractmethod
- async def _maybe_get_first_connection(self):
- ...
-
- @abc.abstractmethod
- async def acquire(self, timeout=None):
- ...
-
- @abc.abstractmethod
- async def _release(self, connection):
- ...
-
- @property
- def codecs_registry(self):
- return self._codecs_registry
-
- @property
- def query_cache(self):
- return self._query_cache
-
- def _resize_holder_pool(self):
- resize_diff = self._max_concurrency - len(self._holders)
-
- if (resize_diff > 0):
- if self._queue.maxsize != self._max_concurrency:
- self._set_queue_maxsize(self._max_concurrency)
-
- for _ in range(resize_diff):
- ch = self._holder_class(self)
-
- self._holders.append(ch)
- self._queue.put_nowait(ch)
- elif resize_diff < 0:
- # TODO: shrink the pool
- pass
-
- def get_max_concurrency(self):
- return self._max_concurrency
-
- def get_free_size(self):
- if self._queue is None:
- # Queue has not been initialized yet
- return self._max_concurrency
-
- return self._queue.qsize()
-
- def set_connect_args(self, dsn=None, **connect_kwargs):
- r"""Set the new connection arguments for this pool.
-
- The new connection arguments will be used for all subsequent
- new connection attempts. Existing connections will remain until
- they expire. Use BasePoolImpl.expire_connections() to expedite
- the connection expiry.
-
- :param str dsn:
- Connection arguments specified using as a single string in
- the following format:
- ``edgedb://user:pass@host:port/database?option=value``.
-
- :param \*\*connect_kwargs:
- Keyword arguments for the
- :func:`~edgedb.asyncio_client.create_async_client` function.
- """
-
- connect_kwargs["dsn"] = dsn
- self._connect_args = connect_kwargs
- self._codecs_registry = protocol.CodecsRegistry()
- self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)
- self._working_addr = None
- self._working_config = None
- self._working_params = None
-
- async def _get_first_connection(self):
- # First connection attempt on this pool.
- connect_config, client_config = con_utils.parse_connect_arguments(
- **self._connect_args,
- # ToDos
- command_timeout=None,
- server_settings=None,
- )
- con = self._connection_factory(
- [connect_config.address], client_config, connect_config
- )
- await con.connect()
- self._working_addr = con.connected_addr()
- self._working_config = client_config
- self._working_params = connect_config
-
- if self._user_max_concurrency is None:
- suggested_concurrency = con.get_settings().get(
- 'suggested_pool_concurrency')
- if suggested_concurrency:
- self._max_concurrency = suggested_concurrency
- self._resize_holder_pool()
- return con
-
- async def _get_new_connection(self):
- con = None
- if self._working_addr is None:
- con = await self._maybe_get_first_connection()
- if con is None:
- assert self._working_addr is not None
- # We've connected before and have a resolved address,
- # and parsed options and config.
- con = self._connection_factory(
- [self._working_addr],
- self._working_config,
- self._working_params,
- )
- await con.connect()
-
- return con
-
- async def release(self, connection):
-
- if not isinstance(connection, BaseConnection):
- raise errors.InterfaceError(
- f'BasePoolImpl.release() received invalid connection: '
- f'{connection!r} does not belong to any connection pool'
- )
-
- ch = connection._holder
- if ch is None:
- # Already released, do nothing.
- return
-
- if ch._pool is not self:
- raise errors.InterfaceError(
- f'BasePoolImpl.release() received invalid connection: '
- f'{connection!r} is not a member of this pool'
- )
-
- return await self._release(ch)
-
- def terminate(self):
- """Terminate all connections in the pool."""
- if self._closed:
- return
- for ch in self._holders:
- ch.terminate()
- self._closed = True
-
- def expire_connections(self):
- """Expire all currently open connections.
-
- Cause all currently open connections to get replaced on the
- next query.
- """
- self._generation += 1
-
- async def ensure_connected(self):
- self._ensure_initialized()
-
- for ch in self._holders:
- if ch._con is not None and not ch._con.is_closed():
- return
-
- ch = self._holders[0]
- ch._con = None
- await ch.connect()
-
-
-class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin):
- __slots__ = ("_impl", "_options")
- _impl_class = NotImplemented
-
- def __init__(
- self,
- *,
- connection_class,
- max_concurrency: typing.Optional[int],
- dsn=None,
- host: str = None,
- port: int = None,
- credentials: str = None,
- credentials_file: str = None,
- user: str = None,
- password: str = None,
- secret_key: str = None,
- database: str = None,
- branch: str = None,
- tls_ca: str = None,
- tls_ca_file: str = None,
- tls_security: str = None,
- tls_server_name: str = None,
- wait_until_available: int = 30,
- timeout: int = 10,
- **kwargs,
- ):
- super().__init__()
- connect_args = {
- "dsn": dsn,
- "host": host,
- "port": port,
- "credentials": credentials,
- "credentials_file": credentials_file,
- "user": user,
- "password": password,
- "secret_key": secret_key,
- "database": database,
- "branch": branch,
- "timeout": timeout,
- "tls_ca": tls_ca,
- "tls_ca_file": tls_ca_file,
- "tls_security": tls_security,
- "tls_server_name": tls_server_name,
- "wait_until_available": wait_until_available,
- }
-
- self._impl = self._impl_class(
- connect_args,
- connection_class=connection_class,
- max_concurrency=max_concurrency,
- **kwargs,
- )
-
- def _shallow_clone(self):
- new_client = self.__class__.__new__(self.__class__)
- new_client._impl = self._impl
- return new_client
-
- def _get_query_cache(self) -> abstract.QueryCache:
- return abstract.QueryCache(
- codecs_registry=self._impl.codecs_registry,
- query_cache=self._impl.query_cache,
- )
-
- def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]:
- return self._options.retry_options
-
- def _get_state(self) -> _options.State:
- return self._options.state
-
- def _get_warning_handler(self) -> _options.WarningHandler:
- return self._options.warning_handler
-
- @property
- def max_concurrency(self) -> int:
- """Max number of connections in the pool."""
-
- return self._impl.get_max_concurrency()
-
- @property
- def free_size(self) -> int:
- """Number of available connections in the pool."""
-
- return self._impl.get_free_size()
-
- async def _query(self, query_context: abstract.QueryContext):
- con = await self._impl.acquire()
- try:
- return await con.raw_query(query_context)
- finally:
- await self._impl.release(con)
-
- async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- con = await self._impl.acquire()
- try:
- await con._execute(execute_context)
- finally:
- await self._impl.release(con)
-
- async def _describe(
- self, describe_context: abstract.DescribeContext
- ) -> abstract.DescribeResult:
- con = await self._impl.acquire()
- try:
- return await con.describe(describe_context)
- finally:
- await self._impl.release(con)
-
- def terminate(self):
- """Terminate all connections in the pool."""
- self._impl.terminate()
+# Auto-generated shim
+import gel.base_client as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.base_client']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py
index f1706e59..12788a48 100644
--- a/edgedb/blocking_client.py
+++ b/edgedb/blocking_client.py
@@ -1,490 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import contextlib
-import datetime
-import queue
-import socket
-import ssl
-import threading
-import time
-import typing
-
-from . import abstract
-from . import base_client
-from . import con_utils
-from . import errors
-from . import transaction
-from .protocol import blocking_proto
-from .protocol.protocol import InputLanguage, OutputFormat
-
-
-DEFAULT_PING_BEFORE_IDLE_TIMEOUT = datetime.timedelta(seconds=5)
-MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1)
-
-
-class BlockingIOConnection(base_client.BaseConnection):
- __slots__ = ("_ping_wait_time",)
-
- async def connect_addr(self, addr, timeout):
- deadline = time.monotonic() + timeout
-
- if isinstance(addr, str):
- # UNIX socket
- res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)]
- else:
- host, port = addr
- try:
- # getaddrinfo() doesn't take timeout!!
- res_list = socket.getaddrinfo(
- host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
- )
- except socket.gaierror as e:
- # All name resolution errors are considered temporary
- err = errors.ClientConnectionFailedTemporarilyError(str(e))
- raise err from e
-
- for i, res in enumerate(res_list):
- af, socktype, proto, _, sa = res
- try:
- sock = socket.socket(af, socktype, proto)
- except OSError as e:
- sock.close()
- if i < len(res_list) - 1:
- continue
- else:
- raise con_utils.wrap_error(e) from e
- try:
- await self._connect_addr(sock, addr, sa, deadline)
- except TimeoutError:
- raise
- except Exception:
- if i < len(res_list) - 1:
- continue
- else:
- raise
- else:
- break
-
- async def _connect_addr(self, sock, addr, sa, deadline):
- try:
- time_left = deadline - time.monotonic()
- if time_left <= 0:
- raise TimeoutError
- try:
- sock.settimeout(time_left)
- sock.connect(sa)
- except OSError as e:
- raise con_utils.wrap_error(e) from e
-
- if not isinstance(addr, str):
- time_left = deadline - time.monotonic()
- if time_left <= 0:
- raise TimeoutError
- try:
- # Upgrade to TLS
- sock.settimeout(time_left)
- try:
- sock = self._params.ssl_ctx.wrap_socket(
- sock,
- server_hostname=(
- self._params.tls_server_name or addr[0]
- ),
- )
- except ssl.CertificateError as e:
- raise con_utils.wrap_error(e) from e
- except ssl.SSLError as e:
- raise con_utils.wrap_error(e) from e
- else:
- con_utils.check_alpn_protocol(sock)
- except OSError as e:
- raise con_utils.wrap_error(e) from e
-
- time_left = deadline - time.monotonic()
- if time_left <= 0:
- raise TimeoutError
-
- if not isinstance(addr, str):
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
- proto = blocking_proto.BlockingIOProtocol(self._params, sock)
- proto.set_connection(self)
-
- try:
- await proto.wait_for(proto.connect(), time_left)
- except TimeoutError:
- raise
- except OSError as e:
- raise con_utils.wrap_error(e) from e
-
- self._protocol = proto
- self._addr = addr
- self._ping_wait_time = max(
- (
- self.get_settings()
- .get("system_config")
- .session_idle_timeout
- - DEFAULT_PING_BEFORE_IDLE_TIMEOUT
- ),
- MINIMUM_PING_WAIT_TIME,
- ).total_seconds()
-
- except Exception:
- sock.close()
- raise
-
- async def sleep(self, seconds):
- time.sleep(seconds)
-
- def is_closed(self):
- proto = self._protocol
- return not (proto and proto.sock is not None and
- proto.sock.fileno() >= 0 and proto.connected)
-
- async def close(self, timeout=None):
- """Send graceful termination message wait for connection to drop."""
- if not self.is_closed():
- try:
- self._protocol.terminate()
- if timeout is None:
- await self._protocol.wait_for_disconnect()
- else:
- await self._protocol.wait_for(
- self._protocol.wait_for_disconnect(), timeout
- )
- except TimeoutError:
- self.terminate()
- raise errors.QueryTimeoutError()
- except Exception:
- self.terminate()
- raise
- finally:
- self._cleanup()
-
- def _dispatch_log_message(self, msg):
- for cb in self._log_listeners:
- cb(self, msg)
-
- async def raw_query(self, query_context: abstract.QueryContext):
- try:
- if (
- time.monotonic() - self._protocol.last_active_timestamp
- > self._ping_wait_time
- ):
- await self._protocol.ping()
- except (errors.IdleSessionTimeoutError, errors.ClientConnectionError):
- await self.connect()
-
- return await super().raw_query(query_context)
-
-
-class _PoolConnectionHolder(base_client.PoolConnectionHolder):
- __slots__ = ()
- _event_class = threading.Event
-
- async def close(self, *, wait=True, timeout=None):
- if self._con is None:
- return
- await self._con.close(timeout=timeout)
-
- async def wait_until_released(self, timeout=None):
- return self._release_event.wait(timeout)
-
-
-class _PoolImpl(base_client.BasePoolImpl):
- _holder_class = _PoolConnectionHolder
-
- def __init__(
- self,
- connect_args,
- *,
- max_concurrency: typing.Optional[int],
- connection_class,
- ):
- if not issubclass(connection_class, BlockingIOConnection):
- raise TypeError(
- f'connection_class is expected to be a subclass of '
- f'edgedb.blocking_client.BlockingIOConnection, '
- f'got {connection_class}')
- super().__init__(
- connect_args,
- connection_class,
- max_concurrency=max_concurrency,
- )
-
- def _ensure_initialized(self):
- if self._queue is None:
- self._queue = queue.LifoQueue(maxsize=self._max_concurrency)
- self._first_connect_lock = threading.Lock()
- self._resize_holder_pool()
-
- def _set_queue_maxsize(self, maxsize):
- with self._queue.mutex:
- self._queue.maxsize = maxsize
-
- async def _maybe_get_first_connection(self):
- with self._first_connect_lock:
- if self._working_addr is None:
- return await self._get_first_connection()
-
- async def acquire(self, timeout=None):
- self._ensure_initialized()
-
- if self._closing:
- raise errors.InterfaceError('pool is closing')
-
- ch = self._queue.get(timeout=timeout)
- try:
- con = await ch.acquire()
- except Exception:
- self._queue.put_nowait(ch)
- raise
- else:
- # Record the timeout, as we will apply it by default
- # in release().
- ch._timeout = timeout
- return con
-
- async def _release(self, holder):
- if not isinstance(holder._con, BlockingIOConnection):
- raise errors.InterfaceError(
- f'release() received invalid connection: '
- f'{holder._con!r} does not belong to any connection pool'
- )
-
- timeout = None
- return await holder.release(timeout)
-
- async def close(self, timeout=None):
- if self._closed:
- return
- self._closing = True
- try:
- if timeout is None:
- for ch in self._holders:
- await ch.wait_until_released()
- for ch in self._holders:
- await ch.close()
- else:
- deadline = time.monotonic() + timeout
- for ch in self._holders:
- secs = deadline - time.monotonic()
- if secs <= 0:
- raise TimeoutError
- if not await ch.wait_until_released(secs):
- raise TimeoutError
- for ch in self._holders:
- secs = deadline - time.monotonic()
- if secs <= 0:
- raise TimeoutError
- await ch.close(timeout=secs)
- except TimeoutError as e:
- self.terminate()
- raise errors.InterfaceError(
- "client is not fully closed in {} seconds; "
- "terminating now.".format(timeout)
- ) from e
- except Exception:
- self.terminate()
- raise
- finally:
- self._closed = True
- self._closing = False
-
-
-class Iteration(transaction.BaseTransaction, abstract.Executor):
-
- __slots__ = ("_managed", "_lock")
-
- def __init__(self, retry, client, iteration):
- super().__init__(retry, client, iteration)
- self._managed = False
- self._lock = threading.Lock()
-
- def __enter__(self):
- with self._exclusive():
- if self._managed:
- raise errors.InterfaceError(
- 'cannot enter context: already in a `with` block')
- self._managed = True
- return self
-
- def __exit__(self, extype, ex, tb):
- with self._exclusive():
- self._managed = False
- return self._client._iter_coroutine(self._exit(extype, ex))
-
- async def _ensure_transaction(self):
- if not self._managed:
- raise errors.InterfaceError(
- "Only managed retriable transactions are supported. "
- "Use `with transaction:`"
- )
- await super()._ensure_transaction()
-
- def _query(self, query_context: abstract.QueryContext):
- with self._exclusive():
- return self._client._iter_coroutine(super()._query(query_context))
-
- def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- with self._exclusive():
- self._client._iter_coroutine(super()._execute(execute_context))
-
- @contextlib.contextmanager
- def _exclusive(self):
- if not self._lock.acquire(blocking=False):
- raise errors.InterfaceError(
- "concurrent queries within the same transaction "
- "are not allowed"
- )
- try:
- yield
- finally:
- self._lock.release()
-
-
-class Retry(transaction.BaseRetry):
-
- def __iter__(self):
- return self
-
- def __next__(self):
- # Note: when changing this code consider also
- # updating AsyncIORetry.__anext__.
- if self._done:
- raise StopIteration
- if self._next_backoff:
- time.sleep(self._next_backoff)
- self._done = True
- iteration = Iteration(self, self._owner, self._iteration)
- self._iteration += 1
- return iteration
-
-
-class Client(base_client.BaseClient, abstract.Executor):
- """A lazy connection pool.
-
- A Client can be used to manage a set of connections to the database.
- Connections are first acquired from the pool, then used, and then released
- back to the pool. Once a connection is released, it's reset to close all
- open cursors and other resources *except* prepared statements.
-
- Clients are created by calling
- :func:`~edgedb.blocking_client.create_client`.
- """
-
- __slots__ = ()
- _impl_class = _PoolImpl
-
- def _iter_coroutine(self, coro):
- try:
- coro.send(None)
- except StopIteration as ex:
- return ex.value
- finally:
- coro.close()
-
- def _query(self, query_context: abstract.QueryContext):
- return self._iter_coroutine(super()._query(query_context))
-
- def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- self._iter_coroutine(super()._execute(execute_context))
-
- def ensure_connected(self):
- self._iter_coroutine(self._impl.ensure_connected())
- return self
-
- def transaction(self) -> Retry:
- return Retry(self)
-
- def close(self, timeout=None):
- """Attempt to gracefully close all connections in the client.
-
- Wait until all pool connections are released, close them and
- shut down the pool. If any error (including cancellation) occurs
- in ``close()`` the pool will terminate by calling
- Client.terminate() .
- """
- self._iter_coroutine(self._impl.close(timeout))
-
- def __enter__(self):
- return self.ensure_connected()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
- def _describe_query(
- self,
- query: str,
- *,
- inject_type_names: bool = False,
- input_language: InputLanguage = InputLanguage.EDGEQL,
- output_format: OutputFormat = OutputFormat.BINARY,
- expect_one: bool = False,
- ) -> abstract.DescribeResult:
- return self._iter_coroutine(self._describe(abstract.DescribeContext(
- query=query,
- state=self._get_state(),
- inject_type_names=inject_type_names,
- input_language=input_language,
- output_format=output_format,
- expect_one=expect_one,
- )))
-
-
-def create_client(
- dsn=None,
- *,
- max_concurrency=None,
- host: str = None,
- port: int = None,
- credentials: str = None,
- credentials_file: str = None,
- user: str = None,
- password: str = None,
- secret_key: str = None,
- database: str = None,
- branch: str = None,
- tls_ca: str = None,
- tls_ca_file: str = None,
- tls_security: str = None,
- wait_until_available: int = 30,
- timeout: int = 10,
-):
- return Client(
- connection_class=BlockingIOConnection,
- max_concurrency=max_concurrency,
-
- # connect arguments
- dsn=dsn,
- host=host,
- port=port,
- credentials=credentials,
- credentials_file=credentials_file,
- user=user,
- password=password,
- secret_key=secret_key,
- database=database,
- branch=branch,
- tls_ca=tls_ca,
- tls_ca_file=tls_ca_file,
- tls_security=tls_security,
- wait_until_available=wait_until_available,
- timeout=timeout,
- )
+# Auto-generated shim
+import gel.blocking_client as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.blocking_client']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/codegen.py b/edgedb/codegen.py
new file mode 100644
index 00000000..7bae2079
--- /dev/null
+++ b/edgedb/codegen.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.codegen as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.codegen']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/color.py b/edgedb/color.py
index 2b972aaa..c761ee80 100644
--- a/edgedb/color.py
+++ b/edgedb/color.py
@@ -1,61 +1,11 @@
-import os
-import sys
-import warnings
-
-COLOR = None
-
-
-class Color:
- HEADER = ""
- BLUE = ""
- CYAN = ""
- GREEN = ""
- WARNING = ""
- FAIL = ""
- ENDC = ""
- BOLD = ""
- UNDERLINE = ""
-
-
-def get_color() -> Color:
- global COLOR
-
- if COLOR is None:
- COLOR = Color()
- if type(USE_COLOR) is bool:
- use_color = USE_COLOR
- else:
- try:
- use_color = USE_COLOR()
- except Exception:
- use_color = False
- if use_color:
- COLOR.HEADER = '\033[95m'
- COLOR.BLUE = '\033[94m'
- COLOR.CYAN = '\033[96m'
- COLOR.GREEN = '\033[92m'
- COLOR.WARNING = '\033[93m'
- COLOR.FAIL = '\033[91m'
- COLOR.ENDC = '\033[0m'
- COLOR.BOLD = '\033[1m'
- COLOR.UNDERLINE = '\033[4m'
-
- return COLOR
-
-
-try:
- USE_COLOR = {
- "default": lambda: sys.stderr.isatty(),
- "auto": lambda: sys.stderr.isatty(),
- "enabled": True,
- "disabled": False,
- }[
- os.getenv("EDGEDB_COLOR_OUTPUT", "default")
- ]
-except KeyError:
- warnings.warn(
- "EDGEDB_COLOR_OUTPUT can only be one of: "
- "default, auto, enabled or disabled",
- stacklevel=1,
- )
- USE_COLOR = False
+# Auto-generated shim
+import gel.color as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.color']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py
index 0863ead9..cad210b4 100644
--- a/edgedb/con_utils.py
+++ b/edgedb/con_utils.py
@@ -1,1278 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import base64
-import binascii
-import errno
-import json
-import os
-import re
-import ssl
-import typing
-import urllib.parse
-import warnings
-import hashlib
-
-from . import errors
-from . import credentials as cred_utils
-from . import platform
-
-
-EDGEDB_PORT = 5656
-ERRNO_RE = re.compile(r"\[Errno (\d+)\]")
-TEMPORARY_ERRORS = (
- ConnectionAbortedError,
- ConnectionRefusedError,
- ConnectionResetError,
- FileNotFoundError,
-)
-TEMPORARY_ERROR_CODES = frozenset({
- errno.ECONNREFUSED,
- errno.ECONNABORTED,
- errno.ECONNRESET,
- errno.ENOENT,
-})
-
-ISO_SECONDS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S')
-ISO_MINUTES_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M')
-ISO_HOURS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H')
-ISO_UNITLESS_HOURS_RE = re.compile(r'^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$')
-ISO_DAYS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D')
-ISO_WEEKS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W')
-ISO_MONTHS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M')
-ISO_YEARS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y')
-
-HUMAN_HOURS_RE = re.compile(
- r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))',
-)
-HUMAN_MINUTES_RE = re.compile(
- r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))',
-)
-HUMAN_SECONDS_RE = re.compile(
- r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))',
-)
-HUMAN_MS_RE = re.compile(
- r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))',
-)
-HUMAN_US_RE = re.compile(
- r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))',
-)
-INSTANCE_NAME_RE = re.compile(
- r'^(\w(?:-?\w)*)$',
- re.ASCII,
-)
-CLOUD_INSTANCE_NAME_RE = re.compile(
- r'^([A-Za-z0-9_-](?:-?[A-Za-z0-9_])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$',
- re.ASCII,
-)
-DSN_RE = re.compile(
- r'^[a-z]+://',
- re.IGNORECASE,
-)
-DOMAIN_LABEL_MAX_LENGTH = 63
-
-
-class ClientConfiguration(typing.NamedTuple):
-
- connect_timeout: float
- command_timeout: float
- wait_until_available: float
-
-
-def _validate_port_spec(hosts, port):
- if isinstance(port, list):
- # If there is a list of ports, its length must
- # match that of the host list.
- if len(port) != len(hosts):
- raise errors.InterfaceError(
- 'could not match {} port numbers to {} hosts'.format(
- len(port), len(hosts)))
- else:
- port = [port for _ in range(len(hosts))]
-
- return port
-
-
-def _parse_hostlist(hostlist, port):
- if ',' in hostlist:
- # A comma-separated list of host addresses.
- hostspecs = hostlist.split(',')
- else:
- hostspecs = [hostlist]
-
- hosts = []
- hostlist_ports = []
-
- if not port:
- portspec = _getenv('PORT')
- if portspec:
- if ',' in portspec:
- default_port = [int(p) for p in portspec.split(',')]
- else:
- default_port = int(portspec)
- else:
- default_port = EDGEDB_PORT
-
- default_port = _validate_port_spec(hostspecs, default_port)
-
- else:
- port = _validate_port_spec(hostspecs, port)
-
- for i, hostspec in enumerate(hostspecs):
- addr, _, hostspec_port = hostspec.partition(':')
- hosts.append(addr)
-
- if not port:
- if hostspec_port:
- hostlist_ports.append(int(hostspec_port))
- else:
- hostlist_ports.append(default_port[i])
-
- if not port:
- port = hostlist_ports
-
- return hosts, port
-
-
-def _hash_path(path):
- path = os.path.realpath(path)
- if platform.IS_WINDOWS and not path.startswith('\\\\'):
- path = '\\\\?\\' + path
- return hashlib.sha1(str(path).encode('utf-8')).hexdigest()
-
-
-def _stash_path(path):
- base_name = os.path.basename(path)
- dir_name = base_name + '-' + _hash_path(path)
- return platform.search_config_dir('projects', dir_name)
-
-
-def _validate_tls_security(val: str) -> str:
- val = val.lower()
- if val not in {"insecure", "no_host_verification", "strict", "default"}:
- raise ValueError(
- "tls_security can only be one of "
- "`insecure`, `no_host_verification`, `strict` or `default`"
- )
-
- return val
-
-
-def _getenv_and_key(key: str) -> typing.Tuple[typing.Optional[str], str]:
- edgedb_key = f'EDGEDB_{key}'
- edgedb_val = os.getenv(edgedb_key)
- gel_key = f'GEL_{key}'
- gel_val = os.getenv(gel_key)
- if edgedb_val is not None and gel_val is not None:
- warnings.warn(
- f'Both {gel_key} and {edgedb_key} are set; '
- f'{edgedb_key} will be ignored',
- stacklevel=1,
- )
-
- if gel_val is None and edgedb_val is not None:
- return edgedb_val, edgedb_key
- else:
- return gel_val, gel_key
-
-
-def _getenv(key: str) -> typing.Optional[str]:
- return _getenv_and_key(key)[0]
-
-
-class ResolvedConnectConfig:
- _host = None
- _host_source = None
-
- _port = None
- _port_source = None
-
- # We keep track of database and branch separately, because we want to make
- # sure that we don't use both at the same time on the same configuration
- # level.
- _database = None
- _database_source = None
-
- _branch = None
- _branch_source = None
-
- _user = None
- _user_source = None
-
- _password = None
- _password_source = None
-
- _secret_key = None
- _secret_key_source = None
-
- _tls_ca_data = None
- _tls_ca_data_source = None
-
- _tls_server_name = None
- _tls_security = None
- _tls_security_source = None
-
- _wait_until_available = None
-
- _cloud_profile = None
- _cloud_profile_source = None
-
- server_settings = {}
-
- def _set_param(self, param, value, source, validator=None):
- param_name = '_' + param
- if getattr(self, param_name) is None:
- setattr(self, param_name + '_source', source)
- if value is not None:
- setattr(
- self,
- param_name,
- validator(value) if validator else value
- )
-
- def set_host(self, host, source):
- self._set_param('host', host, source, _validate_host)
-
- def set_port(self, port, source):
- self._set_param('port', port, source, _validate_port)
-
- def set_database(self, database, source):
- self._set_param('database', database, source, _validate_database)
-
- def set_branch(self, branch, source):
- self._set_param('branch', branch, source, _validate_branch)
-
- def set_user(self, user, source):
- self._set_param('user', user, source, _validate_user)
-
- def set_password(self, password, source):
- self._set_param('password', password, source)
-
- def set_secret_key(self, secret_key, source):
- self._set_param('secret_key', secret_key, source)
-
- def set_tls_ca_data(self, ca_data, source):
- self._set_param('tls_ca_data', ca_data, source)
-
- def set_tls_ca_file(self, ca_file, source):
- def read_ca_file(file_path):
- with open(file_path) as f:
- return f.read()
-
- self._set_param('tls_ca_data', ca_file, source, read_ca_file)
-
- def set_tls_server_name(self, ca_data, source):
- self._set_param('tls_server_name', ca_data, source)
-
- def set_tls_security(self, security, source):
- self._set_param('tls_security', security, source,
- _validate_tls_security)
-
- def set_wait_until_available(self, wait_until_available, source):
- self._set_param(
- 'wait_until_available',
- wait_until_available,
- source,
- _validate_wait_until_available,
- )
-
- def add_server_settings(self, server_settings):
- _validate_server_settings(server_settings)
- self.server_settings = {**server_settings, **self.server_settings}
-
- @property
- def address(self):
- return (
- self._host if self._host else 'localhost',
- self._port if self._port else 5656
- )
-
- # The properties actually merge database and branch, but "default" is
- # different. If you need to know the underlying config use the _database
- # and _branch.
- @property
- def database(self):
- return (
- self._database if self._database else
- self._branch if self._branch else
- 'edgedb'
- )
-
- @property
- def branch(self):
- return (
- self._database if self._database else
- self._branch if self._branch else
- '__default__'
- )
-
- @property
- def user(self):
- return self._user if self._user else 'edgedb'
-
- @property
- def password(self):
- return self._password
-
- @property
- def secret_key(self):
- return self._secret_key
-
- @property
- def tls_server_name(self):
- return self._tls_server_name
-
- @property
- def tls_security(self):
- tls_security = self._tls_security or 'default'
- security, security_key = _getenv_and_key('CLIENT_SECURITY')
- security = security or 'default'
- if security not in {'default', 'insecure_dev_mode', 'strict'}:
- raise ValueError(
- f'environment variable {security_key} should be '
- f'one of strict, insecure_dev_mode or default, '
- f'got: {security!r}')
-
- if security == 'default':
- pass
- elif security == 'insecure_dev_mode':
- if tls_security == 'default':
- tls_security = 'insecure'
- elif security == 'strict':
- if tls_security == 'default':
- tls_security = 'strict'
- elif tls_security in {'no_host_verification', 'insecure'}:
- raise ValueError(
- f'{security_key}=strict but '
- f'tls_security={tls_security}, tls_security must be '
- f'set to strict when {security_key} is strict')
-
- if tls_security != 'default':
- return tls_security
-
- if self._tls_ca_data is not None:
- return "no_host_verification"
-
- return "strict"
-
- _ssl_ctx = None
-
- @property
- def ssl_ctx(self):
- if (self._ssl_ctx):
- return self._ssl_ctx
-
- self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
-
- if self._tls_ca_data:
- self._ssl_ctx.load_verify_locations(
- cadata=self._tls_ca_data
- )
- else:
- self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
- if platform.IS_WINDOWS:
- import certifi
- self._ssl_ctx.load_verify_locations(cafile=certifi.where())
-
- tls_security = self.tls_security
- self._ssl_ctx.check_hostname = tls_security == "strict"
-
- if tls_security in {"strict", "no_host_verification"}:
- self._ssl_ctx.verify_mode = ssl.CERT_REQUIRED
- else:
- self._ssl_ctx.verify_mode = ssl.CERT_NONE
-
- self._ssl_ctx.set_alpn_protocols(['edgedb-binary'])
-
- return self._ssl_ctx
-
- @property
- def wait_until_available(self):
- return (
- self._wait_until_available
- if self._wait_until_available is not None
- else 30
- )
-
-
-def _validate_host(host):
- if '/' in host:
- raise ValueError('unix socket paths not supported')
- if host == '' or ',' in host:
- raise ValueError(f'invalid host: "{host}"')
- return host
-
-
-def _prepare_host_for_dsn(host):
- host = _validate_host(host)
- if ':' in host:
- # IPv6
- host = f'[{host}]'
- return host
-
-
-def _validate_port(port):
- try:
- if isinstance(port, str):
- port = int(port)
- if not isinstance(port, int):
- raise ValueError()
- except Exception:
- raise ValueError(f'invalid port: {port}, not an integer')
- if port < 1 or port > 65535:
- raise ValueError(f'invalid port: {port}, must be between 1 and 65535')
- return port
-
-
-def _validate_database(database):
- if database == '':
- raise ValueError(f'invalid database name: {database}')
- return database
-
-
-def _validate_branch(branch):
- if branch == '':
- raise ValueError(f'invalid branch name: {branch}')
- return branch
-
-
-def _validate_user(user):
- if user == '':
- raise ValueError(f'invalid user name: {user}')
- return user
-
-
-def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]:
- s = string
- total = 0
- match = rgex.search(string)
- if match:
- total += float(match.group(1))
- s = s.replace(match.group(0), "", 1)
-
- return (total, s)
-
-
-def _parse_iso_duration(string: str) -> typing.Union[float, int]:
- if not string.startswith("PT"):
- raise ValueError(f"invalid duration {string!r}")
-
- time = string[2:]
- match = ISO_UNITLESS_HOURS_RE.search(time)
- if match:
- hours = float(match.group(0))
- return 3600 * hours
-
- hours, time = _pop_iso_unit(ISO_HOURS_RE, time)
- minutes, time = _pop_iso_unit(ISO_MINUTES_RE, time)
- seconds, time = _pop_iso_unit(ISO_SECONDS_RE, time)
-
- if time:
- raise ValueError(f'invalid duration {string!r}')
-
- return 3600 * hours + 60 * minutes + seconds
-
-
-def _remove_white_space(s: str) -> str:
- return ''.join(c for c in s if not c.isspace())
-
-
-def _pop_human_duration_unit(
- rgex: re.Pattern,
- string: str,
-) -> typing.Tuple[float, bool, str]:
- match = rgex.search(string)
- if not match:
- return 0, False, string
-
- number = 0
- if match.group(1):
- literal = _remove_white_space(match.group(1))
- if literal.endswith('.'):
- return 0, False, string
-
- if literal.startswith('-.'):
- return 0, False, string
-
- number = float(literal)
- string = string.replace(
- match.group(0),
- match.group(2) or match.group(3) or "",
- 1,
- )
-
- return number, True, string
-
-
-def _parse_human_duration(string: str) -> float:
- found = False
-
- hour, f, s = _pop_human_duration_unit(HUMAN_HOURS_RE, string)
- found |= f
-
- minute, f, s = _pop_human_duration_unit(HUMAN_MINUTES_RE, s)
- found |= f
-
- second, f, s = _pop_human_duration_unit(HUMAN_SECONDS_RE, s)
- found |= f
-
- ms, f, s = _pop_human_duration_unit(HUMAN_MS_RE, s)
- found |= f
-
- us, f, s = _pop_human_duration_unit(HUMAN_US_RE, s)
- found |= f
-
- if s.strip() or not found:
- raise ValueError(f'invalid duration {string!r}')
-
- return 3600 * hour + 60 * minute + second + 0.001 * ms + 0.000001 * us
-
-
-def _parse_duration_str(string: str) -> float:
- if string.startswith('PT'):
- return _parse_iso_duration(string)
- return _parse_human_duration(string)
-
-
-def _validate_wait_until_available(wait_until_available):
- if isinstance(wait_until_available, str):
- return _parse_duration_str(wait_until_available)
-
- if isinstance(wait_until_available, (int, float)):
- return wait_until_available
-
- raise ValueError(f"invalid duration {wait_until_available!r}")
-
-
-def _validate_server_settings(server_settings):
- if (
- not isinstance(server_settings, dict) or
- not all(isinstance(k, str) for k in server_settings) or
- not all(isinstance(v, str) for v in server_settings.values())
- ):
- raise ValueError(
- 'server_settings is expected to be None or '
- 'a Dict[str, str]')
-
-
-def _parse_connect_dsn_and_args(
- *,
- dsn,
- host,
- port,
- credentials,
- credentials_file,
- user,
- password,
- secret_key,
- database,
- branch,
- tls_ca,
- tls_ca_file,
- tls_security,
- tls_server_name,
- server_settings,
- wait_until_available,
-):
- resolved_config = ResolvedConnectConfig()
-
- if dsn and DSN_RE.match(dsn):
- instance_name = None
- else:
- instance_name, dsn = dsn, None
-
- def _get(key: str) -> typing.Optional[typing.Tuple[str, str]]:
- val, env = _getenv_and_key(key)
- return (
- (val, f'"{env}" environment variable')
- if val is not None else None
- )
-
- # The cloud profile is potentially relevant to resolving credentials at
- # any stage, including the config stage when other environment variables
- # are not yet read.
- cloud_profile_tuple = _get('CLOUD_PROFILE')
- cloud_profile = cloud_profile_tuple[0] if cloud_profile_tuple else None
-
- has_compound_options = _resolve_config_options(
- resolved_config,
- 'Cannot have more than one of the following connection options: '
- + '"dsn", "credentials", "credentials_file" or "host"/"port"',
- dsn=(dsn, '"dsn" option') if dsn is not None else None,
- instance_name=(
- (instance_name, '"dsn" option (parsed as instance name)')
- if instance_name is not None else None
- ),
- credentials=(
- (credentials, '"credentials" option')
- if credentials is not None else None
- ),
- credentials_file=(
- (credentials_file, '"credentials_file" option')
- if credentials_file is not None else None
- ),
- host=(host, '"host" option') if host is not None else None,
- port=(port, '"port" option') if port is not None else None,
- database=(
- (database, '"database" option')
- if database is not None else None
- ),
- branch=(
- (branch, '"branch" option')
- if branch is not None else None
- ),
- user=(user, '"user" option') if user is not None else None,
- password=(
- (password, '"password" option')
- if password is not None else None
- ),
- secret_key=(
- (secret_key, '"secret_key" option')
- if secret_key is not None else None
- ),
- tls_ca=(
- (tls_ca, '"tls_ca" option')
- if tls_ca is not None else None
- ),
- tls_ca_file=(
- (tls_ca_file, '"tls_ca_file" option')
- if tls_ca_file is not None else None
- ),
- tls_security=(
- (tls_security, '"tls_security" option')
- if tls_security is not None else None
- ),
- tls_server_name=(
- (tls_server_name, '"tls_server_name" option')
- if tls_server_name is not None else None
- ),
- server_settings=(
- (server_settings, '"server_settings" option')
- if server_settings is not None else None
- ),
- wait_until_available=(
- (wait_until_available, '"wait_until_available" option')
- if wait_until_available is not None else None
- ),
- cloud_profile=cloud_profile_tuple,
- )
-
- if has_compound_options is False:
- env_port_tuple = _get("PORT")
- if (
- resolved_config._port is None
- and env_port_tuple
- and env_port_tuple[0].startswith('tcp://')
- ):
- # EDGEDB_PORT is set by 'docker --link' so ignore and warn
- warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' +
- 'so will be ignored', stacklevel=1)
- env_port_tuple = None
-
- has_compound_options = _resolve_config_options(
- resolved_config,
- # XXX
- 'Cannot have more than one of the following connection '
- + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", '
- + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"',
- dsn=_get('DSN'),
- instance_name=_get('INSTANCE'),
- credentials_file=_get('CREDENTIALS_FILE'),
- host=_get('HOST'),
- port=env_port_tuple,
- database=_get('DATABASE'),
- branch=_get('BRANCH'),
- user=_get('USER'),
- password=_get('PASSWORD'),
- secret_key=_get('SECRET_KEY'),
- tls_ca=_get('TLS_CA'),
- tls_ca_file=_get('TLS_CA_FILE'),
- tls_security=_get('CLIENT_TLS_SECURITY'),
- tls_server_name=_get('TLS_SERVER_NAME'),
- wait_until_available=_get('WAIT_UNTIL_AVAILABLE'),
- )
-
- if not has_compound_options:
- dir = find_edgedb_project_dir()
- stash_dir = _stash_path(dir)
- if os.path.exists(stash_dir):
- with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f:
- instance_name = f.read().strip()
- cloud_profile_file = os.path.join(stash_dir, 'cloud-profile')
- if os.path.exists(cloud_profile_file):
- with open(cloud_profile_file, 'rt') as f:
- cloud_profile = f.read().strip()
- else:
- cloud_profile = None
-
- _resolve_config_options(
- resolved_config,
- '',
- instance_name=(
- instance_name,
- f'project linked instance ("{instance_name}")'
- ),
- cloud_profile=(
- cloud_profile,
- f'project defined cloud profile ("{cloud_profile}")'
- ),
- )
-
- opt_database_file = os.path.join(stash_dir, 'database')
- if os.path.exists(opt_database_file):
- with open(opt_database_file, 'rt') as f:
- database = f.read().strip()
- resolved_config.set_database(database, "project")
- else:
- raise errors.ClientConnectionError(
- f'Found `edgedb.toml` but the project is not initialized. '
- f'Run `edgedb project init`.'
- )
-
- return resolved_config
-
-
-def _parse_dsn_into_config(
- resolved_config: ResolvedConnectConfig,
- dsn: typing.Tuple[str, str]
-):
- dsn_str, source = dsn
-
- try:
- parsed = urllib.parse.urlparse(dsn_str)
- host = (
- urllib.parse.unquote(parsed.hostname) if parsed.hostname else None
- )
- port = parsed.port
- database = parsed.path
- user = parsed.username
- password = parsed.password
- except Exception as e:
- raise ValueError(f'invalid DSN or instance name: {str(e)}')
-
- if parsed.scheme != 'edgedb':
- raise ValueError(
- f'invalid DSN: scheme is expected to be '
- f'"edgedb", got {parsed.scheme!r}')
-
- query = (
- urllib.parse.parse_qs(parsed.query, keep_blank_values=True)
- if parsed.query != ''
- else {}
- )
- for key, val in query.items():
- if isinstance(val, list):
- if len(val) > 1:
- raise ValueError(
- f'invalid DSN: duplicate query parameter {key}')
- query[key] = val[-1]
-
- def handle_dsn_part(
- paramName, value, currentValue, setter,
- formatter=lambda val: val
- ):
- param_values = [
- (value if value != '' else None),
- query.get(paramName),
- query.get(paramName + '_env'),
- query.get(paramName + '_file')
- ]
- if len([p for p in param_values if p is not None]) > 1:
- raise ValueError(
- f'invalid DSN: more than one of ' +
- f'{(paramName + ", ") if value else ""}' +
- f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' +
- f'was specified'
- )
-
- if currentValue is None:
- param = (
- value if (value is not None and value != '')
- else query.get(paramName)
- )
- paramSource = source
-
- if param is None:
- env = query.get(paramName + '_env')
- if env is not None:
- param = os.getenv(env)
- if param is None:
- raise ValueError(
- f'{paramName}_env environment variable "{env}" ' +
- f'doesn\'t exist')
- paramSource = paramSource + f' ({paramName}_env: {env})'
- if param is None:
- filename = query.get(paramName + '_file')
- if filename is not None:
- with open(filename) as f:
- param = f.read()
- paramSource = (
- paramSource + f' ({paramName}_file: {filename})'
- )
-
- param = formatter(param) if param is not None else None
-
- setter(param, paramSource)
-
- query.pop(paramName, None)
- query.pop(paramName + '_env', None)
- query.pop(paramName + '_file', None)
-
- handle_dsn_part(
- 'host', host, resolved_config._host, resolved_config.set_host
- )
-
- handle_dsn_part(
- 'port', port, resolved_config._port, resolved_config.set_port
- )
-
- def strip_leading_slash(str):
- return str[1:] if str.startswith('/') else str
-
- if (
- 'branch' in query or
- 'branch_env' in query or
- 'branch_file' in query
- ):
- if (
- 'database' in query or
- 'database_env' in query or
- 'database_file' in query
- ):
- raise ValueError(
- f"invalid DSN: `database` and `branch` cannot be present "
- f"at the same time"
- )
- if resolved_config._database is None:
- # Only update the config if 'database' has not been already
- # resolved.
- handle_dsn_part(
- 'branch', strip_leading_slash(database),
- resolved_config._branch, resolved_config.set_branch,
- strip_leading_slash
- )
- else:
- # Clean up the query, if config already has 'database'
- query.pop('branch', None)
- query.pop('branch_env', None)
- query.pop('branch_file', None)
-
- else:
- if resolved_config._branch is None:
- # Only update the config if 'branch' has not been already
- # resolved.
- handle_dsn_part(
- 'database', strip_leading_slash(database),
- resolved_config._database, resolved_config.set_database,
- strip_leading_slash
- )
- else:
- # Clean up the query, if config already has 'branch'
- query.pop('database', None)
- query.pop('database_env', None)
- query.pop('database_file', None)
-
- handle_dsn_part(
- 'user', user, resolved_config._user, resolved_config.set_user
- )
-
- handle_dsn_part(
- 'password', password,
- resolved_config._password, resolved_config.set_password
- )
-
- handle_dsn_part(
- 'secret_key', None,
- resolved_config._secret_key, resolved_config.set_secret_key
- )
-
- handle_dsn_part(
- 'tls_ca_file', None,
- resolved_config._tls_ca_data, resolved_config.set_tls_ca_file
- )
-
- handle_dsn_part(
- 'tls_server_name', None,
- resolved_config._tls_server_name,
- resolved_config.set_tls_server_name
- )
-
- handle_dsn_part(
- 'tls_security', None,
- resolved_config._tls_security,
- resolved_config.set_tls_security
- )
-
- handle_dsn_part(
- 'wait_until_available', None,
- resolved_config._wait_until_available,
- resolved_config.set_wait_until_available
- )
-
- resolved_config.add_server_settings(query)
-
-
-def _jwt_base64_decode(payload):
- remainder = len(payload) % 4
- if remainder == 2:
- payload += '=='
- elif remainder == 3:
- payload += '='
- elif remainder != 0:
- raise errors.ClientConnectionError("Invalid secret key")
- payload = base64.urlsafe_b64decode(payload.encode("utf-8"))
- return json.loads(payload.decode("utf-8"))
-
-
-def _parse_cloud_instance_name_into_config(
- resolved_config: ResolvedConnectConfig,
- source: str,
- org_slug: str,
- instance_name: str,
-):
- org_slug = org_slug.lower()
- instance_name = instance_name.lower()
-
- label = f"{instance_name}--{org_slug}"
- if len(label) > DOMAIN_LABEL_MAX_LENGTH:
- raise ValueError(
- f"invalid instance name: cloud instance name length cannot exceed "
- f"{DOMAIN_LABEL_MAX_LENGTH - 1} characters: "
- f"{org_slug}/{instance_name}"
- )
- secret_key = resolved_config.secret_key
- if secret_key is None:
- try:
- config_dir = platform.config_dir()
- if resolved_config._cloud_profile is None:
- profile = profile_src = "default"
- else:
- profile = resolved_config._cloud_profile
- profile_src = resolved_config._cloud_profile_source
- path = config_dir / "cloud-credentials" / f"{profile}.json"
- with open(path, "rt") as f:
- secret_key = json.load(f)["secret_key"]
- except Exception:
- raise errors.ClientConnectionError(
- "Cannot connect to cloud instances without secret key."
- )
- resolved_config.set_secret_key(
- secret_key,
- f"cloud-credentials/{profile}.json specified by {profile_src}",
- )
- try:
- dns_zone = _jwt_base64_decode(secret_key.split(".", 2)[1])["iss"]
- except errors.EdgeDBError:
- raise
- except Exception:
- raise errors.ClientConnectionError("Invalid secret key")
- payload = f"{org_slug}/{instance_name}".encode("utf-8")
- dns_bucket = binascii.crc_hqx(payload, 0) % 100
- host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}"
- resolved_config.set_host(host, source)
-
-
-def _resolve_config_options(
- resolved_config: ResolvedConnectConfig,
- compound_error: str,
- *,
- dsn=None,
- instance_name=None,
- credentials=None,
- credentials_file=None,
- host=None,
- port=None,
- database=None,
- branch=None,
- user=None,
- password=None,
- secret_key=None,
- tls_ca=None,
- tls_ca_file=None,
- tls_security=None,
- tls_server_name=None,
- server_settings=None,
- wait_until_available=None,
- cloud_profile=None,
-):
- if database is not None:
- if branch is not None:
- raise errors.ClientConnectionError(
- f"{database[1]} and {branch[1]} are mutually exclusive"
- )
- if resolved_config._branch is None:
- # Only update the config if 'branch' has not been already
- # resolved.
- resolved_config.set_database(*database)
- if branch is not None:
- if resolved_config._database is None:
- # Only update the config if 'database' has not been already
- # resolved.
- resolved_config.set_branch(*branch)
- if user is not None:
- resolved_config.set_user(*user)
- if password is not None:
- resolved_config.set_password(*password)
- if secret_key is not None:
- resolved_config.set_secret_key(*secret_key)
- if tls_ca_file is not None:
- if tls_ca is not None:
- raise errors.ClientConnectionError(
- f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive"
- )
- resolved_config.set_tls_ca_file(*tls_ca_file)
- if tls_ca is not None:
- resolved_config.set_tls_ca_data(*tls_ca)
- if tls_security is not None:
- resolved_config.set_tls_security(*tls_security)
- if tls_server_name is not None:
- resolved_config.set_tls_server_name(*tls_server_name)
- if server_settings is not None:
- resolved_config.add_server_settings(server_settings[0])
- if wait_until_available is not None:
- resolved_config.set_wait_until_available(*wait_until_available)
- if cloud_profile is not None:
- resolved_config._set_param('cloud_profile', *cloud_profile)
-
- compound_params = [
- dsn,
- instance_name,
- credentials,
- credentials_file,
- host or port,
- ]
- compound_params_count = len([p for p in compound_params if p is not None])
-
- if compound_params_count > 1:
- raise errors.ClientConnectionError(compound_error)
-
- elif compound_params_count == 1:
- if dsn is not None or host is not None or port is not None:
- if port is not None:
- resolved_config.set_port(*port)
- if dsn is None:
- dsn = (
- 'edgedb://' +
- (_prepare_host_for_dsn(host[0]) if host else ''),
- host[1] if host is not None else port[1]
- )
- _parse_dsn_into_config(resolved_config, dsn)
- else:
- if credentials_file is not None:
- creds = cred_utils.read_credentials(credentials_file[0])
- source = "credentials"
- elif credentials is not None:
- try:
- cred_data = json.loads(credentials[0])
- except ValueError as e:
- raise RuntimeError(f"cannot read credentials") from e
- else:
- creds = cred_utils.validate_credentials(cred_data)
- source = "credentials"
- elif INSTANCE_NAME_RE.match(instance_name[0]):
- source = instance_name[1]
- creds = cred_utils.read_credentials(
- cred_utils.get_credentials_path(instance_name[0]),
- )
- else:
- name_match = CLOUD_INSTANCE_NAME_RE.match(instance_name[0])
- if name_match is None:
- raise ValueError(
- f'invalid DSN or instance name: "{instance_name[0]}"'
- )
- source = instance_name[1]
- org, inst = name_match.groups()
- _parse_cloud_instance_name_into_config(
- resolved_config, source, org, inst
- )
- return True
-
- resolved_config.set_host(creds.get('host'), source)
- resolved_config.set_port(creds.get('port'), source)
- if 'database' in creds and resolved_config._branch is None:
- # Only update the config if 'branch' has not been already
- # resolved.
- resolved_config.set_database(creds.get('database'), source)
-
- elif 'branch' in creds and resolved_config._database is None:
- # Only update the config if 'database' has not been already
- # resolved.
- resolved_config.set_branch(creds.get('branch'), source)
- resolved_config.set_user(creds.get('user'), source)
- resolved_config.set_password(creds.get('password'), source)
- resolved_config.set_tls_ca_data(creds.get('tls_ca'), source)
- resolved_config.set_tls_security(
- creds.get('tls_security'),
- source
- )
-
- return True
-
- else:
- return False
-
-
-def find_edgedb_project_dir():
- dir = os.getcwd()
- dev = os.stat(dir).st_dev
-
- while True:
- gel_toml = os.path.join(dir, 'gel.toml')
- edgedb_toml = os.path.join(dir, 'edgedb.toml')
- if not os.path.isfile(gel_toml) and not os.path.isfile(edgedb_toml):
- parent = os.path.dirname(dir)
- if parent == dir:
- raise errors.ClientConnectionError(
- f'no `gel.toml` found and '
- f'no connection options specified'
- )
- parent_dev = os.stat(parent).st_dev
- if parent_dev != dev:
- raise errors.ClientConnectionError(
- f'no `gel.toml` found and '
- f'no connection options specified'
- f'(stopped searching for `edgedb.toml` at file system'
- f'boundary {dir!r})'
- )
- dir = parent
- dev = parent_dev
- continue
- return dir
-
-
-def parse_connect_arguments(
- *,
- dsn,
- host,
- port,
- credentials,
- credentials_file,
- database,
- branch,
- user,
- password,
- secret_key,
- tls_ca,
- tls_ca_file,
- tls_security,
- tls_server_name,
- timeout,
- command_timeout,
- wait_until_available,
- server_settings,
-) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]:
-
- if command_timeout is not None:
- try:
- if isinstance(command_timeout, bool):
- raise ValueError
- command_timeout = float(command_timeout)
- if command_timeout <= 0:
- raise ValueError
- except ValueError:
- raise ValueError(
- 'invalid command_timeout value: '
- 'expected greater than 0 float (got {!r})'.format(
- command_timeout)) from None
-
- connect_config = _parse_connect_dsn_and_args(
- dsn=dsn,
- host=host,
- port=port,
- credentials=credentials,
- credentials_file=credentials_file,
- database=database,
- branch=branch,
- user=user,
- password=password,
- secret_key=secret_key,
- tls_ca=tls_ca,
- tls_ca_file=tls_ca_file,
- tls_security=tls_security,
- tls_server_name=tls_server_name,
- server_settings=server_settings,
- wait_until_available=wait_until_available,
- )
-
- client_config = ClientConfiguration(
- connect_timeout=timeout,
- command_timeout=command_timeout,
- wait_until_available=connect_config.wait_until_available,
- )
-
- return connect_config, client_config
-
-
-def check_alpn_protocol(ssl_obj):
- if ssl_obj.selected_alpn_protocol() != 'edgedb-binary':
- raise errors.ClientConnectionFailedError(
- "The server doesn't support the edgedb-binary protocol."
- )
-
-
-def render_client_no_connection_error(prefix, addr, attempts, duration):
- if isinstance(addr, str):
- msg = (
- f'{prefix}'
- f'\n\tAfter {attempts} attempts in {duration:.1f} sec'
- f'\n\tIs the server running locally and accepting '
- f'\n\tconnections on Unix domain socket {addr!r}?'
- )
- else:
- msg = (
- f'{prefix}'
- f'\n\tAfter {attempts} attempts in {duration:.1f} sec'
- f'\n\tIs the server running on host {addr[0]!r} '
- f'and accepting '
- f'\n\tTCP/IP connections on port {addr[1]}?'
- )
- return msg
-
-
-def _extract_errno(s):
- """Extract multiple errnos from error string
-
- When we connect to a host that has multiple underlying IP addresses, say
- ``localhost`` having ``::1`` and ``127.0.0.1``, we get
- ``OSError("Multiple exceptions:...")`` error without ``.errno`` attribute
- set. There are multiple ones in the text, so we extract all of them.
- """
- result = []
- for match in ERRNO_RE.finditer(s):
- result.append(int(match.group(1)))
- if result:
- return result
-
-
-def wrap_error(e):
- message = str(e)
- if e.errno is None:
- errnos = _extract_errno(message)
- else:
- errnos = [e.errno]
-
- if errnos:
- is_temp = any((code in TEMPORARY_ERROR_CODES for code in errnos))
- else:
- is_temp = isinstance(e, TEMPORARY_ERRORS)
-
- if is_temp:
- return errors.ClientConnectionFailedTemporarilyError(message)
- else:
- return errors.ClientConnectionFailedError(message)
+# Auto-generated shim
+import gel.con_utils as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.con_utils']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/credentials.py b/edgedb/credentials.py
index 71146fd9..ea5f0390 100644
--- a/edgedb/credentials.py
+++ b/edgedb/credentials.py
@@ -1,119 +1,11 @@
-import os
-import pathlib
-import typing
-import json
-
-from . import platform
-
-
-class RequiredCredentials(typing.TypedDict, total=True):
- port: int
- user: str
-
-
-class Credentials(RequiredCredentials, total=False):
- host: typing.Optional[str]
- password: typing.Optional[str]
- # It's OK for database and branch to appear in credentials, as long as
- # they match.
- database: typing.Optional[str]
- branch: typing.Optional[str]
- tls_ca: typing.Optional[str]
- tls_security: typing.Optional[str]
-
-
-def get_credentials_path(instance_name: str) -> pathlib.Path:
- return platform.search_config_dir("credentials", instance_name + ".json")
-
-
-def read_credentials(path: os.PathLike) -> Credentials:
- try:
- with open(path, encoding='utf-8') as f:
- credentials = json.load(f)
- return validate_credentials(credentials)
- except Exception as e:
- raise RuntimeError(
- f"cannot read credentials at {path}"
- ) from e
-
-
-def validate_credentials(data: dict) -> Credentials:
- port = data.get('port')
- if port is None:
- port = 5656
- if not isinstance(port, int) or port < 1 or port > 65535:
- raise ValueError("invalid `port` value")
-
- user = data.get('user')
- if user is None:
- raise ValueError("`user` key is required")
- if not isinstance(user, str):
- raise ValueError("`user` must be a string")
-
- result = { # required keys
- "user": user,
- "port": port,
- }
-
- host = data.get('host')
- if host is not None:
- if not isinstance(host, str):
- raise ValueError("`host` must be a string")
- result['host'] = host
-
- database = data.get('database')
- if database is not None:
- if not isinstance(database, str):
- raise ValueError("`database` must be a string")
- result['database'] = database
-
- branch = data.get('branch')
- if branch is not None:
- if not isinstance(branch, str):
- raise ValueError("`branch` must be a string")
- if database is not None and branch != database:
- raise ValueError(
- f"`database` and `branch` cannot be different")
- result['branch'] = branch
-
- password = data.get('password')
- if password is not None:
- if not isinstance(password, str):
- raise ValueError("`password` must be a string")
- result['password'] = password
-
- ca = data.get('tls_ca')
- if ca is not None:
- if not isinstance(ca, str):
- raise ValueError("`tls_ca` must be a string")
- result['tls_ca'] = ca
-
- cert_data = data.get('tls_cert_data')
- if cert_data is not None:
- if not isinstance(cert_data, str):
- raise ValueError("`tls_cert_data` must be a string")
- if ca is not None and ca != cert_data:
- raise ValueError(
- f"tls_ca and tls_cert_data are both set and disagree")
- result['tls_ca'] = cert_data
-
- verify = data.get('tls_verify_hostname')
- if verify is not None:
- if not isinstance(verify, bool):
- raise ValueError("`tls_verify_hostname` must be a bool")
- result['tls_security'] = 'strict' if verify else 'no_host_verification'
-
- tls_security = data.get('tls_security')
- if tls_security is not None:
- if not isinstance(tls_security, str):
- raise ValueError("`tls_security` must be a string")
- result['tls_security'] = tls_security
-
- missmatch = ValueError(f"tls_verify_hostname={verify} and "
- f"tls_security={tls_security} are incompatible")
- if tls_security == "strict" and verify is False:
- raise missmatch
- if tls_security in {"no_host_verification", "insecure"} and verify is True:
- raise missmatch
-
- return result
+# Auto-generated shim
+import gel.credentials as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.credentials']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/datatypes/__init__.py b/edgedb/datatypes/__init__.py
index 73609285..6da01f78 100644
--- a/edgedb/datatypes/__init__.py
+++ b/edgedb/datatypes/__init__.py
@@ -1,17 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
+# Auto-generated shim
+import gel.datatypes as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.datatypes']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/datatypes/datatypes.py b/edgedb/datatypes/datatypes.py
new file mode 100644
index 00000000..05dc5979
--- /dev/null
+++ b/edgedb/datatypes/datatypes.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.datatypes.datatypes as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.datatypes.datatypes']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/datatypes/range.py b/edgedb/datatypes/range.py
index e3fd3d1e..29e3c8cf 100644
--- a/edgedb/datatypes/range.py
+++ b/edgedb/datatypes/range.py
@@ -1,165 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator,
- Sequence)
-
-T = TypeVar("T")
-
-
-class Range(Generic[T]):
-
- __slots__ = ("_lower", "_upper", "_inc_lower", "_inc_upper", "_empty")
-
- def __init__(
- self,
- lower: Optional[T] = None,
- upper: Optional[T] = None,
- *,
- inc_lower: bool = True,
- inc_upper: bool = False,
- empty: bool = False,
- ) -> None:
- self._empty = empty
-
- if empty:
- if (
- lower != upper
- or lower is not None and inc_upper and inc_lower
- ):
- raise ValueError(
- "conflicting arguments in range constructor: "
- "\"empty\" is `true` while the specified bounds "
- "suggest otherwise"
- )
-
- self._lower = self._upper = None
- self._inc_lower = self._inc_upper = False
- else:
- self._lower = lower
- self._upper = upper
- self._inc_lower = lower is not None and inc_lower
- self._inc_upper = upper is not None and inc_upper
-
- @property
- def lower(self) -> Optional[T]:
- return self._lower
-
- @property
- def inc_lower(self) -> bool:
- return self._inc_lower
-
- @property
- def upper(self) -> Optional[T]:
- return self._upper
-
- @property
- def inc_upper(self) -> bool:
- return self._inc_upper
-
- def is_empty(self) -> bool:
- return self._empty
-
- def __bool__(self):
- return not self.is_empty()
-
- def __eq__(self, other) -> bool:
- if isinstance(other, Range):
- o = other
- else:
- return NotImplemented
-
- return (
- self._lower,
- self._upper,
- self._inc_lower,
- self._inc_upper,
- self._empty,
- ) == (
- o._lower,
- o._upper,
- o._inc_lower,
- o._inc_upper,
- o._empty,
- )
-
- def __hash__(self) -> int:
- return hash((
- self._lower,
- self._upper,
- self._inc_lower,
- self._inc_upper,
- self._empty,
- ))
-
- def __str__(self) -> str:
- if self._empty:
- desc = "empty"
- else:
- lb = "(" if not self._inc_lower else "["
- if self._lower is not None:
- lb += repr(self._lower)
-
- if self._upper is not None:
- ub = repr(self._upper)
- else:
- ub = ""
-
- ub += ")" if self._inc_upper else "]"
-
- desc = f"{lb}, {ub}"
-
- return f""
-
- __repr__ = __str__
-
-
-# TODO: maybe we should implement range and multirange operations as well as
-# normalization of the sub-ranges?
-class MultiRange(Iterable[T]):
-
- _ranges: Sequence[T]
-
- def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
- if iterable is not None:
- self._ranges = tuple(iterable)
- else:
- self._ranges = tuple()
-
- def __len__(self) -> int:
- return len(self._ranges)
-
- def __iter__(self) -> Iterator[T]:
- return iter(self._ranges)
-
- def __reversed__(self) -> Iterator[T]:
- return reversed(self._ranges)
-
- def __str__(self) -> str:
- return f''
-
- __repr__ = __str__
-
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, MultiRange):
- return set(self._ranges) == set(other._ranges)
- else:
- return NotImplemented
-
- def __hash__(self) -> int:
- return hash(self._ranges)
+# Auto-generated shim
+import gel.datatypes.range as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.datatypes.range']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/describe.py b/edgedb/describe.py
index 92e38854..4b55e179 100644
--- a/edgedb/describe.py
+++ b/edgedb/describe.py
@@ -1,98 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import dataclasses
-import typing
-import uuid
-
-from . import enums
-
-
-@dataclasses.dataclass(frozen=True)
-class AnyType:
- desc_id: uuid.UUID
- name: typing.Optional[str]
-
-
-@dataclasses.dataclass(frozen=True)
-class Element:
- type: AnyType
- cardinality: enums.Cardinality
- is_implicit: bool
- kind: enums.ElementKind
-
-
-@dataclasses.dataclass(frozen=True)
-class SequenceType(AnyType):
- element_type: AnyType
-
-
-@dataclasses.dataclass(frozen=True)
-class SetType(SequenceType):
- pass
-
-
-@dataclasses.dataclass(frozen=True)
-class ObjectType(AnyType):
- elements: typing.Dict[str, Element]
-
-
-@dataclasses.dataclass(frozen=True)
-class BaseScalarType(AnyType):
- pass
-
-
-@dataclasses.dataclass(frozen=True)
-class ScalarType(AnyType):
- base_type: BaseScalarType
-
-
-@dataclasses.dataclass(frozen=True)
-class TupleType(AnyType):
- element_types: typing.Tuple[AnyType]
-
-
-@dataclasses.dataclass(frozen=True)
-class NamedTupleType(AnyType):
- element_types: typing.Dict[str, AnyType]
-
-
-@dataclasses.dataclass(frozen=True)
-class ArrayType(SequenceType):
- pass
-
-
-@dataclasses.dataclass(frozen=True)
-class EnumType(AnyType):
- members: typing.Tuple[str]
-
-
-@dataclasses.dataclass(frozen=True)
-class SparseObjectType(ObjectType):
- pass
-
-
-@dataclasses.dataclass(frozen=True)
-class RangeType(AnyType):
- value_type: AnyType
-
-
-@dataclasses.dataclass(frozen=True)
-class MultiRangeType(AnyType):
- value_type: AnyType
+# Auto-generated shim
+import gel.describe as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.describe']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/enums.py b/edgedb/enums.py
index 6312f596..feeb69ea 100644
--- a/edgedb/enums.py
+++ b/edgedb/enums.py
@@ -1,74 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2021-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import enum
-
-
-class Capability(enum.IntFlag):
-
- NONE = 0 # noqa
- MODIFICATIONS = 1 << 0 # noqa
- SESSION_CONFIG = 1 << 1 # noqa
- TRANSACTION = 1 << 2 # noqa
- DDL = 1 << 3 # noqa
- PERSISTENT_CONFIG = 1 << 4 # noqa
-
- ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa
- EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa
- LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa
-
-
-class CompilationFlag(enum.IntFlag):
-
- INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa
- INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa
- INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa
-
-
-class Cardinality(enum.Enum):
- # Cardinality isn't applicable for the query:
- # * the query is a command like CONFIGURE that
- # does not return any data;
- # * the query is composed of multiple queries.
- NO_RESULT = 0x6e
-
- # Cardinality is 1 or 0
- AT_MOST_ONE = 0x6f
-
- # Cardinality is 1
- ONE = 0x41
-
- # Cardinality is >= 0
- MANY = 0x6d
-
- # Cardinality is >= 1
- AT_LEAST_ONE = 0x4d
-
- def is_single(self) -> bool:
- return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE}
-
- def is_multi(self) -> bool:
- return self in {Cardinality.AT_LEAST_ONE, Cardinality.MANY}
-
-
-class ElementKind(enum.Enum):
-
- LINK = 1 # noqa
- PROPERTY = 2 # noqa
- LINK_PROPERTY = 3 # noqa
+# Auto-generated shim
+import gel.enums as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.enums']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/errors/__init__.py b/edgedb/errors/__init__.py
index 424edc91..f510b333 100644
--- a/edgedb/errors/__init__.py
+++ b/edgedb/errors/__init__.py
@@ -1,518 +1,11 @@
-# AUTOGENERATED FROM "edb/api/errors.txt" WITH
-# $ edb gen-errors \
-# --import 'from edgedb.errors._base import *\nfrom edgedb.errors.tags import *' \
-# --extra-all "_base.__all__" \
-# --stdout \
-# --client
-
-
-# flake8: noqa
-
-
-from edgedb.errors._base import *
-from edgedb.errors.tags import *
-
-
-__all__ = _base.__all__ + ( # type: ignore
- 'InternalServerError',
- 'UnsupportedFeatureError',
- 'ProtocolError',
- 'BinaryProtocolError',
- 'UnsupportedProtocolVersionError',
- 'TypeSpecNotFoundError',
- 'UnexpectedMessageError',
- 'InputDataError',
- 'ParameterTypeMismatchError',
- 'StateMismatchError',
- 'ResultCardinalityMismatchError',
- 'CapabilityError',
- 'UnsupportedCapabilityError',
- 'DisabledCapabilityError',
- 'QueryError',
- 'InvalidSyntaxError',
- 'EdgeQLSyntaxError',
- 'SchemaSyntaxError',
- 'GraphQLSyntaxError',
- 'InvalidTypeError',
- 'InvalidTargetError',
- 'InvalidLinkTargetError',
- 'InvalidPropertyTargetError',
- 'InvalidReferenceError',
- 'UnknownModuleError',
- 'UnknownLinkError',
- 'UnknownPropertyError',
- 'UnknownUserError',
- 'UnknownDatabaseError',
- 'UnknownParameterError',
- 'SchemaError',
- 'SchemaDefinitionError',
- 'InvalidDefinitionError',
- 'InvalidModuleDefinitionError',
- 'InvalidLinkDefinitionError',
- 'InvalidPropertyDefinitionError',
- 'InvalidUserDefinitionError',
- 'InvalidDatabaseDefinitionError',
- 'InvalidOperatorDefinitionError',
- 'InvalidAliasDefinitionError',
- 'InvalidFunctionDefinitionError',
- 'InvalidConstraintDefinitionError',
- 'InvalidCastDefinitionError',
- 'DuplicateDefinitionError',
- 'DuplicateModuleDefinitionError',
- 'DuplicateLinkDefinitionError',
- 'DuplicatePropertyDefinitionError',
- 'DuplicateUserDefinitionError',
- 'DuplicateDatabaseDefinitionError',
- 'DuplicateOperatorDefinitionError',
- 'DuplicateViewDefinitionError',
- 'DuplicateFunctionDefinitionError',
- 'DuplicateConstraintDefinitionError',
- 'DuplicateCastDefinitionError',
- 'DuplicateMigrationError',
- 'SessionTimeoutError',
- 'IdleSessionTimeoutError',
- 'QueryTimeoutError',
- 'TransactionTimeoutError',
- 'IdleTransactionTimeoutError',
- 'ExecutionError',
- 'InvalidValueError',
- 'DivisionByZeroError',
- 'NumericOutOfRangeError',
- 'AccessPolicyError',
- 'QueryAssertionError',
- 'IntegrityError',
- 'ConstraintViolationError',
- 'CardinalityViolationError',
- 'MissingRequiredError',
- 'TransactionError',
- 'TransactionConflictError',
- 'TransactionSerializationError',
- 'TransactionDeadlockError',
- 'WatchError',
- 'ConfigurationError',
- 'AccessError',
- 'AuthenticationError',
- 'AvailabilityError',
- 'BackendUnavailableError',
- 'ServerOfflineError',
- 'BackendError',
- 'UnsupportedBackendFeatureError',
- 'LogMessage',
- 'WarningMessage',
- 'ClientError',
- 'ClientConnectionError',
- 'ClientConnectionFailedError',
- 'ClientConnectionFailedTemporarilyError',
- 'ClientConnectionTimeoutError',
- 'ClientConnectionClosedError',
- 'InterfaceError',
- 'QueryArgumentError',
- 'MissingArgumentError',
- 'UnknownArgumentError',
- 'InvalidArgumentError',
- 'NoDataError',
- 'InternalClientError',
-)
-
-
-class InternalServerError(EdgeDBError):
- _code = 0x_01_00_00_00
-
-
-class UnsupportedFeatureError(EdgeDBError):
- _code = 0x_02_00_00_00
-
-
-class ProtocolError(EdgeDBError):
- _code = 0x_03_00_00_00
-
-
-class BinaryProtocolError(ProtocolError):
- _code = 0x_03_01_00_00
-
-
-class UnsupportedProtocolVersionError(BinaryProtocolError):
- _code = 0x_03_01_00_01
-
-
-class TypeSpecNotFoundError(BinaryProtocolError):
- _code = 0x_03_01_00_02
-
-
-class UnexpectedMessageError(BinaryProtocolError):
- _code = 0x_03_01_00_03
-
-
-class InputDataError(ProtocolError):
- _code = 0x_03_02_00_00
-
-
-class ParameterTypeMismatchError(InputDataError):
- _code = 0x_03_02_01_00
-
-
-class StateMismatchError(InputDataError):
- _code = 0x_03_02_02_00
- tags = frozenset({SHOULD_RETRY})
-
-
-class ResultCardinalityMismatchError(ProtocolError):
- _code = 0x_03_03_00_00
-
-
-class CapabilityError(ProtocolError):
- _code = 0x_03_04_00_00
-
-
-class UnsupportedCapabilityError(CapabilityError):
- _code = 0x_03_04_01_00
-
-
-class DisabledCapabilityError(CapabilityError):
- _code = 0x_03_04_02_00
-
-
-class QueryError(EdgeDBError):
- _code = 0x_04_00_00_00
-
-
-class InvalidSyntaxError(QueryError):
- _code = 0x_04_01_00_00
-
-
-class EdgeQLSyntaxError(InvalidSyntaxError):
- _code = 0x_04_01_01_00
-
-
-class SchemaSyntaxError(InvalidSyntaxError):
- _code = 0x_04_01_02_00
-
-
-class GraphQLSyntaxError(InvalidSyntaxError):
- _code = 0x_04_01_03_00
-
-
-class InvalidTypeError(QueryError):
- _code = 0x_04_02_00_00
-
-
-class InvalidTargetError(InvalidTypeError):
- _code = 0x_04_02_01_00
-
-
-class InvalidLinkTargetError(InvalidTargetError):
- _code = 0x_04_02_01_01
-
-
-class InvalidPropertyTargetError(InvalidTargetError):
- _code = 0x_04_02_01_02
-
-
-class InvalidReferenceError(QueryError):
- _code = 0x_04_03_00_00
-
-
-class UnknownModuleError(InvalidReferenceError):
- _code = 0x_04_03_00_01
-
-
-class UnknownLinkError(InvalidReferenceError):
- _code = 0x_04_03_00_02
-
-
-class UnknownPropertyError(InvalidReferenceError):
- _code = 0x_04_03_00_03
-
-
-class UnknownUserError(InvalidReferenceError):
- _code = 0x_04_03_00_04
-
-
-class UnknownDatabaseError(InvalidReferenceError):
- _code = 0x_04_03_00_05
-
-
-class UnknownParameterError(InvalidReferenceError):
- _code = 0x_04_03_00_06
-
-
-class SchemaError(QueryError):
- _code = 0x_04_04_00_00
-
-
-class SchemaDefinitionError(QueryError):
- _code = 0x_04_05_00_00
-
-
-class InvalidDefinitionError(SchemaDefinitionError):
- _code = 0x_04_05_01_00
-
-
-class InvalidModuleDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_01
-
-
-class InvalidLinkDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_02
-
-
-class InvalidPropertyDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_03
-
-
-class InvalidUserDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_04
-
-
-class InvalidDatabaseDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_05
-
-
-class InvalidOperatorDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_06
-
-
-class InvalidAliasDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_07
-
-
-class InvalidFunctionDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_08
-
-
-class InvalidConstraintDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_09
-
-
-class InvalidCastDefinitionError(InvalidDefinitionError):
- _code = 0x_04_05_01_0A
-
-
-class DuplicateDefinitionError(SchemaDefinitionError):
- _code = 0x_04_05_02_00
-
-
-class DuplicateModuleDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_01
-
-
-class DuplicateLinkDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_02
-
-
-class DuplicatePropertyDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_03
-
-
-class DuplicateUserDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_04
-
-
-class DuplicateDatabaseDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_05
-
-
-class DuplicateOperatorDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_06
-
-
-class DuplicateViewDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_07
-
-
-class DuplicateFunctionDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_08
-
-
-class DuplicateConstraintDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_09
-
-
-class DuplicateCastDefinitionError(DuplicateDefinitionError):
- _code = 0x_04_05_02_0A
-
-
-class DuplicateMigrationError(DuplicateDefinitionError):
- _code = 0x_04_05_02_0B
-
-
-class SessionTimeoutError(QueryError):
- _code = 0x_04_06_00_00
-
-
-class IdleSessionTimeoutError(SessionTimeoutError):
- _code = 0x_04_06_01_00
- tags = frozenset({SHOULD_RETRY})
-
-
-class QueryTimeoutError(SessionTimeoutError):
- _code = 0x_04_06_02_00
-
-
-class TransactionTimeoutError(SessionTimeoutError):
- _code = 0x_04_06_0A_00
-
-
-class IdleTransactionTimeoutError(TransactionTimeoutError):
- _code = 0x_04_06_0A_01
-
-
-class ExecutionError(EdgeDBError):
- _code = 0x_05_00_00_00
-
-
-class InvalidValueError(ExecutionError):
- _code = 0x_05_01_00_00
-
-
-class DivisionByZeroError(InvalidValueError):
- _code = 0x_05_01_00_01
-
-
-class NumericOutOfRangeError(InvalidValueError):
- _code = 0x_05_01_00_02
-
-
-class AccessPolicyError(InvalidValueError):
- _code = 0x_05_01_00_03
-
-
-class QueryAssertionError(InvalidValueError):
- _code = 0x_05_01_00_04
-
-
-class IntegrityError(ExecutionError):
- _code = 0x_05_02_00_00
-
-
-class ConstraintViolationError(IntegrityError):
- _code = 0x_05_02_00_01
-
-
-class CardinalityViolationError(IntegrityError):
- _code = 0x_05_02_00_02
-
-
-class MissingRequiredError(IntegrityError):
- _code = 0x_05_02_00_03
-
-
-class TransactionError(ExecutionError):
- _code = 0x_05_03_00_00
-
-
-class TransactionConflictError(TransactionError):
- _code = 0x_05_03_01_00
- tags = frozenset({SHOULD_RETRY})
-
-
-class TransactionSerializationError(TransactionConflictError):
- _code = 0x_05_03_01_01
- tags = frozenset({SHOULD_RETRY})
-
-
-class TransactionDeadlockError(TransactionConflictError):
- _code = 0x_05_03_01_02
- tags = frozenset({SHOULD_RETRY})
-
-
-class WatchError(ExecutionError):
- _code = 0x_05_04_00_00
-
-
-class ConfigurationError(EdgeDBError):
- _code = 0x_06_00_00_00
-
-
-class AccessError(EdgeDBError):
- _code = 0x_07_00_00_00
-
-
-class AuthenticationError(AccessError):
- _code = 0x_07_01_00_00
-
-
-class AvailabilityError(EdgeDBError):
- _code = 0x_08_00_00_00
-
-
-class BackendUnavailableError(AvailabilityError):
- _code = 0x_08_00_00_01
- tags = frozenset({SHOULD_RETRY})
-
-
-class ServerOfflineError(AvailabilityError):
- _code = 0x_08_00_00_02
- tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
-
-
-class BackendError(EdgeDBError):
- _code = 0x_09_00_00_00
-
-
-class UnsupportedBackendFeatureError(BackendError):
- _code = 0x_09_00_01_00
-
-
-class LogMessage(EdgeDBMessage):
- _code = 0x_F0_00_00_00
-
-
-class WarningMessage(LogMessage):
- _code = 0x_F0_01_00_00
-
-
-class ClientError(EdgeDBError):
- _code = 0x_FF_00_00_00
-
-
-class ClientConnectionError(ClientError):
- _code = 0x_FF_01_00_00
-
-
-class ClientConnectionFailedError(ClientConnectionError):
- _code = 0x_FF_01_01_00
-
-
-class ClientConnectionFailedTemporarilyError(ClientConnectionFailedError):
- _code = 0x_FF_01_01_01
- tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
-
-
-class ClientConnectionTimeoutError(ClientConnectionError):
- _code = 0x_FF_01_02_00
- tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
-
-
-class ClientConnectionClosedError(ClientConnectionError):
- _code = 0x_FF_01_03_00
- tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
-
-
-class InterfaceError(ClientError):
- _code = 0x_FF_02_00_00
-
-
-class QueryArgumentError(InterfaceError):
- _code = 0x_FF_02_01_00
-
-
-class MissingArgumentError(QueryArgumentError):
- _code = 0x_FF_02_01_01
-
-
-class UnknownArgumentError(QueryArgumentError):
- _code = 0x_FF_02_01_02
-
-
-class InvalidArgumentError(QueryArgumentError):
- _code = 0x_FF_02_01_03
-
-
-class NoDataError(ClientError):
- _code = 0x_FF_03_00_00
-
-
-class InternalClientError(ClientError):
- _code = 0x_FF_04_00_00
-
+# Auto-generated shim
+import gel.errors as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.errors']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py
index 03ce547c..d567937f 100644
--- a/edgedb/errors/_base.py
+++ b/edgedb/errors/_base.py
@@ -1,363 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import io
-import os
-import traceback
-import unicodedata
-import warnings
-
-__all__ = (
- 'EdgeDBError', 'EdgeDBMessage',
-)
-
-
-class Meta(type):
-
- def __new__(mcls, name, bases, dct):
- cls = super().__new__(mcls, name, bases, dct)
-
- code = dct.get('_code')
- if code is not None:
- mcls._index[code] = cls
-
- # If it's a base class add it to the base class index
- b1, b2, b3, b4 = _decode(code)
- if b1 == 0 or b2 == 0 or b3 == 0 or b4 == 0:
- mcls._base_class_index[(b1, b2, b3, b4)] = cls
-
- return cls
-
-
-class EdgeDBMessageMeta(Meta):
-
- _base_class_index = {}
- _index = {}
-
-
-class EdgeDBMessage(Warning, metaclass=EdgeDBMessageMeta):
-
- _code = None
-
- def __init__(self, severity, message):
- super().__init__(message)
- self._severity = severity
-
- def get_severity(self):
- return self._severity
-
- def get_severity_name(self):
- return _severity_name(self._severity)
-
- def get_code(self):
- return self._code
-
- @staticmethod
- def _from_code(code, severity, message, *args, **kwargs):
- cls = _lookup_message_cls(code)
- exc = cls(severity, message, *args, **kwargs)
- exc._code = code
- return exc
-
-
-class EdgeDBErrorMeta(Meta):
-
- _base_class_index = {}
- _index = {}
-
-
-class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta):
-
- _code = None
- _query = None
- tags = frozenset()
-
- def __init__(self, *args, **kwargs):
- self._attrs = {}
- super().__init__(*args, **kwargs)
-
- def has_tag(self, tag):
- return tag in self.tags
-
- @property
- def _position(self):
- # not a stable API method
- return int(self._read_str_field(FIELD_POSITION_START, -1))
-
- @property
- def _position_start(self):
- # not a stable API method
- return int(self._read_str_field(FIELD_CHARACTER_START, -1))
-
- @property
- def _position_end(self):
- # not a stable API method
- return int(self._read_str_field(FIELD_CHARACTER_END, -1))
-
- @property
- def _line(self):
- # not a stable API method
- return int(self._read_str_field(FIELD_LINE_START, -1))
-
- @property
- def _col(self):
- # not a stable API method
- return int(self._read_str_field(FIELD_COLUMN_START, -1))
-
- @property
- def _hint(self):
- # not a stable API method
- return self._read_str_field(FIELD_HINT)
-
- @property
- def _details(self):
- # not a stable API method
- return self._read_str_field(FIELD_DETAILS)
-
- def _read_str_field(self, key, default=None):
- val = self._attrs.get(key)
- if isinstance(val, bytes):
- return val.decode('utf-8')
- elif val is not None:
- return val
- return default
-
- def get_code(self):
- return self._code
-
- def get_server_context(self):
- return self._read_str_field(FIELD_SERVER_TRACEBACK)
-
- @staticmethod
- def _from_code(code, *args, **kwargs):
- cls = _lookup_error_cls(code)
- exc = cls(*args, **kwargs)
- exc._code = code
- return exc
-
- @staticmethod
- def _from_json(data):
- exc = EdgeDBError._from_code(data['code'], data['message'])
- exc._attrs = {
- field: data[name]
- for name, field in _JSON_FIELDS.items()
- if name in data
- }
- return exc
-
- def __str__(self):
- msg = super().__str__()
- if SHOW_HINT and self._query and self._position_start >= 0:
- try:
- return _format_error(
- msg,
- self._query,
- self._position_start,
- max(1, self._position_end - self._position_start),
- self._line if self._line > 0 else "?",
- self._col if self._col > 0 else "?",
- self._hint or "error",
- self._details,
- )
- except Exception:
- return "".join(
- (
- msg,
- LINESEP,
- LINESEP,
- "During formatting of the above exception, "
- "another exception occurred:",
- LINESEP,
- LINESEP,
- traceback.format_exc(),
- )
- )
- else:
- return msg
-
-
-def _lookup_cls(code: int, *, meta: type, default: type):
- try:
- return meta._index[code]
- except KeyError:
- pass
-
- b1, b2, b3, _ = _decode(code)
-
- try:
- return meta._base_class_index[(b1, b2, b3, 0)]
- except KeyError:
- pass
- try:
- return meta._base_class_index[(b1, b2, 0, 0)]
- except KeyError:
- pass
- try:
- return meta._base_class_index[(b1, 0, 0, 0)]
- except KeyError:
- pass
-
- return default
-
-
-def _lookup_error_cls(code: int):
- return _lookup_cls(code, meta=EdgeDBErrorMeta, default=EdgeDBError)
-
-
-def _lookup_message_cls(code: int):
- return _lookup_cls(code, meta=EdgeDBMessageMeta, default=EdgeDBMessage)
-
-
-def _decode(code: int):
- return tuple(code.to_bytes(4, 'big'))
-
-
-def _severity_name(severity):
- if severity <= EDGE_SEVERITY_DEBUG:
- return 'DEBUG'
- if severity <= EDGE_SEVERITY_INFO:
- return 'INFO'
- if severity <= EDGE_SEVERITY_NOTICE:
- return 'NOTICE'
- if severity <= EDGE_SEVERITY_WARNING:
- return 'WARNING'
- if severity <= EDGE_SEVERITY_ERROR:
- return 'ERROR'
- if severity <= EDGE_SEVERITY_FATAL:
- return 'FATAL'
- return 'PANIC'
-
-
-def _format_error(msg, query, start, offset, line, col, hint, details):
- c = get_color()
- rv = io.StringIO()
- rv.write(f"{c.BOLD}{msg}{c.ENDC}{LINESEP}")
- lines = query.splitlines(keepends=True)
- num_len = len(str(len(lines)))
- rv.write(f"{c.BLUE}{'':>{num_len}} ┌─{c.ENDC} query:{line}:{col}{LINESEP}")
- rv.write(f"{c.BLUE}{'':>{num_len}} │ {c.ENDC}{LINESEP}")
- for num, line in enumerate(lines):
- length = len(line)
- line = line.rstrip() # we'll use our own line separator
- if start >= length:
- # skip lines before the error
- start -= length
- continue
-
- if start >= 0:
- # Error starts in current line, write the line before the error
- first_half = repr(line[:start])[1:-1]
- line = line[start:]
- length -= start
- rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.ENDC}{first_half}")
- start = _unicode_width(first_half)
- else:
- # Multi-line error continues
- rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.FAIL}│ {c.ENDC}")
-
- if offset > length:
- # Error is ending beyond current line
- line = repr(line)[1:-1]
- rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}")
- if start >= 0:
- # Multi-line error starts
- rv.write(f"{c.BLUE}{'':>{num_len}} │ "
- f"{c.FAIL}â•â”€{'─' * start}^{c.ENDC}{LINESEP}")
- offset -= length
- start = -1 # mark multi-line
- else:
- # Error is ending within current line
- first_half = repr(line[:offset])[1:-1]
- line = repr(line[offset:])[1:-1]
- rv.write(f"{c.FAIL}{first_half}{c.ENDC}{line}{LINESEP}")
- size = _unicode_width(first_half)
- if start >= 0:
- # Mark single-line error
- rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}"
- f"{c.FAIL}{'^' * size} {hint}{c.ENDC}")
- else:
- # End of multi-line error
- rv.write(f"{c.BLUE}{'':>{num_len}} │ "
- f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}")
- break
-
- if details:
- rv.write(f"{LINESEP}Details: {details}")
-
- return rv.getvalue()
-
-
-def _unicode_width(text):
- return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else
- 2 if unicodedata.east_asian_width(c) == "W" else 1
- for c in text)
-
-
-FIELD_HINT = 0x_00_01
-FIELD_DETAILS = 0x_00_02
-FIELD_SERVER_TRACEBACK = 0x_01_01
-
-# XXX: Subject to be changed/deprecated.
-FIELD_POSITION_START = 0x_FF_F1
-FIELD_POSITION_END = 0x_FF_F2
-FIELD_LINE_START = 0x_FF_F3
-FIELD_COLUMN_START = 0x_FF_F4
-FIELD_UTF16_COLUMN_START = 0x_FF_F5
-FIELD_LINE_END = 0x_FF_F6
-FIELD_COLUMN_END = 0x_FF_F7
-FIELD_UTF16_COLUMN_END = 0x_FF_F8
-FIELD_CHARACTER_START = 0x_FF_F9
-FIELD_CHARACTER_END = 0x_FF_FA
-
-
-EDGE_SEVERITY_DEBUG = 20
-EDGE_SEVERITY_INFO = 40
-EDGE_SEVERITY_NOTICE = 60
-EDGE_SEVERITY_WARNING = 80
-EDGE_SEVERITY_ERROR = 120
-EDGE_SEVERITY_FATAL = 200
-EDGE_SEVERITY_PANIC = 255
-
-
-# Fields to include in the json dump of the type
-_JSON_FIELDS = {
- 'hint': FIELD_HINT,
- 'details': FIELD_DETAILS,
- 'start': FIELD_CHARACTER_START,
- 'end': FIELD_CHARACTER_END,
- 'line': FIELD_LINE_START,
- 'col': FIELD_COLUMN_START,
-}
-
-
-LINESEP = os.linesep
-
-try:
- SHOW_HINT = {"default": True, "enabled": True, "disabled": False}[
- os.getenv("EDGEDB_ERROR_HINT", "default")
- ]
-except KeyError:
- warnings.warn(
- "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled",
- stacklevel=1,
- )
- SHOW_HINT = False
-
-
-from edgedb.color import get_color
+# Auto-generated shim
+import gel.errors._base as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.errors._base']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/errors/tags.py b/edgedb/errors/tags.py
index 275b31ac..ff307bb9 100644
--- a/edgedb/errors/tags.py
+++ b/edgedb/errors/tags.py
@@ -1,25 +1,11 @@
-__all__ = [
- 'Tag',
- 'SHOULD_RECONNECT',
- 'SHOULD_RETRY',
-]
-
-
-class Tag(object):
- """Error tag
-
- Tags are used to differentiate certain properties of errors that apply to
- error classes across hierarchy.
-
- Use ``error.has_tag(tag_name)`` to check for a tag.
- """
-
- def __init__(self, name):
- self.name = name
-
- def __repr__(self):
- return f''
-
-
-SHOULD_RECONNECT = Tag('SHOULD_RECONNECT')
-SHOULD_RETRY = Tag('SHOULD_RETRY')
+# Auto-generated shim
+import gel.errors.tags as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.errors.tags']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/introspect.py b/edgedb/introspect.py
index 0db045e6..7fa06bfc 100644
--- a/edgedb/introspect.py
+++ b/edgedb/introspect.py
@@ -1,66 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-# IMPORTANT: this private API is subject to change.
-
-
-import functools
-import typing
-
-from edgedb.datatypes import datatypes as dt
-from edgedb.enums import ElementKind
-
-
-class PointerDescription(typing.NamedTuple):
-
- name: str
- kind: ElementKind
- implicit: bool
-
-
-class ObjectDescription(typing.NamedTuple):
-
- pointers: typing.Tuple[PointerDescription, ...]
-
-
-@functools.lru_cache()
-def _introspect_object_desc(desc) -> ObjectDescription:
- pointers = []
- # Call __dir__ directly as dir() scrambles the order.
- for name in desc.__dir__():
- if desc.is_link(name):
- kind = ElementKind.LINK
- elif desc.is_linkprop(name):
- continue
- else:
- kind = ElementKind.PROPERTY
-
- pointers.append(
- PointerDescription(
- name=name,
- kind=kind,
- implicit=desc.is_implicit(name)))
-
- return ObjectDescription(
- pointers=tuple(pointers))
-
-
-def introspect_object(obj) -> ObjectDescription:
- return _introspect_object_desc(
- dt.get_object_descriptor(obj))
+# Auto-generated shim
+import gel.introspect as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.introspect']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/options.py b/edgedb/options.py
index cec9541a..ba17dbc2 100644
--- a/edgedb/options.py
+++ b/edgedb/options.py
@@ -1,492 +1,11 @@
-import abc
-import enum
-import logging
-import random
-import typing
-import sys
-from collections import namedtuple
-
-from . import errors
-
-
-logger = logging.getLogger('edgedb')
-
-
-_RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"])
-
-
-def default_backoff(attempt):
- return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001
-
-
-WarningHandler = typing.Callable[
- [typing.Tuple[errors.EdgeDBError, ...], typing.Any],
- typing.Any,
-]
-
-
-def raise_warnings(warnings, res):
- if (
- len(warnings) > 1
- and sys.version_info >= (3, 11)
- ):
- raise ExceptionGroup( # noqa
- "Query produced warnings", warnings
- )
- else:
- raise warnings[0]
-
-
-def log_warnings(warnings, res):
- for w in warnings:
- logger.warning("EdgeDB warning: %s", str(w))
- return res
-
-
-class RetryCondition:
- """Specific condition to retry on for fine-grained control"""
- TransactionConflict = enum.auto()
- NetworkError = enum.auto()
-
-
-class IsolationLevel:
- """Isolation level for transaction"""
- Serializable = "SERIALIZABLE"
-
-
-class RetryOptions:
- """An immutable class that contains rules for `transaction()`"""
- __slots__ = ['_default', '_overrides']
-
- def __init__(self, attempts: int, backoff=default_backoff):
- self._default = _RetryRule(attempts, backoff)
- self._overrides = None
-
- def with_rule(self, condition, attempts=None, backoff=None):
- default = self._default
- overrides = self._overrides
- if overrides is None:
- overrides = {}
- else:
- overrides = overrides.copy()
- overrides[condition] = _RetryRule(
- default.attempts if attempts is None else attempts,
- default.backoff if backoff is None else backoff,
- )
- result = RetryOptions.__new__(RetryOptions)
- result._default = default
- result._overrides = overrides
- return result
-
- @classmethod
- def defaults(cls):
- return cls(
- attempts=3,
- backoff=default_backoff,
- )
-
- def get_rule_for_exception(self, exception):
- default = self._default
- overrides = self._overrides
- res = default
- if overrides:
- if isinstance(exception, errors.TransactionConflictError):
- res = overrides.get(RetryCondition.TransactionConflict, res)
- elif isinstance(exception, errors.ClientError):
- res = overrides.get(RetryCondition.NetworkError, res)
- return res
-
-
-class TransactionOptions:
- """Options for `transaction()`"""
- __slots__ = ['_isolation', '_readonly', '_deferrable']
-
- def __init__(
- self,
- isolation: IsolationLevel=IsolationLevel.Serializable,
- readonly: bool = False,
- deferrable: bool = False,
- ):
- self._isolation = isolation
- self._readonly = readonly
- self._deferrable = deferrable
-
- @classmethod
- def defaults(cls):
- return cls()
-
- def start_transaction_query(self):
- isolation = str(self._isolation)
- if self._readonly:
- mode = 'READ ONLY'
- else:
- mode = 'READ WRITE'
-
- if self._deferrable:
- defer = 'DEFERRABLE'
- else:
- defer = 'NOT DEFERRABLE'
-
- return f'START TRANSACTION ISOLATION {isolation}, {mode}, {defer};'
-
- def __repr__(self):
- return (
- f'<{self.__class__.__name__} '
- f'isolation:{self._isolation}, '
- f'readonly:{self._readonly}, '
- f'deferrable:{self._deferrable}>'
- )
-
-
-class State:
- __slots__ = ['_module', '_aliases', '_config', '_globals']
-
- def __init__(
- self,
- default_module: typing.Optional[str] = None,
- module_aliases: typing.Mapping[str, str] = None,
- config: typing.Mapping[str, typing.Any] = None,
- globals_: typing.Mapping[str, typing.Any] = None,
- ):
- self._module = default_module
- self._aliases = {} if module_aliases is None else dict(module_aliases)
- self._config = {} if config is None else dict(config)
- self._globals = (
- {} if globals_ is None else self.with_globals(globals_)._globals
- )
-
- @classmethod
- def _new(cls, default_module, module_aliases, config, globals_):
- rv = cls.__new__(cls)
- rv._module = default_module
- rv._aliases = module_aliases
- rv._config = config
- rv._globals = globals_
- return rv
-
- @classmethod
- def defaults(cls):
- return cls()
-
- def with_default_module(self, module: typing.Optional[str] = None):
- return self._new(
- default_module=module,
- module_aliases=self._aliases,
- config=self._config,
- globals_=self._globals,
- )
-
- def with_module_aliases(self, *args, **aliases):
- if len(args) > 1:
- raise errors.InvalidArgumentError(
- "with_module_aliases() takes from 0 to 1 positional arguments "
- "but {} were given".format(len(args))
- )
- aliases_dict = args[0] if args else {}
- aliases_dict.update(aliases)
- new_aliases = self._aliases.copy()
- new_aliases.update(aliases_dict)
- return self._new(
- default_module=self._module,
- module_aliases=new_aliases,
- config=self._config,
- globals_=self._globals,
- )
-
- def with_config(self, *args, **config):
- if len(args) > 1:
- raise errors.InvalidArgumentError(
- "with_config() takes from 0 to 1 positional arguments "
- "but {} were given".format(len(args))
- )
- config_dict = args[0] if args else {}
- config_dict.update(config)
- new_config = self._config.copy()
- new_config.update(config_dict)
- return self._new(
- default_module=self._module,
- module_aliases=self._aliases,
- config=new_config,
- globals_=self._globals,
- )
-
- def resolve(self, name: str) -> str:
- parts = name.split("::", 1)
- if len(parts) == 1:
- return f"{self._module or 'default'}::{name}"
- elif len(parts) == 2:
- mod, name = parts
- mod = self._aliases.get(mod, mod)
- return f"{mod}::{name}"
- else:
- raise AssertionError('broken split')
-
- def with_globals(self, *args, **globals_):
- if len(args) > 1:
- raise errors.InvalidArgumentError(
- "with_globals() takes from 0 to 1 positional arguments "
- "but {} were given".format(len(args))
- )
- new_globals = self._globals.copy()
- if args:
- for k, v in args[0].items():
- new_globals[self.resolve(k)] = v
- for k, v in globals_.items():
- new_globals[self.resolve(k)] = v
- return self._new(
- default_module=self._module,
- module_aliases=self._aliases,
- config=self._config,
- globals_=new_globals,
- )
-
- def without_module_aliases(self, *aliases):
- if not aliases:
- new_aliases = {}
- else:
- new_aliases = self._aliases.copy()
- for alias in aliases:
- new_aliases.pop(alias, None)
- return self._new(
- default_module=self._module,
- module_aliases=new_aliases,
- config=self._config,
- globals_=self._globals,
- )
-
- def without_config(self, *config_names):
- if not config_names:
- new_config = {}
- else:
- new_config = self._config.copy()
- for name in config_names:
- new_config.pop(name, None)
- return self._new(
- default_module=self._module,
- module_aliases=self._aliases,
- config=new_config,
- globals_=self._globals,
- )
-
- def without_globals(self, *global_names):
- if not global_names:
- new_globals = {}
- else:
- new_globals = self._globals.copy()
- for name in global_names:
- new_globals.pop(self.resolve(name), None)
- return self._new(
- default_module=self._module,
- module_aliases=self._aliases,
- config=self._config,
- globals_=new_globals,
- )
-
- def as_dict(self):
- rv = {}
- if self._module is not None:
- rv["module"] = self._module
- if self._aliases:
- rv["aliases"] = list(self._aliases.items())
- if self._config:
- rv["config"] = self._config
- if self._globals:
- rv["globals"] = self._globals
- return rv
-
-
-class _OptionsMixin:
- def __init__(self, *args, **kwargs):
- self._options = _Options.defaults()
- super().__init__(*args, **kwargs)
-
- @abc.abstractmethod
- def _shallow_clone(self):
- pass
-
- def with_transaction_options(self, options: TransactionOptions = None):
- """Returns object with adjusted options for future transactions.
-
- :param options TransactionOptions:
- Object that encapsulates transaction options.
-
- This method returns a "shallow copy" of the current object
- with modified transaction options.
-
- Both ``self`` and returned object can be used after, but when using
- them transaction options applied will be different.
-
- Transaction options are used by the ``transaction`` method.
- """
- result = self._shallow_clone()
- result._options = self._options.with_transaction_options(options)
- return result
-
- def with_retry_options(self, options: RetryOptions=None):
- """Returns object with adjusted options for future retrying
- transactions.
-
- :param options RetryOptions:
- Object that encapsulates retry options.
-
- This method returns a "shallow copy" of the current object
- with modified retry options.
-
- Both ``self`` and returned object can be used after, but when using
- them retry options applied will be different.
- """
-
- result = self._shallow_clone()
- result._options = self._options.with_retry_options(options)
- return result
-
- def with_warning_handler(self, warning_handler: WarningHandler=None):
- """Returns object with adjusted options for handling warnings.
-
- :param warning_handler WarningHandler:
- Function for handling warnings. It is passed a tuple of warnings
- and the query result and returns a potentially updated query
- result.
-
- This method returns a "shallow copy" of the current object
- with modified retry options.
-
- Both ``self`` and returned object can be used after, but when using
- them retry options applied will be different.
- """
-
- result = self._shallow_clone()
- result._options = self._options.with_warning_handler(warning_handler)
- return result
-
- def with_state(self, state: State):
- result = self._shallow_clone()
- result._options = self._options.with_state(state)
- return result
-
- def with_default_module(self, module: typing.Optional[str] = None):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.with_default_module(module)
- )
- return result
-
- def with_module_aliases(self, *args, **aliases):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.with_module_aliases(*args, **aliases)
- )
- return result
-
- def with_config(self, *args, **config):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.with_config(*args, **config)
- )
- return result
-
- def with_globals(self, *args, **globals_):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.with_globals(*args, **globals_)
- )
- return result
-
- def without_module_aliases(self, *aliases):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.without_module_aliases(*aliases)
- )
- return result
-
- def without_config(self, *config_names):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.without_config(*config_names)
- )
- return result
-
- def without_globals(self, *global_names):
- result = self._shallow_clone()
- result._options = self._options.with_state(
- self._options.state.without_globals(*global_names)
- )
- return result
-
-
-class _Options:
- """Internal class for storing connection options"""
-
- __slots__ = [
- '_retry_options', '_transaction_options', '_state',
- '_warning_handler'
- ]
-
- def __init__(
- self,
- retry_options: RetryOptions,
- transaction_options: TransactionOptions,
- state: State,
- warning_handler: WarningHandler,
- ):
- self._retry_options = retry_options
- self._transaction_options = transaction_options
- self._state = state
- self._warning_handler = warning_handler
-
- @property
- def retry_options(self):
- return self._retry_options
-
- @property
- def transaction_options(self):
- return self._transaction_options
-
- @property
- def state(self):
- return self._state
-
- @property
- def warning_handler(self):
- return self._warning_handler
-
- def with_retry_options(self, options: RetryOptions):
- return _Options(
- options,
- self._transaction_options,
- self._state,
- self._warning_handler,
- )
-
- def with_transaction_options(self, options: TransactionOptions):
- return _Options(
- self._retry_options,
- options,
- self._state,
- self._warning_handler,
- )
-
- def with_state(self, state: State):
- return _Options(
- self._retry_options,
- self._transaction_options,
- state,
- self._warning_handler,
- )
-
- def with_warning_handler(self, warning_handler: WarningHandler):
- return _Options(
- self._retry_options,
- self._transaction_options,
- self._state,
- warning_handler,
- )
-
- @classmethod
- def defaults(cls):
- return cls(
- RetryOptions.defaults(),
- TransactionOptions.defaults(),
- State.defaults(),
- log_warnings,
- )
+# Auto-generated shim
+import gel.options as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.options']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/pgproto/__init__.py b/edgedb/pgproto/__init__.py
new file mode 100644
index 00000000..14e16623
--- /dev/null
+++ b/edgedb/pgproto/__init__.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.pgproto as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.pgproto']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/pgproto/pgproto.py b/edgedb/pgproto/pgproto.py
new file mode 100644
index 00000000..879e7585
--- /dev/null
+++ b/edgedb/pgproto/pgproto.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.pgproto.pgproto as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.pgproto.pgproto']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/pgproto/types.py b/edgedb/pgproto/types.py
new file mode 100644
index 00000000..68484dd3
--- /dev/null
+++ b/edgedb/pgproto/types.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.pgproto.types as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.pgproto.types']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/platform.py b/edgedb/platform.py
index 55410532..18eec595 100644
--- a/edgedb/platform.py
+++ b/edgedb/platform.py
@@ -1,52 +1,11 @@
-import functools
-import os
-import pathlib
-import sys
-
-if sys.platform == "darwin":
- def config_dir() -> pathlib.Path:
- return (
- pathlib.Path.home() / "Library" / "Application Support" / "edgedb"
- )
-
- IS_WINDOWS = False
-
-elif sys.platform == "win32":
- import ctypes
- from ctypes import windll
-
- def config_dir() -> pathlib.Path:
- path_buf = ctypes.create_unicode_buffer(255)
- csidl = 28 # CSIDL_LOCAL_APPDATA
- windll.shell32.SHGetFolderPathW(0, csidl, 0, 0, path_buf)
- return pathlib.Path(path_buf.value) / "EdgeDB" / "config"
-
- IS_WINDOWS = True
-
-else:
- def config_dir() -> pathlib.Path:
- xdg_conf_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "."))
- if not xdg_conf_dir.is_absolute():
- xdg_conf_dir = pathlib.Path.home() / ".config"
- return xdg_conf_dir / "edgedb"
-
- IS_WINDOWS = False
-
-
-def old_config_dir() -> pathlib.Path:
- return pathlib.Path.home() / ".edgedb"
-
-
-def search_config_dir(*suffix):
- rv = functools.reduce(lambda p1, p2: p1 / p2, [config_dir(), *suffix])
- if rv.exists():
- return rv
-
- fallback = functools.reduce(
- lambda p1, p2: p1 / p2, [old_config_dir(), *suffix]
- )
- if fallback.exists():
- return fallback
-
- # None of the searched files exists, return the new path.
- return rv
+# Auto-generated shim
+import gel.platform as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.platform']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/protocol/__init__.py b/edgedb/protocol/__init__.py
index 46ceb625..f53f0a30 100644
--- a/edgedb/protocol/__init__.py
+++ b/edgedb/protocol/__init__.py
@@ -1,17 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
+# Auto-generated shim
+import gel.protocol as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.protocol']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/protocol/asyncio_proto.py b/edgedb/protocol/asyncio_proto.py
new file mode 100644
index 00000000..189b76d5
--- /dev/null
+++ b/edgedb/protocol/asyncio_proto.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.protocol.asyncio_proto as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.protocol.asyncio_proto']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/protocol/blocking_proto.py b/edgedb/protocol/blocking_proto.py
new file mode 100644
index 00000000..5974af66
--- /dev/null
+++ b/edgedb/protocol/blocking_proto.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.protocol.blocking_proto as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.protocol.blocking_proto']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/protocol/protocol.py b/edgedb/protocol/protocol.py
new file mode 100644
index 00000000..811284d5
--- /dev/null
+++ b/edgedb/protocol/protocol.py
@@ -0,0 +1,11 @@
+# Auto-generated shim
+import gel.protocol.protocol as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.protocol.protocol']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/scram/__init__.py b/edgedb/scram/__init__.py
index 62f3e712..e0ca2611 100644
--- a/edgedb/scram/__init__.py
+++ b/edgedb/scram/__init__.py
@@ -1,434 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""Helpers for SCRAM authentication."""
-
-import base64
-import hashlib
-import hmac
-import os
-import typing
-
-from .saslprep import saslprep
-
-
-RAW_NONCE_LENGTH = 18
-
-# Per recommendations in RFC 7677.
-DEFAULT_SALT_LENGTH = 16
-DEFAULT_ITERATIONS = 4096
-
-
-def generate_salt(length: int = DEFAULT_SALT_LENGTH) -> bytes:
- return os.urandom(length)
-
-
-def generate_nonce(length: int = RAW_NONCE_LENGTH) -> str:
- return B64(os.urandom(length))
-
-
-def build_verifier(password: str, *, salt: typing.Optional[bytes] = None,
- iterations: int = DEFAULT_ITERATIONS) -> str:
- """Build the SCRAM verifier for the given password.
-
- Returns a string in the following format:
-
- "$:$:"
-
- The salt and keys are base64-encoded values.
- """
- password = saslprep(password).encode('utf-8')
-
- if salt is None:
- salt = generate_salt()
-
- salted_password = get_salted_password(password, salt, iterations)
- client_key = get_client_key(salted_password)
- stored_key = H(client_key)
- server_key = get_server_key(salted_password)
-
- return (f'SCRAM-SHA-256${iterations}:{B64(salt)}$'
- f'{B64(stored_key)}:{B64(server_key)}')
-
-
-class SCRAMVerifier(typing.NamedTuple):
-
- mechanism: str
- iterations: int
- salt: bytes
- stored_key: bytes
- server_key: bytes
-
-
-def parse_verifier(verifier: str) -> SCRAMVerifier:
-
- parts = verifier.split('$')
- if len(parts) != 3:
- raise ValueError('invalid SCRAM verifier')
-
- mechanism = parts[0]
- if mechanism != 'SCRAM-SHA-256':
- raise ValueError('invalid SCRAM verifier')
-
- iterations, _, salt = parts[1].partition(':')
- stored_key, _, server_key = parts[2].partition(':')
- if not salt or not server_key:
- raise ValueError('invalid SCRAM verifier')
-
- try:
- iterations = int(iterations)
- except ValueError:
- raise ValueError('invalid SCRAM verifier') from None
-
- return SCRAMVerifier(
- mechanism=mechanism,
- iterations=iterations,
- salt=base64.b64decode(salt),
- stored_key=base64.b64decode(stored_key),
- server_key=base64.b64decode(server_key),
- )
-
-
-def parse_client_first_message(resp: bytes):
-
- # Relevant bits of RFC 5802:
- #
- # saslname = 1*(value-safe-char / "=2C" / "=3D")
- # ;; Conforms to .
- #
- # authzid = "a=" saslname
- # ;; Protocol specific.
- #
- # cb-name = 1*(ALPHA / DIGIT / "." / "-")
- # ;; See RFC 5056, Section 7.
- # ;; E.g., "tls-server-end-point" or
- # ;; "tls-unique".
- #
- # gs2-cbind-flag = ("p=" cb-name) / "n" / "y"
- # ;; "n" -> client doesn't support channel binding.
- # ;; "y" -> client does support channel binding
- # ;; but thinks the server does not.
- # ;; "p" -> client requires channel binding.
- # ;; The selected channel binding follows "p=".
- #
- # gs2-header = gs2-cbind-flag "," [ authzid ] ","
- # ;; GS2 header for SCRAM
- # ;; (the actual GS2 header includes an optional
- # ;; flag to indicate that the GSS mechanism is not
- # ;; "standard", but since SCRAM is "standard", we
- # ;; don't include that flag).
- #
- # username = "n=" saslname
- # ;; Usernames are prepared using SASLprep.
- #
- # reserved-mext = "m=" 1*(value-char)
- # ;; Reserved for signaling mandatory extensions.
- # ;; The exact syntax will be defined in
- # ;; the future.
- #
- # nonce = "r=" c-nonce [s-nonce]
- # ;; Second part provided by server.
- #
- # c-nonce = printable
- #
- # client-first-message-bare =
- # [reserved-mext ","]
- # username "," nonce ["," extensions]
- #
- # client-first-message =
- # gs2-header client-first-message-bare
-
- attrs = resp.split(b',')
-
- cb_attr = attrs[0]
- if cb_attr == b'y':
- cb = True
- elif cb_attr == b'n':
- cb = False
- elif cb_attr[0:1] == b'p':
- _, _, cb = cb_attr.partition(b'=')
- if not cb:
- raise ValueError('malformed SCRAM message')
- else:
- raise ValueError('malformed SCRAM message')
-
- authzid_attr = attrs[1]
- if authzid_attr:
- if authzid_attr[0:1] != b'a':
- raise ValueError('malformed SCRAM message')
- _, _, authzid = authzid_attr.partition(b'=')
- else:
- authzid = None
-
- user_attr = attrs[2]
- if user_attr[0:1] == b'm':
- raise ValueError('unsupported SCRAM extensions in message')
- elif user_attr[0:1] != b'n':
- raise ValueError('malformed SCRAM message')
-
- _, _, user = user_attr.partition(b'=')
-
- nonce_attr = attrs[3]
- if nonce_attr[0:1] != b'r':
- raise ValueError('malformed SCRAM message')
-
- _, _, nonce_bin = nonce_attr.partition(b'=')
- nonce = nonce_bin.decode('ascii')
- if not nonce.isprintable():
- raise ValueError('invalid characters in client nonce')
-
- # ["," extensions] are ignored
-
- return len(cb_attr) + 2, cb, authzid, user, nonce
-
-
-def parse_client_final_message(
- msg: bytes, client_nonce: str, server_nonce: str):
-
- # Relevant bits of RFC 5802:
- #
- # gs2-header = gs2-cbind-flag "," [ authzid ] ","
- # ;; GS2 header for SCRAM
- # ;; (the actual GS2 header includes an optional
- # ;; flag to indicate that the GSS mechanism is not
- # ;; "standard", but since SCRAM is "standard", we
- # ;; don't include that flag).
- #
- # cbind-input = gs2-header [ cbind-data ]
- # ;; cbind-data MUST be present for
- # ;; gs2-cbind-flag of "p" and MUST be absent
- # ;; for "y" or "n".
- #
- # channel-binding = "c=" base64
- # ;; base64 encoding of cbind-input.
- #
- # proof = "p=" base64
- #
- # client-final-message-without-proof =
- # channel-binding "," nonce [","
- # extensions]
- #
- # client-final-message =
- # client-final-message-without-proof "," proof
-
- attrs = msg.split(b',')
-
- cb_attr = attrs[0]
- if cb_attr[0:1] != b'c':
- raise ValueError('malformed SCRAM message')
-
- _, _, cb_data = cb_attr.partition(b'=')
-
- nonce_attr = attrs[1]
- if nonce_attr[0:1] != b'r':
- raise ValueError('malformed SCRAM message')
-
- _, _, nonce_bin = nonce_attr.partition(b'=')
- nonce = nonce_bin.decode('ascii')
-
- expected_nonce = f'{client_nonce}{server_nonce}'
-
- if nonce != expected_nonce:
- raise ValueError(
- 'invalid SCRAM client-final message: nonce does not match')
-
- proof = None
-
- for attr in attrs[2:]:
- if attr[0:1] == b'p':
- _, _, proof = attr.partition(b'=')
- proof_attr_len = len(attr)
- proof = base64.b64decode(proof)
- elif proof is not None:
- raise ValueError('malformed SCRAM message')
-
- if proof is None:
- raise ValueError('malformed SCRAM message')
-
- return cb_data, proof, proof_attr_len + 1
-
-
-def build_client_first_message(client_nonce: str, username: str) -> str:
-
- bare = f'n={saslprep(username)},r={client_nonce}'
- return f'n,,{bare}', bare
-
-
-def build_server_first_message(server_nonce: str, client_nonce: str,
- salt: bytes, iterations: int) -> str:
-
- return (
- f'r={client_nonce}{server_nonce},'
- f's={B64(salt)},i={iterations}'
- )
-
-
-def build_auth_message(
- client_first_bare: bytes,
- server_first: bytes, client_final: bytes) -> bytes:
-
- return b'%b,%b,%b' % (client_first_bare, server_first, client_final)
-
-
-def build_client_final_message(
- password: str,
- salt: bytes,
- iterations: int,
- client_first_bare: bytes,
- server_first: bytes,
- server_nonce: str) -> str:
-
- client_final = f'c=biws,r={server_nonce}'
-
- AuthMessage = build_auth_message(
- client_first_bare, server_first, client_final.encode('utf-8'))
-
- SaltedPassword = get_salted_password(
- saslprep(password).encode('utf-8'),
- salt,
- iterations)
-
- ClientKey = get_client_key(SaltedPassword)
- StoredKey = H(ClientKey)
- ClientSignature = HMAC(StoredKey, AuthMessage)
- ClientProof = XOR(ClientKey, ClientSignature)
-
- ServerKey = get_server_key(SaltedPassword)
- ServerProof = HMAC(ServerKey, AuthMessage)
-
- return f'{client_final},p={B64(ClientProof)}', ServerProof
-
-
-def build_server_final_message(
- client_first_bare: bytes, server_first: bytes,
- client_final: bytes, server_key: bytes) -> str:
-
- AuthMessage = build_auth_message(
- client_first_bare, server_first, client_final)
- ServerSignature = HMAC(server_key, AuthMessage)
- return f'v={B64(ServerSignature)}'
-
-
-def parse_server_first_message(msg: bytes):
-
- attrs = msg.split(b',')
-
- nonce_attr = attrs[0]
- if nonce_attr[0:1] != b'r':
- raise ValueError('malformed SCRAM message')
-
- _, _, nonce_bin = nonce_attr.partition(b'=')
- nonce = nonce_bin.decode('ascii')
- if not nonce.isprintable():
- raise ValueError('malformed SCRAM message')
-
- salt_attr = attrs[1]
- if salt_attr[0:1] != b's':
- raise ValueError('malformed SCRAM message')
-
- _, _, salt_b64 = salt_attr.partition(b'=')
- salt = base64.b64decode(salt_b64)
-
- iter_attr = attrs[2]
- if iter_attr[0:1] != b'i':
- raise ValueError('malformed SCRAM message')
-
- _, _, iterations = iter_attr.partition(b'=')
-
- try:
- itercount = int(iterations)
- except ValueError:
- raise ValueError('malformed SCRAM message') from None
-
- return nonce, salt, itercount
-
-
-def parse_server_final_message(msg: bytes):
-
- attrs = msg.split(b',')
-
- nonce_attr = attrs[0]
- if nonce_attr[0:1] != b'v':
- raise ValueError('malformed SCRAM message')
-
- _, _, signature_b64 = nonce_attr.partition(b'=')
- signature = base64.b64decode(signature_b64)
-
- return signature
-
-
-def verify_password(password: bytes, verifier: str) -> bool:
- """Check the given password against a verifier.
-
- Returns True if the password is OK, False otherwise.
- """
-
- password = saslprep(password).encode('utf-8')
- v = parse_verifier(verifier)
- salted_password = get_salted_password(password, v.salt, v.iterations)
- computed_key = get_server_key(salted_password)
- return v.server_key == computed_key
-
-
-def verify_client_proof(client_first: bytes, server_first: bytes,
- client_final: bytes, StoredKey: bytes,
- ClientProof: bytes) -> bool:
- AuthMessage = build_auth_message(client_first, server_first, client_final)
- ClientSignature = HMAC(StoredKey, AuthMessage)
- ClientKey = XOR(ClientProof, ClientSignature)
- return H(ClientKey) == StoredKey
-
-
-def B64(val: bytes) -> str:
- """Return base64-encoded string representation of input binary data."""
- return base64.b64encode(val).decode()
-
-
-def HMAC(key: bytes, msg: bytes) -> bytes:
- return hmac.new(key, msg, digestmod=hashlib.sha256).digest()
-
-
-def XOR(a: bytes, b: bytes) -> bytes:
- if len(a) != len(b):
- raise ValueError('scram.XOR received operands of unequal length')
- xint = int.from_bytes(a, 'big') ^ int.from_bytes(b, 'big')
- return xint.to_bytes(len(a), 'big')
-
-
-def H(s: bytes) -> bytes:
- return hashlib.sha256(s).digest()
-
-
-def get_salted_password(password: bytes, salt: bytes,
- iterations: int) -> bytes:
- # U1 := HMAC(str, salt + INT(1))
- H_i = U_i = HMAC(password, salt + b'\x00\x00\x00\x01')
-
- for _ in range(iterations - 1):
- U_i = HMAC(password, U_i)
- H_i = XOR(H_i, U_i)
-
- return H_i
-
-
-def get_client_key(salted_password: bytes) -> bytes:
- return HMAC(salted_password, b'Client Key')
-
-
-def get_server_key(salted_password: bytes) -> bytes:
- return HMAC(salted_password, b'Server Key')
+# Auto-generated shim
+import gel.scram as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.scram']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/scram/saslprep.py b/edgedb/scram/saslprep.py
index 79eb84d8..18128fe9 100644
--- a/edgedb/scram/saslprep.py
+++ b/edgedb/scram/saslprep.py
@@ -1,82 +1,11 @@
-# Copyright 2016-present MongoDB, Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import stringprep
-import unicodedata
-
-
-# RFC4013 section 2.3 prohibited output.
-_PROHIBITED = (
- # A strict reading of RFC 4013 requires table c12 here, but
- # characters from it are mapped to SPACE in the Map step. Can
- # normalization reintroduce them somehow?
- stringprep.in_table_c12,
- stringprep.in_table_c21_c22,
- stringprep.in_table_c3,
- stringprep.in_table_c4,
- stringprep.in_table_c5,
- stringprep.in_table_c6,
- stringprep.in_table_c7,
- stringprep.in_table_c8,
- stringprep.in_table_c9)
-
-
-def saslprep(data: str, prohibit_unassigned_code_points=True):
- """An implementation of RFC4013 SASLprep."""
-
- if data == '':
- return data
-
- if prohibit_unassigned_code_points:
- prohibited = _PROHIBITED + (stringprep.in_table_a1,)
- else:
- prohibited = _PROHIBITED
-
- # RFC3454 section 2, step 1 - Map
- # RFC4013 section 2.1 mappings
- # Map Non-ASCII space characters to SPACE (U+0020). Map
- # commonly mapped to nothing characters to, well, nothing.
- in_table_c12 = stringprep.in_table_c12
- in_table_b1 = stringprep.in_table_b1
- data = u"".join(
- [u"\u0020" if in_table_c12(elt) else elt
- for elt in data if not in_table_b1(elt)])
-
- # RFC3454 section 2, step 2 - Normalize
- # RFC4013 section 2.2 normalization
- data = unicodedata.ucd_3_2_0.normalize('NFKC', data)
-
- in_table_d1 = stringprep.in_table_d1
- if in_table_d1(data[0]):
- if not in_table_d1(data[-1]):
- # RFC3454, Section 6, #3. If a string contains any
- # RandALCat character, the first and last characters
- # MUST be RandALCat characters.
- raise ValueError("SASLprep: failed bidirectional check")
- # RFC3454, Section 6, #2. If a string contains any RandALCat
- # character, it MUST NOT contain any LCat character.
- prohibited = prohibited + (stringprep.in_table_d2,)
- else:
- # RFC3454, Section 6, #3. Following the logic of #3, if
- # the first character is not a RandALCat, no other character
- # can be either.
- prohibited = prohibited + (in_table_d1,)
-
- # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi
- for char in data:
- if any(in_table(char) for in_table in prohibited):
- raise ValueError(
- "SASLprep: failed prohibited character check")
-
- return data
+# Auto-generated shim
+import gel.scram.saslprep as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.scram.saslprep']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/edgedb/transaction.py b/edgedb/transaction.py
index 2e2d871a..91828a2b 100644
--- a/edgedb/transaction.py
+++ b/edgedb/transaction.py
@@ -1,224 +1,11 @@
-#
-# This source file is part of the EdgeDB open source project.
-#
-# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-import enum
-
-from . import abstract
-from . import errors
-from . import options
-
-
-class TransactionState(enum.Enum):
- NEW = 0
- STARTED = 1
- COMMITTED = 2
- ROLLEDBACK = 3
- FAILED = 4
-
-
-class BaseTransaction:
-
- __slots__ = (
- '_client',
- '_connection',
- '_options',
- '_state',
- '__retry',
- '__iteration',
- '__started',
- )
-
- def __init__(self, retry, client, iteration):
- self._client = client
- self._connection = None
- self._options = retry._options.transaction_options
- self._state = TransactionState.NEW
- self.__retry = retry
- self.__iteration = iteration
- self.__started = False
-
- def is_active(self) -> bool:
- return self._state is TransactionState.STARTED
-
- def __check_state_base(self, opname):
- if self._state is TransactionState.COMMITTED:
- raise errors.InterfaceError(
- 'cannot {}; the transaction is already committed'.format(
- opname))
- if self._state is TransactionState.ROLLEDBACK:
- raise errors.InterfaceError(
- 'cannot {}; the transaction is already rolled back'.format(
- opname))
- if self._state is TransactionState.FAILED:
- raise errors.InterfaceError(
- 'cannot {}; the transaction is in error state'.format(
- opname))
-
- def __check_state(self, opname):
- if self._state is not TransactionState.STARTED:
- if self._state is TransactionState.NEW:
- raise errors.InterfaceError(
- 'cannot {}; the transaction is not yet started'.format(
- opname))
- self.__check_state_base(opname)
-
- def _make_start_query(self):
- self.__check_state_base('start')
- if self._state is TransactionState.STARTED:
- raise errors.InterfaceError(
- 'cannot start; the transaction is already started')
-
- return self._options.start_transaction_query()
-
- def _make_commit_query(self):
- self.__check_state('commit')
- return 'COMMIT;'
-
- def _make_rollback_query(self):
- self.__check_state('rollback')
- return 'ROLLBACK;'
-
- def __repr__(self):
- attrs = []
- attrs.append('state:{}'.format(self._state.name.lower()))
- attrs.append(repr(self._options))
-
- if self.__class__.__module__.startswith('edgedb.'):
- mod = 'edgedb'
- else:
- mod = self.__class__.__module__
-
- return '<{}.{} {} {:#x}>'.format(
- mod, self.__class__.__name__, ' '.join(attrs), id(self))
-
- async def _ensure_transaction(self):
- if not self.__started:
- self.__started = True
- query = self._make_start_query()
- self._connection = await self._client._impl.acquire()
- if self._connection.is_closed():
- await self._connection.connect(
- single_attempt=self.__iteration != 0
- )
- try:
- await self._privileged_execute(query)
- except BaseException:
- self._state = TransactionState.FAILED
- raise
- else:
- self._state = TransactionState.STARTED
-
- async def _exit(self, extype, ex):
- if not self.__started:
- return False
-
- try:
- if extype is None:
- query = self._make_commit_query()
- state = TransactionState.COMMITTED
- else:
- query = self._make_rollback_query()
- state = TransactionState.ROLLEDBACK
- try:
- await self._privileged_execute(query)
- except BaseException:
- self._state = TransactionState.FAILED
- if extype is None:
- # COMMIT itself may fail; recover in connection
- await self._privileged_execute("ROLLBACK;")
- raise
- else:
- self._state = state
- except errors.EdgeDBError as err:
- if ex is None:
- # On commit we don't know if commit is succeeded before the
- # database have received it or after it have been done but
- # network is dropped before we were able to receive a response.
- # On a TransactionError, though, we know the we need
- # to retry.
- # TODO(tailhook) should other errors have retries?
- if (
- isinstance(err, errors.TransactionError)
- and err.has_tag(errors.SHOULD_RETRY)
- and self.__retry._retry(err)
- ):
- pass
- else:
- raise err
- # If we were going to rollback, look at original error
- # to find out whether we want to retry, regardless of
- # the rollback error.
- # In this case we ignore rollback issue as original error is more
- # important, e.g. in case `CancelledError` it's important
- # to propagate it to cancel the whole task.
- # NOTE: rollback error is always swallowed, should we use
- # on_log_message for it?
- finally:
- await self._client._impl.release(self._connection)
-
- if (
- extype is not None and
- issubclass(extype, errors.EdgeDBError) and
- ex.has_tag(errors.SHOULD_RETRY)
- ):
- return self.__retry._retry(ex)
-
- def _get_query_cache(self) -> abstract.QueryCache:
- return self._client._get_query_cache()
-
- def _get_state(self) -> options.State:
- return self._client._get_state()
-
- def _get_warning_handler(self) -> options.WarningHandler:
- return self._client._get_warning_handler()
-
- async def _query(self, query_context: abstract.QueryContext):
- await self._ensure_transaction()
- return await self._connection.raw_query(query_context)
-
- async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
- await self._ensure_transaction()
- await self._connection._execute(execute_context)
-
- async def _privileged_execute(self, query: str) -> None:
- await self._connection.privileged_execute(abstract.ExecuteContext(
- query=abstract.QueryWithArgs(query, (), {}),
- cache=self._get_query_cache(),
- state=self._get_state(),
- warning_handler=self._get_warning_handler(),
- ))
-
-
-class BaseRetry:
-
- def __init__(self, owner):
- self._owner = owner
- self._iteration = 0
- self._done = False
- self._next_backoff = 0
- self._options = owner._options
-
- def _retry(self, exc):
- self._last_exception = exc
- rule = self._options.retry_options.get_rule_for_exception(exc)
- if self._iteration >= rule.attempts:
- return False
- self._done = False
- self._next_backoff = rule.backoff(self._iteration)
- return True
+# Auto-generated shim
+import gel.transaction as _mod
+import sys as _sys
+_cur = _sys.modules['edgedb.transaction']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
diff --git a/gel/__init__.py b/gel/__init__.py
new file mode 100644
index 00000000..c19f7778
--- /dev/null
+++ b/gel/__init__.py
@@ -0,0 +1,284 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+# flake8: noqa
+
+from ._version import __version__
+
+from gel.datatypes.datatypes import (
+ Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory
+)
+from gel.datatypes.datatypes import Set, Object, Array
+from gel.datatypes.range import Range, MultiRange
+
+from .abstract import (
+ Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor,
+)
+
+from .asyncio_client import (
+ create_async_client,
+ AsyncIOClient
+)
+
+from .blocking_client import create_client, Client
+from .enums import Cardinality, ElementKind
+from .options import RetryCondition, IsolationLevel, default_backoff
+from .options import RetryOptions, TransactionOptions
+from .options import State
+
+from .errors._base import EdgeDBError, EdgeDBMessage
+
+__all__ = [
+ "Array",
+ "AsyncIOClient",
+ "AsyncIOExecutor",
+ "AsyncIOReadOnlyExecutor",
+ "Cardinality",
+ "Client",
+ "ConfigMemory",
+ "DateDuration",
+ "EdgeDBError",
+ "EdgeDBMessage",
+ "ElementKind",
+ "EnumValue",
+ "Executor",
+ "IsolationLevel",
+ "NamedTuple",
+ "Object",
+ "Range",
+ "ReadOnlyExecutor",
+ "RelativeDuration",
+ "RetryCondition",
+ "RetryOptions",
+ "Set",
+ "State",
+ "TransactionOptions",
+ "Tuple",
+ "create_async_client",
+ "create_client",
+ "default_backoff",
+]
+
+
+# The below is generated by `make gen-errors`.
+# DO NOT MODIFY BY HAND.
+#
+#
+from .errors import (
+ InternalServerError,
+ UnsupportedFeatureError,
+ ProtocolError,
+ BinaryProtocolError,
+ UnsupportedProtocolVersionError,
+ TypeSpecNotFoundError,
+ UnexpectedMessageError,
+ InputDataError,
+ ParameterTypeMismatchError,
+ StateMismatchError,
+ ResultCardinalityMismatchError,
+ CapabilityError,
+ UnsupportedCapabilityError,
+ DisabledCapabilityError,
+ QueryError,
+ InvalidSyntaxError,
+ EdgeQLSyntaxError,
+ SchemaSyntaxError,
+ GraphQLSyntaxError,
+ InvalidTypeError,
+ InvalidTargetError,
+ InvalidLinkTargetError,
+ InvalidPropertyTargetError,
+ InvalidReferenceError,
+ UnknownModuleError,
+ UnknownLinkError,
+ UnknownPropertyError,
+ UnknownUserError,
+ UnknownDatabaseError,
+ UnknownParameterError,
+ SchemaError,
+ SchemaDefinitionError,
+ InvalidDefinitionError,
+ InvalidModuleDefinitionError,
+ InvalidLinkDefinitionError,
+ InvalidPropertyDefinitionError,
+ InvalidUserDefinitionError,
+ InvalidDatabaseDefinitionError,
+ InvalidOperatorDefinitionError,
+ InvalidAliasDefinitionError,
+ InvalidFunctionDefinitionError,
+ InvalidConstraintDefinitionError,
+ InvalidCastDefinitionError,
+ DuplicateDefinitionError,
+ DuplicateModuleDefinitionError,
+ DuplicateLinkDefinitionError,
+ DuplicatePropertyDefinitionError,
+ DuplicateUserDefinitionError,
+ DuplicateDatabaseDefinitionError,
+ DuplicateOperatorDefinitionError,
+ DuplicateViewDefinitionError,
+ DuplicateFunctionDefinitionError,
+ DuplicateConstraintDefinitionError,
+ DuplicateCastDefinitionError,
+ DuplicateMigrationError,
+ SessionTimeoutError,
+ IdleSessionTimeoutError,
+ QueryTimeoutError,
+ TransactionTimeoutError,
+ IdleTransactionTimeoutError,
+ ExecutionError,
+ InvalidValueError,
+ DivisionByZeroError,
+ NumericOutOfRangeError,
+ AccessPolicyError,
+ QueryAssertionError,
+ IntegrityError,
+ ConstraintViolationError,
+ CardinalityViolationError,
+ MissingRequiredError,
+ TransactionError,
+ TransactionConflictError,
+ TransactionSerializationError,
+ TransactionDeadlockError,
+ WatchError,
+ ConfigurationError,
+ AccessError,
+ AuthenticationError,
+ AvailabilityError,
+ BackendUnavailableError,
+ ServerOfflineError,
+ BackendError,
+ UnsupportedBackendFeatureError,
+ LogMessage,
+ WarningMessage,
+ ClientError,
+ ClientConnectionError,
+ ClientConnectionFailedError,
+ ClientConnectionFailedTemporarilyError,
+ ClientConnectionTimeoutError,
+ ClientConnectionClosedError,
+ InterfaceError,
+ QueryArgumentError,
+ MissingArgumentError,
+ UnknownArgumentError,
+ InvalidArgumentError,
+ NoDataError,
+ InternalClientError,
+)
+
+__all__.extend([
+ "InternalServerError",
+ "UnsupportedFeatureError",
+ "ProtocolError",
+ "BinaryProtocolError",
+ "UnsupportedProtocolVersionError",
+ "TypeSpecNotFoundError",
+ "UnexpectedMessageError",
+ "InputDataError",
+ "ParameterTypeMismatchError",
+ "StateMismatchError",
+ "ResultCardinalityMismatchError",
+ "CapabilityError",
+ "UnsupportedCapabilityError",
+ "DisabledCapabilityError",
+ "QueryError",
+ "InvalidSyntaxError",
+ "EdgeQLSyntaxError",
+ "SchemaSyntaxError",
+ "GraphQLSyntaxError",
+ "InvalidTypeError",
+ "InvalidTargetError",
+ "InvalidLinkTargetError",
+ "InvalidPropertyTargetError",
+ "InvalidReferenceError",
+ "UnknownModuleError",
+ "UnknownLinkError",
+ "UnknownPropertyError",
+ "UnknownUserError",
+ "UnknownDatabaseError",
+ "UnknownParameterError",
+ "SchemaError",
+ "SchemaDefinitionError",
+ "InvalidDefinitionError",
+ "InvalidModuleDefinitionError",
+ "InvalidLinkDefinitionError",
+ "InvalidPropertyDefinitionError",
+ "InvalidUserDefinitionError",
+ "InvalidDatabaseDefinitionError",
+ "InvalidOperatorDefinitionError",
+ "InvalidAliasDefinitionError",
+ "InvalidFunctionDefinitionError",
+ "InvalidConstraintDefinitionError",
+ "InvalidCastDefinitionError",
+ "DuplicateDefinitionError",
+ "DuplicateModuleDefinitionError",
+ "DuplicateLinkDefinitionError",
+ "DuplicatePropertyDefinitionError",
+ "DuplicateUserDefinitionError",
+ "DuplicateDatabaseDefinitionError",
+ "DuplicateOperatorDefinitionError",
+ "DuplicateViewDefinitionError",
+ "DuplicateFunctionDefinitionError",
+ "DuplicateConstraintDefinitionError",
+ "DuplicateCastDefinitionError",
+ "DuplicateMigrationError",
+ "SessionTimeoutError",
+ "IdleSessionTimeoutError",
+ "QueryTimeoutError",
+ "TransactionTimeoutError",
+ "IdleTransactionTimeoutError",
+ "ExecutionError",
+ "InvalidValueError",
+ "DivisionByZeroError",
+ "NumericOutOfRangeError",
+ "AccessPolicyError",
+ "QueryAssertionError",
+ "IntegrityError",
+ "ConstraintViolationError",
+ "CardinalityViolationError",
+ "MissingRequiredError",
+ "TransactionError",
+ "TransactionConflictError",
+ "TransactionSerializationError",
+ "TransactionDeadlockError",
+ "WatchError",
+ "ConfigurationError",
+ "AccessError",
+ "AuthenticationError",
+ "AvailabilityError",
+ "BackendUnavailableError",
+ "ServerOfflineError",
+ "BackendError",
+ "UnsupportedBackendFeatureError",
+ "LogMessage",
+ "WarningMessage",
+ "ClientError",
+ "ClientConnectionError",
+ "ClientConnectionFailedError",
+ "ClientConnectionFailedTemporarilyError",
+ "ClientConnectionTimeoutError",
+ "ClientConnectionClosedError",
+ "InterfaceError",
+ "QueryArgumentError",
+ "MissingArgumentError",
+ "UnknownArgumentError",
+ "InvalidArgumentError",
+ "NoDataError",
+ "InternalClientError",
+])
+#
diff --git a/gel/_taskgroup.py b/gel/_taskgroup.py
new file mode 100644
index 00000000..0a1859d7
--- /dev/null
+++ b/gel/_taskgroup.py
@@ -0,0 +1,295 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import asyncio
+import functools
+import itertools
+import textwrap
+import traceback
+
+
+class TaskGroup:
+
+ def __init__(self, *, name=None):
+ if name is None:
+ self._name = f'tg-{_name_counter()}'
+ else:
+ self._name = str(name)
+
+ self._entered = False
+ self._exiting = False
+ self._aborting = False
+ self._loop = None
+ self._parent_task = None
+ self._parent_cancel_requested = False
+ self._tasks = set()
+ self._unfinished_tasks = 0
+ self._errors = []
+ self._base_error = None
+ self._on_completed_fut = None
+
+ def get_name(self):
+ return self._name
+
+ def __repr__(self):
+ msg = f''
+ return msg
+
+ async def __aenter__(self):
+ if self._entered:
+ raise RuntimeError(
+ f"TaskGroup {self!r} has been already entered")
+ self._entered = True
+
+ if self._loop is None:
+ self._loop = asyncio.get_event_loop()
+
+ if hasattr(asyncio, 'current_task'):
+ self._parent_task = asyncio.current_task(self._loop)
+ else:
+ self._parent_task = asyncio.Task.current_task(self._loop)
+
+ if self._parent_task is None:
+ raise RuntimeError(
+ f'TaskGroup {self!r} cannot determine the parent task')
+ self._patch_task(self._parent_task)
+
+ return self
+
+ async def __aexit__(self, et, exc, tb):
+ self._exiting = True
+ propagate_cancelation = False
+
+ if (exc is not None and
+ self._is_base_error(exc) and
+ self._base_error is None):
+ self._base_error = exc
+
+ if et is asyncio.CancelledError:
+ if self._parent_cancel_requested:
+ # Only if we did request task to cancel ourselves
+ # we mark it as no longer cancelled.
+ self._parent_task.__cancel_requested__ = False
+ else:
+ propagate_cancelation = True
+
+ if et is not None and not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # await ... # <- CancelledError
+ #
+ if et is asyncio.CancelledError:
+ propagate_cancelation = True
+
+ # or there's an exception in "async with":
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # 1 / 0
+ #
+ self._abort()
+
+ # We use while-loop here because "self._on_completed_fut"
+ # can be cancelled multiple times if our parent task
+ # is being cancelled repeatedly (or even once, when
+ # our own cancellation is already in progress)
+ while self._unfinished_tasks:
+ if self._on_completed_fut is None:
+ self._on_completed_fut = self._loop.create_future()
+
+ try:
+ await self._on_completed_fut
+ except asyncio.CancelledError:
+ if not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async def wrapper():
+ # async with TaskGroup() as g:
+ # g.create_task(foo)
+ #
+ # "wrapper" is being cancelled while "foo" is
+ # still running.
+ propagate_cancelation = True
+ self._abort()
+
+ self._on_completed_fut = None
+
+ assert self._unfinished_tasks == 0
+ self._on_completed_fut = None # no longer needed
+
+ if self._base_error is not None:
+ raise self._base_error
+
+ if propagate_cancelation:
+ # The wrapping task was cancelled; since we're done with
+ # closing all child tasks, just propagate the cancellation
+ # request now.
+ raise asyncio.CancelledError()
+
+ if et is not None and et is not asyncio.CancelledError:
+ self._errors.append(exc)
+
+ if self._errors:
+ # Exceptions are heavy objects that can have object
+ # cycles (bad for GC); let's not keep a reference to
+ # a bunch of them.
+ errors = self._errors
+ self._errors = None
+
+ me = TaskGroupError('unhandled errors in a TaskGroup',
+ errors=errors)
+ raise me from None
+
+ def create_task(self, coro):
+ if not self._entered:
+ raise RuntimeError(f"TaskGroup {self!r} has not been entered")
+ if self._exiting:
+ raise RuntimeError(f"TaskGroup {self!r} is awaiting in exit")
+ task = self._loop.create_task(coro)
+ task.add_done_callback(self._on_task_done)
+ self._unfinished_tasks += 1
+ self._tasks.add(task)
+ return task
+
+ def _is_base_error(self, exc):
+ assert isinstance(exc, BaseException)
+ return not isinstance(exc, Exception)
+
+ def _patch_task(self, task):
+ # In Python 3.8 we'll need proper API on asyncio.Task to
+ # make TaskGroups possible. We need to be able to access
+ # information about task cancellation, more specifically,
+ # we need a flag to say if a task was cancelled or not.
+ # We also need to be able to flip that flag.
+
+ def _task_cancel(task, orig_cancel):
+ task.__cancel_requested__ = True
+ return orig_cancel()
+
+ if hasattr(task, '__cancel_requested__'):
+ return
+
+ task.__cancel_requested__ = False
+ # confirm that we were successful at adding the new attribute:
+ assert not task.__cancel_requested__
+
+ orig_cancel = task.cancel
+ task.cancel = functools.partial(_task_cancel, task, orig_cancel)
+
+ def _abort(self):
+ self._aborting = True
+
+ for t in self._tasks:
+ if not t.done():
+ t.cancel()
+
+ def _on_task_done(self, task):
+ self._unfinished_tasks -= 1
+ assert self._unfinished_tasks >= 0
+
+ if self._exiting and not self._unfinished_tasks:
+ if not self._on_completed_fut.done():
+ self._on_completed_fut.set_result(True)
+
+ if task.cancelled():
+ return
+
+ exc = task.exception()
+ if exc is None:
+ return
+
+ self._errors.append(exc)
+ if self._is_base_error(exc) and self._base_error is None:
+ self._base_error = exc
+
+ if self._parent_task.done():
+ # Not sure if this case is possible, but we want to handle
+ # it anyways.
+ self._loop.call_exception_handler({
+ 'message': f'Task {task!r} has errored out but its parent '
+ f'task {self._parent_task} is already completed',
+ 'exception': exc,
+ 'task': task,
+ })
+ return
+
+ self._abort()
+ if not self._parent_task.__cancel_requested__:
+ # If parent task *is not* being cancelled, it means that we want
+ # to manually cancel it to abort whatever is being run right now
+ # in the TaskGroup. But we want to mark parent task as
+ # "not cancelled" later in __aexit__. Example situation that
+ # we need to handle:
+ #
+ # async def foo():
+ # try:
+ # async with TaskGroup() as g:
+ # g.create_task(crash_soon())
+ # await something # <- this needs to be canceled
+ # # by the TaskGroup, e.g.
+ # # foo() needs to be cancelled
+ # except Exception:
+ # # Ignore any exceptions raised in the TaskGroup
+ # pass
+ # await something_else # this line has to be called
+ # # after TaskGroup is finished.
+ self._parent_cancel_requested = True
+ self._parent_task.cancel()
+
+
+class MultiError(Exception):
+
+ def __init__(self, msg, *args, errors=()):
+ if errors:
+ types = set(type(e).__name__ for e in errors)
+ msg = f'{msg}; {len(errors)} sub errors: ({", ".join(types)})'
+ for er in errors:
+ msg += f'\n + {type(er).__name__}: {er}'
+ if er.__traceback__:
+ er_tb = ''.join(traceback.format_tb(er.__traceback__))
+ er_tb = textwrap.indent(er_tb, ' | ')
+ msg += f'\n{er_tb}\n'
+ super().__init__(msg, *args)
+ self.__errors__ = tuple(errors)
+
+ def get_error_types(self):
+ return {type(e) for e in self.__errors__}
+
+ def __reduce__(self):
+ return (type(self), (self.args,), {'__errors__': self.__errors__})
+
+
+class TaskGroupError(MultiError):
+ pass
+
+
+_name_counter = itertools.count(1).__next__
diff --git a/edgedb/_testbase.py b/gel/_testbase.py
similarity index 96%
rename from edgedb/_testbase.py
rename to gel/_testbase.py
index 167c9fb5..ef97d81b 100644
--- a/edgedb/_testbase.py
+++ b/gel/_testbase.py
@@ -32,9 +32,9 @@
import time
import unittest
-import edgedb
-from edgedb import asyncio_client
-from edgedb import blocking_client
+import gel
+from gel import asyncio_client
+from gel import blocking_client
log = logging.getLogger(__name__)
@@ -89,9 +89,9 @@ def _start_cluster(*, cleanup_atexit=True):
# not interfere with the server's.
env.pop('PYTHONPATH', None)
- edgedb_server = env.get('EDGEDB_SERVER_BINARY', 'edgedb-server')
+ gel_server = env.get('EDGEDB_SERVER_BINARY', 'edgedb-server')
args = [
- edgedb_server,
+ gel_server,
"--temp-dir",
"--testmode",
f"--emit-server-status={status_file_unix}",
@@ -100,7 +100,7 @@ def _start_cluster(*, cleanup_atexit=True):
"--bootstrap-command=ALTER ROLE edgedb { SET password := 'test' }",
]
- help_args = [edgedb_server, "--help"]
+ help_args = [gel_server, "--help"]
if sys.platform == 'win32':
help_args = ['wsl', '-u', 'edgedb'] + help_args
@@ -173,7 +173,7 @@ def _start_cluster(*, cleanup_atexit=True):
else:
con_args['tls_ca_file'] = data['tls_cert_file']
- client = edgedb.create_client(password='test', **con_args)
+ client = gel.create_client(password='test', **con_args)
client.ensure_connected()
client.execute("""
# Set session_idle_transaction_timeout to 5 minutes.
@@ -237,7 +237,7 @@ def wrapper(self, *args, __meth__=meth, **kwargs):
# retry the test.
self.loop.run_until_complete(
__meth__(self, *args, **kwargs))
- except edgedb.TransactionSerializationError:
+ except gel.TransactionSerializationError:
if try_no == 3:
raise
else:
@@ -335,7 +335,7 @@ def setUpClass(cls):
cls.cluster = _start_cluster(cleanup_atexit=True)
-class TestAsyncIOClient(edgedb.AsyncIOClient):
+class TestAsyncIOClient(gel.AsyncIOClient):
def _clear_codecs_cache(self):
self._impl.codecs_registry.clear_cache()
@@ -352,7 +352,7 @@ def is_proto_lt_1_0(self):
return self.connection._protocol.is_legacy
-class TestClient(edgedb.Client):
+class TestClient(gel.Client):
@property
def connection(self):
return self._impl._holders[0]._con
@@ -560,12 +560,12 @@ def tearDownClass(cls):
try:
cls.adapt_call(
cls.admin_client.execute(script))
- except edgedb.errors.ExecutionError:
+ except gel.errors.ExecutionError:
if i < retry - 1:
time.sleep(0.1)
else:
raise
- except edgedb.errors.UnknownDatabaseError:
+ except gel.errors.UnknownDatabaseError:
break
except Exception:
diff --git a/gel/_version.py b/gel/_version.py
new file mode 100644
index 00000000..41998d40
--- /dev/null
+++ b/gel/_version.py
@@ -0,0 +1,31 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This file MUST NOT contain anything but the __version__ assignment.
+#
+# When making a release, change the value of __version__
+# to an appropriate value, and open a pull request against
+# the correct branch (master if making a new feature release).
+# The commit message MUST contain a properly formatted release
+# log, and the commit must be signed.
+#
+# The release automation will: build and test the packages for the
+# supported platforms, publish the packages on PyPI, merge the PR
+# to the target branch, create a Git tag pointing to the commit.
+
+__version__ = '3.0.0b2'
diff --git a/gel/abstract.py b/gel/abstract.py
new file mode 100644
index 00000000..0c2c06cc
--- /dev/null
+++ b/gel/abstract.py
@@ -0,0 +1,437 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from __future__ import annotations
+import abc
+import dataclasses
+import typing
+
+from . import describe
+from . import enums
+from . import options
+from .protocol import protocol
+
+__all__ = (
+ "QueryWithArgs",
+ "QueryCache",
+ "QueryOptions",
+ "QueryContext",
+ "Executor",
+ "AsyncIOExecutor",
+ "ReadOnlyExecutor",
+ "AsyncIOReadOnlyExecutor",
+ "DescribeContext",
+ "DescribeResult",
+)
+
+
+class QueryWithArgs(typing.NamedTuple):
+ query: str
+ args: typing.Tuple
+ kwargs: typing.Dict[str, typing.Any]
+ input_language: protocol.InputLanguage = protocol.InputLanguage.EDGEQL
+
+
+class QueryCache(typing.NamedTuple):
+ codecs_registry: protocol.CodecsRegistry
+ query_cache: protocol.LRUMapping
+
+
+class QueryOptions(typing.NamedTuple):
+ output_format: protocol.OutputFormat
+ expect_one: bool
+ required_one: bool
+
+
+class QueryContext(typing.NamedTuple):
+ query: QueryWithArgs
+ cache: QueryCache
+ query_options: QueryOptions
+ retry_options: typing.Optional[options.RetryOptions]
+ state: typing.Optional[options.State]
+ warning_handler: options.WarningHandler
+
+ def lower(
+ self, *, allow_capabilities: enums.Capability
+ ) -> protocol.ExecuteContext:
+ return protocol.ExecuteContext(
+ query=self.query.query,
+ args=self.query.args,
+ kwargs=self.query.kwargs,
+ reg=self.cache.codecs_registry,
+ qc=self.cache.query_cache,
+ input_language=self.query.input_language,
+ output_format=self.query_options.output_format,
+ expect_one=self.query_options.expect_one,
+ required_one=self.query_options.required_one,
+ allow_capabilities=allow_capabilities,
+ state=self.state.as_dict() if self.state else None,
+ )
+
+
+class ExecuteContext(typing.NamedTuple):
+ query: QueryWithArgs
+ cache: QueryCache
+ state: typing.Optional[options.State]
+ warning_handler: options.WarningHandler
+
+ def lower(
+ self, *, allow_capabilities: enums.Capability
+ ) -> protocol.ExecuteContext:
+ return protocol.ExecuteContext(
+ query=self.query.query,
+ args=self.query.args,
+ kwargs=self.query.kwargs,
+ reg=self.cache.codecs_registry,
+ qc=self.cache.query_cache,
+ input_language=self.query.input_language,
+ output_format=protocol.OutputFormat.NONE,
+ allow_capabilities=allow_capabilities,
+ state=self.state.as_dict() if self.state else None,
+ )
+
+
+@dataclasses.dataclass
+class DescribeContext:
+ query: str
+ state: typing.Optional[options.State]
+ inject_type_names: bool
+ input_language: protocol.InputLanguage
+ output_format: protocol.OutputFormat
+ expect_one: bool
+
+ def lower(
+ self, *, allow_capabilities: enums.Capability
+ ) -> protocol.ExecuteContext:
+ return protocol.ExecuteContext(
+ query=self.query,
+ args=None,
+ kwargs=None,
+ reg=protocol.CodecsRegistry(),
+ qc=protocol.LRUMapping(maxsize=1),
+ input_language=self.input_language,
+ output_format=self.output_format,
+ expect_one=self.expect_one,
+ inline_typenames=self.inject_type_names,
+ allow_capabilities=allow_capabilities,
+ state=self.state.as_dict() if self.state else None,
+ )
+
+
+@dataclasses.dataclass
+class DescribeResult:
+ input_type: typing.Optional[describe.AnyType]
+ output_type: typing.Optional[describe.AnyType]
+ output_cardinality: enums.Cardinality
+ capabilities: enums.Capability
+
+
+_query_opts = QueryOptions(
+ output_format=protocol.OutputFormat.BINARY,
+ expect_one=False,
+ required_one=False,
+)
+_query_single_opts = QueryOptions(
+ output_format=protocol.OutputFormat.BINARY,
+ expect_one=True,
+ required_one=False,
+)
+_query_required_single_opts = QueryOptions(
+ output_format=protocol.OutputFormat.BINARY,
+ expect_one=True,
+ required_one=True,
+)
+_query_json_opts = QueryOptions(
+ output_format=protocol.OutputFormat.JSON,
+ expect_one=False,
+ required_one=False,
+)
+_query_single_json_opts = QueryOptions(
+ output_format=protocol.OutputFormat.JSON,
+ expect_one=True,
+ required_one=False,
+)
+_query_required_single_json_opts = QueryOptions(
+ output_format=protocol.OutputFormat.JSON,
+ expect_one=True,
+ required_one=True,
+)
+
+
+class BaseReadOnlyExecutor(abc.ABC):
+ __slots__ = ()
+
+ @abc.abstractmethod
+ def _get_query_cache(self) -> QueryCache:
+ ...
+
+ def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
+ return None
+
+ @abc.abstractmethod
+ def _get_state(self) -> options.State:
+ ...
+
+ @abc.abstractmethod
+ def _get_warning_handler(self) -> options.WarningHandler:
+ ...
+
+
+class ReadOnlyExecutor(BaseReadOnlyExecutor):
+ """Subclasses can execute *at least* read-only queries"""
+
+ __slots__ = ()
+
+ @abc.abstractmethod
+ def _query(self, query_context: QueryContext):
+ ...
+
+ def query(self, query: str, *args, **kwargs) -> list:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_single(
+ self, query: str, *args, **kwargs
+ ) -> typing.Union[typing.Any, None]:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_single_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_required_single_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_json(self, query: str, *args, **kwargs) -> str:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_single_json(self, query: str, *args, **kwargs) -> str:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_single_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_required_single_json(self, query: str, *args, **kwargs) -> str:
+ return self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_required_single_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
+ return self._query(QueryContext(
+ query=QueryWithArgs(
+ query,
+ args,
+ kwargs,
+ input_language=protocol.InputLanguage.SQL,
+ ),
+ cache=self._get_query_cache(),
+ query_options=_query_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ @abc.abstractmethod
+ def _execute(self, execute_context: ExecuteContext):
+ ...
+
+ def execute(self, commands: str, *args, **kwargs) -> None:
+ self._execute(ExecuteContext(
+ query=QueryWithArgs(commands, args, kwargs),
+ cache=self._get_query_cache(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ def execute_sql(self, commands: str, *args, **kwargs) -> None:
+ self._execute(ExecuteContext(
+ query=QueryWithArgs(
+ commands,
+ args,
+ kwargs,
+ input_language=protocol.InputLanguage.SQL,
+ ),
+ cache=self._get_query_cache(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+
+class Executor(ReadOnlyExecutor):
+ """Subclasses can execute both read-only and modification queries"""
+
+ __slots__ = ()
+
+
+class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor):
+ """Subclasses can execute *at least* read-only queries"""
+
+ __slots__ = ()
+
+ @abc.abstractmethod
+ async def _query(self, query_context: QueryContext):
+ ...
+
+ async def query(self, query: str, *args, **kwargs) -> list:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_single_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_required_single(
+ self,
+ query: str,
+ *args,
+ **kwargs
+ ) -> typing.Any:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_required_single_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_json(self, query: str, *args, **kwargs) -> str:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_single_json(self, query: str, *args, **kwargs) -> str:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_single_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_required_single_json(
+ self,
+ query: str,
+ *args,
+ **kwargs
+ ) -> str:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(query, args, kwargs),
+ cache=self._get_query_cache(),
+ query_options=_query_required_single_json_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def query_sql(self, query: str, *args, **kwargs) -> typing.Any:
+ return await self._query(QueryContext(
+ query=QueryWithArgs(
+ query,
+ args,
+ kwargs,
+ input_language=protocol.InputLanguage.SQL,
+ ),
+ cache=self._get_query_cache(),
+ query_options=_query_opts,
+ retry_options=self._get_retry_options(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ @abc.abstractmethod
+ async def _execute(self, execute_context: ExecuteContext) -> None:
+ ...
+
+ async def execute(self, commands: str, *args, **kwargs) -> None:
+ await self._execute(ExecuteContext(
+ query=QueryWithArgs(commands, args, kwargs),
+ cache=self._get_query_cache(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
+ await self._execute(ExecuteContext(
+ query=QueryWithArgs(
+ commands,
+ args,
+ kwargs,
+ input_language=protocol.InputLanguage.SQL,
+ ),
+ cache=self._get_query_cache(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+
+class AsyncIOExecutor(AsyncIOReadOnlyExecutor):
+ """Subclasses can execute both read-only and modification queries"""
+
+ __slots__ = ()
diff --git a/gel/ai/__init__.py b/gel/ai/__init__.py
new file mode 100644
index 00000000..96111c2b
--- /dev/null
+++ b/gel/ai/__init__.py
@@ -0,0 +1,32 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
+from .core import create_ai, EdgeDBAI
+from .core import create_async_ai, AsyncEdgeDBAI
+
+__all__ = [
+ "AIOptions",
+ "ChatParticipantRole",
+ "Prompt",
+ "QueryContext",
+ "create_ai",
+ "EdgeDBAI",
+ "create_async_ai",
+ "AsyncEdgeDBAI",
+]
diff --git a/gel/ai/core.py b/gel/ai/core.py
new file mode 100644
index 00000000..2822bae4
--- /dev/null
+++ b/gel/ai/core.py
@@ -0,0 +1,191 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import annotations
+import typing
+
+import gel
+import httpx
+import httpx_sse
+
+from . import types
+
+
+def create_ai(client: gel.Client, **kwargs) -> EdgeDBAI:
+ client.ensure_connected()
+ return EdgeDBAI(client, types.AIOptions(**kwargs))
+
+
+async def create_async_ai(
+ client: gel.AsyncIOClient, **kwargs
+) -> AsyncEdgeDBAI:
+ await client.ensure_connected()
+ return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))
+
+
+class BaseEdgeDBAI:
+ options: types.AIOptions
+ context: types.QueryContext
+ client_cls = NotImplemented
+
+ def __init__(
+ self,
+ client: typing.Union[gel.Client, gel.AsyncIOClient],
+ options: types.AIOptions,
+ **kwargs,
+ ):
+ pool = client._impl
+ host, port = pool._working_addr
+ params = pool._working_params
+ proto = "http" if params.tls_security == "insecure" else "https"
+ branch = params.branch
+ self.options = options
+ self.context = types.QueryContext(**kwargs)
+ args = dict(
+ base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
+ verify=params.ssl_ctx,
+ )
+ if params.password is not None:
+ args["auth"] = (params.user, params.password)
+ elif params.secret_key is not None:
+ args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
+ self._init_client(**args)
+
+ def _init_client(self, **kwargs):
+ raise NotImplementedError
+
+ def with_config(self, **kwargs) -> typing.Self:
+ cls = type(self)
+ rv = cls.__new__(cls)
+ rv.options = self.options.derive(kwargs)
+ rv.context = self.context
+ rv.client = self.client
+ return rv
+
+ def with_context(self, **kwargs) -> typing.Self:
+ cls = type(self)
+ rv = cls.__new__(cls)
+ rv.options = self.options
+ rv.context = self.context.derive(kwargs)
+ rv.client = self.client
+ return rv
+
+ def _make_rag_request(
+ self,
+ *,
+ message: str,
+ context: typing.Optional[types.QueryContext] = None,
+ stream: bool,
+ ) -> types.RAGRequest:
+ if context is None:
+ context = self.context
+ return types.RAGRequest(
+ model=self.options.model,
+ prompt=self.options.prompt,
+ context=context,
+ query=message,
+ stream=stream,
+ )
+
+
+class EdgeDBAI(BaseEdgeDBAI):
+ client: httpx.Client
+
+ def _init_client(self, **kwargs):
+ self.client = httpx.Client(**kwargs)
+
+ def query_rag(
+ self, message: str, context: typing.Optional[types.QueryContext] = None
+ ) -> str:
+ resp = self.client.post(
+ **self._make_rag_request(
+ context=context,
+ message=message,
+ stream=False,
+ ).to_httpx_request()
+ )
+ resp.raise_for_status()
+ return resp.json()["response"]
+
+ def stream_rag(
+ self, message: str, context: typing.Optional[types.QueryContext] = None
+ ) -> typing.Iterator[str]:
+ with httpx_sse.connect_sse(
+ self.client,
+ "post",
+ **self._make_rag_request(
+ context=context,
+ message=message,
+ stream=True,
+ ).to_httpx_request(),
+ ) as event_source:
+ event_source.response.raise_for_status()
+ for sse in event_source.iter_sse():
+ yield sse.data
+
+ def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
+ resp = self.client.post(
+ "/embeddings", json={"input": inputs, "model": model}
+ )
+ resp.raise_for_status()
+ return resp.json()["data"][0]["embedding"]
+
+
+class AsyncEdgeDBAI(BaseEdgeDBAI):
+ client: httpx.AsyncClient
+
+ def _init_client(self, **kwargs):
+ self.client = httpx.AsyncClient(**kwargs)
+
+ async def query_rag(
+ self, message: str, context: typing.Optional[types.QueryContext] = None
+ ) -> str:
+ resp = await self.client.post(
+ **self._make_rag_request(
+ context=context,
+ message=message,
+ stream=False,
+ ).to_httpx_request()
+ )
+ resp.raise_for_status()
+ return resp.json()["response"]
+
+ async def stream_rag(
+ self, message: str, context: typing.Optional[types.QueryContext] = None
+ ) -> typing.Iterator[str]:
+ async with httpx_sse.aconnect_sse(
+ self.client,
+ "post",
+ **self._make_rag_request(
+ context=context,
+ message=message,
+ stream=True,
+ ).to_httpx_request(),
+ ) as event_source:
+ event_source.response.raise_for_status()
+ async for sse in event_source.aiter_sse():
+ yield sse.data
+
+ async def generate_embeddings(
+ self, *inputs: str, model: str
+ ) -> list[float]:
+ resp = await self.client.post(
+ "/embeddings", json={"input": inputs, "model": model}
+ )
+ resp.raise_for_status()
+ return resp.json()["data"][0]["embedding"]
diff --git a/gel/ai/types.py b/gel/ai/types.py
new file mode 100644
index 00000000..41bf24c0
--- /dev/null
+++ b/gel/ai/types.py
@@ -0,0 +1,81 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import typing
+
+import dataclasses as dc
+import enum
+
+
+class ChatParticipantRole(enum.Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ TOOL = "tool"
+
+
+class Custom(typing.TypedDict):
+ role: ChatParticipantRole
+ content: str
+
+
+class Prompt:
+ name: typing.Optional[str]
+ id: typing.Optional[str]
+ custom: typing.Optional[typing.List[Custom]]
+
+
+@dc.dataclass
+class AIOptions:
+ model: str
+ prompt: typing.Optional[Prompt] = None
+
+ def derive(self, kwargs):
+ return AIOptions(**{**dc.asdict(self), **kwargs})
+
+
+@dc.dataclass
+class QueryContext:
+ query: str = ""
+ variables: typing.Optional[typing.Dict[str, typing.Any]] = None
+ globals: typing.Optional[typing.Dict[str, typing.Any]] = None
+ max_object_count: typing.Optional[int] = None
+
+ def derive(self, kwargs):
+ return QueryContext(**{**dc.asdict(self), **kwargs})
+
+
+@dc.dataclass
+class RAGRequest:
+ model: str
+ prompt: typing.Optional[Prompt]
+ context: QueryContext
+ query: str
+ stream: typing.Optional[bool]
+
+ def to_httpx_request(self) -> typing.Dict[str, typing.Any]:
+ return dict(
+ url="/rag",
+ headers={
+ "Content-Type": "application/json",
+ "Accept": (
+ "text/event-stream" if self.stream else "application/json"
+ ),
+ },
+ json=dc.asdict(self),
+ )
diff --git a/gel/asyncio_client.py b/gel/asyncio_client.py
new file mode 100644
index 00000000..81cbc808
--- /dev/null
+++ b/gel/asyncio_client.py
@@ -0,0 +1,448 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import asyncio
+import contextlib
+import logging
+import socket
+import ssl
+import typing
+
+from . import abstract
+from . import base_client
+from . import con_utils
+from . import errors
+from . import transaction
+from .protocol import asyncio_proto
+from .protocol.protocol import InputLanguage, OutputFormat
+
+
+__all__ = (
+ 'create_async_client', 'AsyncIOClient'
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncIOConnection(base_client.BaseConnection):
+ __slots__ = ("_loop",)
+
+ def __init__(self, loop, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._loop = loop
+
+ def is_closed(self):
+ protocol = self._protocol
+ return protocol is None or not protocol.connected
+
+ async def connect_addr(self, addr, timeout):
+ try:
+ await asyncio.wait_for(self._connect_addr(addr), timeout)
+ except asyncio.TimeoutError as e:
+ raise TimeoutError from e
+
+ async def sleep(self, seconds):
+ await asyncio.sleep(seconds)
+
+ async def aclose(self):
+ """Send graceful termination message wait for connection to drop."""
+ if not self.is_closed():
+ try:
+ self._protocol.terminate()
+ await self._protocol.wait_for_disconnect()
+ except (Exception, asyncio.CancelledError):
+ self.terminate()
+ raise
+ finally:
+ self._cleanup()
+
+ def _protocol_factory(self):
+ return asyncio_proto.AsyncIOProtocol(self._params, self._loop)
+
+ async def _connect_addr(self, addr):
+ tr = None
+
+ try:
+ if isinstance(addr, str):
+ # UNIX socket
+ tr, pr = await self._loop.create_unix_connection(
+ self._protocol_factory, addr
+ )
+ else:
+ try:
+ tr, pr = await self._loop.create_connection(
+ self._protocol_factory,
+ *addr,
+ ssl=self._params.ssl_ctx,
+ server_hostname=(
+ self._params.tls_server_name or addr[0]
+ ),
+ )
+ except ssl.CertificateError as e:
+ raise con_utils.wrap_error(e) from e
+ except ssl.SSLError as e:
+ raise con_utils.wrap_error(e) from e
+ else:
+ con_utils.check_alpn_protocol(
+ tr.get_extra_info('ssl_object')
+ )
+ except socket.gaierror as e:
+ # All name resolution errors are considered temporary
+ raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e
+ except OSError as e:
+ raise con_utils.wrap_error(e) from e
+ except Exception:
+ if tr is not None:
+ tr.close()
+ raise
+
+ pr.set_connection(self)
+
+ try:
+ await pr.connect()
+ except OSError as e:
+ if tr is not None:
+ tr.close()
+ raise con_utils.wrap_error(e) from e
+ except BaseException:
+ if tr is not None:
+ tr.close()
+ raise
+
+ self._protocol = pr
+ self._addr = addr
+
+ def _dispatch_log_message(self, msg):
+ for cb in self._log_listeners:
+ self._loop.call_soon(cb, self, msg)
+
+
+class _PoolConnectionHolder(base_client.PoolConnectionHolder):
+ __slots__ = ()
+ _event_class = asyncio.Event
+
+ async def close(self, *, wait=True):
+ if self._con is None:
+ return
+ if wait:
+ await self._con.aclose()
+ else:
+ self._pool._loop.create_task(self._con.aclose())
+
+ async def wait_until_released(self, timeout=None):
+ await self._release_event.wait()
+
+
+class _AsyncIOPoolImpl(base_client.BasePoolImpl):
+ __slots__ = ('_loop',)
+ _holder_class = _PoolConnectionHolder
+
+ def __init__(
+ self,
+ connect_args,
+ *,
+ max_concurrency: typing.Optional[int],
+ connection_class,
+ ):
+ if not issubclass(connection_class, AsyncIOConnection):
+ raise TypeError(
+ f'connection_class is expected to be a subclass of '
+ f'gel.asyncio_client.AsyncIOConnection, '
+ f'got {connection_class}')
+ self._loop = None
+ super().__init__(
+ connect_args,
+ lambda *args: connection_class(self._loop, *args),
+ max_concurrency=max_concurrency,
+ )
+
+ def _ensure_initialized(self):
+ if self._loop is None:
+ self._loop = asyncio.get_event_loop()
+ self._queue = asyncio.LifoQueue(maxsize=self._max_concurrency)
+ self._first_connect_lock = asyncio.Lock()
+ self._resize_holder_pool()
+
+ def _set_queue_maxsize(self, maxsize):
+ self._queue._maxsize = maxsize
+
+ async def _maybe_get_first_connection(self):
+ async with self._first_connect_lock:
+ if self._working_addr is None:
+ return await self._get_first_connection()
+
+ async def acquire(self, timeout=None):
+ self._ensure_initialized()
+
+ async def _acquire_impl():
+ ch = await self._queue.get() # type: _PoolConnectionHolder
+ try:
+ proxy = await ch.acquire() # type: AsyncIOConnection
+ except (Exception, asyncio.CancelledError):
+ self._queue.put_nowait(ch)
+ raise
+ else:
+ # Record the timeout, as we will apply it by default
+ # in release().
+ ch._timeout = timeout
+ return proxy
+
+ if self._closing:
+ raise errors.InterfaceError('pool is closing')
+
+ if timeout is None:
+ return await _acquire_impl()
+ else:
+ return await asyncio.wait_for(
+ _acquire_impl(), timeout=timeout)
+
+ async def _release(self, holder):
+
+ if not isinstance(holder._con, AsyncIOConnection):
+ raise errors.InterfaceError(
+ f'release() received invalid connection: '
+ f'{holder._con!r} does not belong to any connection pool'
+ )
+
+ timeout = None
+
+ # Use asyncio.shield() to guarantee that task cancellation
+ # does not prevent the connection from being returned to the
+ # pool properly.
+ return await asyncio.shield(holder.release(timeout))
+
+ async def aclose(self):
+ """Attempt to gracefully close all connections in the pool.
+
+ Wait until all pool connections are released, close them and
+ shut down the pool. If any error (including cancellation) occurs
+ in ``close()`` the pool will terminate by calling
+ _AsyncIOPoolImpl.terminate() .
+
+ It is advisable to use :func:`python:asyncio.wait_for` to set
+ a timeout.
+ """
+ if self._closed:
+ return
+
+ if not self._loop:
+ self._closed = True
+ return
+
+ self._closing = True
+
+ try:
+ warning_callback = self._loop.call_later(
+ 60, self._warn_on_long_close)
+
+ release_coros = [
+ ch.wait_until_released() for ch in self._holders]
+ await asyncio.gather(*release_coros)
+
+ close_coros = [
+ ch.close() for ch in self._holders]
+ await asyncio.gather(*close_coros)
+
+ except (Exception, asyncio.CancelledError):
+ self.terminate()
+ raise
+
+ finally:
+ warning_callback.cancel()
+ self._closed = True
+ self._closing = False
+
+ def _warn_on_long_close(self):
+ logger.warning(
+ 'AsyncIOClient.aclose() is taking over 60 seconds to complete. '
+ 'Check if you have any unreleased connections left. '
+ 'Use asyncio.wait_for() to set a timeout for '
+ 'AsyncIOClient.aclose().')
+
+
+class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor):
+
+ __slots__ = ("_managed", "_locked")
+
+ def __init__(self, retry, client, iteration):
+ super().__init__(retry, client, iteration)
+ self._managed = False
+ self._locked = False
+
+ async def __aenter__(self):
+ if self._managed:
+ raise errors.InterfaceError(
+ 'cannot enter context: already in an `async with` block')
+ self._managed = True
+ return self
+
+ async def __aexit__(self, extype, ex, tb):
+ with self._exclusive():
+ self._managed = False
+ return await self._exit(extype, ex)
+
+ async def _ensure_transaction(self):
+ if not self._managed:
+ raise errors.InterfaceError(
+ "Only managed retriable transactions are supported. "
+ "Use `async with transaction:`"
+ )
+ await super()._ensure_transaction()
+
+ async def _query(self, query_context: abstract.QueryContext):
+ with self._exclusive():
+ return await super()._query(query_context)
+
+ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ with self._exclusive():
+ await super()._execute(execute_context)
+
+ @contextlib.contextmanager
+ def _exclusive(self):
+ if self._locked:
+ raise errors.InterfaceError(
+ "concurrent queries within the same transaction "
+ "are not allowed"
+ )
+ self._locked = True
+ try:
+ yield
+ finally:
+ self._locked = False
+
+
+class AsyncIORetry(transaction.BaseRetry):
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ # Note: when changing this code consider also
+ # updating Retry.__next__.
+ if self._done:
+ raise StopAsyncIteration
+ if self._next_backoff:
+ await asyncio.sleep(self._next_backoff)
+ self._done = True
+ iteration = AsyncIOIteration(self, self._owner, self._iteration)
+ self._iteration += 1
+ return iteration
+
+
+class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor):
+ """A lazy connection pool.
+
+ A Client can be used to manage a set of connections to the database.
+ Connections are first acquired from the pool, then used, and then released
+ back to the pool. Once a connection is released, it's reset to close all
+ open cursors and other resources *except* prepared statements.
+
+ Clients are created by calling
+ :func:`~gel.asyncio_client.create_async_client`.
+ """
+
+ __slots__ = ()
+ _impl_class = _AsyncIOPoolImpl
+
+ async def ensure_connected(self):
+ await self._impl.ensure_connected()
+ return self
+
+ async def aclose(self):
+ """Attempt to gracefully close all connections in the pool.
+
+ Wait until all pool connections are released, close them and
+ shut down the pool. If any error (including cancellation) occurs
+ in ``aclose()`` the pool will terminate by calling
+ AsyncIOClient.terminate() .
+
+ It is advisable to use :func:`python:asyncio.wait_for` to set
+ a timeout.
+ """
+ await self._impl.aclose()
+
+ def transaction(self) -> AsyncIORetry:
+ return AsyncIORetry(self)
+
+ async def __aenter__(self):
+ return await self.ensure_connected()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.aclose()
+
+ async def _describe_query(
+ self,
+ query: str,
+ *,
+ inject_type_names: bool = False,
+ input_language: InputLanguage = InputLanguage.EDGEQL,
+ output_format: OutputFormat = OutputFormat.BINARY,
+ expect_one: bool = False,
+ ) -> abstract.DescribeResult:
+ return await self._describe(abstract.DescribeContext(
+ query=query,
+ state=self._get_state(),
+ inject_type_names=inject_type_names,
+ input_language=input_language,
+ output_format=output_format,
+ expect_one=expect_one,
+ ))
+
+
+def create_async_client(
+ dsn=None,
+ *,
+ max_concurrency=None,
+ host: str = None,
+ port: int = None,
+ credentials: str = None,
+ credentials_file: str = None,
+ user: str = None,
+ password: str = None,
+ secret_key: str = None,
+ database: str = None,
+ branch: str = None,
+ tls_ca: str = None,
+ tls_ca_file: str = None,
+ tls_security: str = None,
+ wait_until_available: int = 30,
+ timeout: int = 10,
+):
+ return AsyncIOClient(
+ connection_class=AsyncIOConnection,
+ max_concurrency=max_concurrency,
+
+ # connect arguments
+ dsn=dsn,
+ host=host,
+ port=port,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ user=user,
+ password=password,
+ secret_key=secret_key,
+ database=database,
+ branch=branch,
+ tls_ca=tls_ca,
+ tls_ca_file=tls_ca_file,
+ tls_security=tls_security,
+ wait_until_available=wait_until_available,
+ timeout=timeout,
+ )
diff --git a/gel/base_client.py b/gel/base_client.py
new file mode 100644
index 00000000..5bc1a86a
--- /dev/null
+++ b/gel/base_client.py
@@ -0,0 +1,734 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import abc
+import random
+import time
+import typing
+
+from . import abstract
+from . import con_utils
+from . import enums
+from . import errors
+from . import options as _options
+from .protocol import protocol
+
+
+BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection')
+QUERY_CACHE_SIZE = 1000
+
+
+class BaseConnection(metaclass=abc.ABCMeta):
+ _protocol: typing.Any
+ _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]]
+ _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]]
+ _config: con_utils.ClientConfiguration
+ _params: con_utils.ResolvedConnectConfig
+ _log_listeners: typing.Set[
+ typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None]
+ ]
+ __slots__ = (
+ "__weakref__",
+ "_protocol",
+ "_addr",
+ "_addrs",
+ "_config",
+ "_params",
+ "_log_listeners",
+ "_holder",
+ )
+
+ def __init__(
+ self,
+ addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]],
+ config: con_utils.ClientConfiguration,
+ params: con_utils.ResolvedConnectConfig,
+ ):
+ self._addr = None
+ self._protocol = None
+ self._addrs = addrs
+ self._config = config
+ self._params = params
+ self._log_listeners = set()
+ self._holder = None
+
+ @abc.abstractmethod
+ def _dispatch_log_message(self, msg):
+ ...
+
+ def _on_log_message(self, msg):
+ if self._log_listeners:
+ self._dispatch_log_message(msg)
+
+ def connected_addr(self):
+ return self._addr
+
+ def _get_last_status(self) -> typing.Optional[str]:
+ if self._protocol is None:
+ return None
+ status = self._protocol.last_status
+ if status is not None:
+ status = status.decode()
+ return status
+
+ def _cleanup(self):
+ self._log_listeners.clear()
+ if self._holder:
+ self._holder._release_on_close()
+ self._holder = None
+
+ def add_log_listener(
+ self: BaseConnection_T,
+ callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage],
+ None]
+ ) -> None:
+ """Add a listener for EdgeDB log messages.
+
+ :param callable callback:
+ A callable receiving the following arguments:
+ **connection**: a Connection the callback is registered with;
+ **message**: the `gel.EdgeDBMessage` message.
+ """
+ self._log_listeners.add(callback)
+
+ def remove_log_listener(
+ self: BaseConnection_T,
+ callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage],
+ None]
+ ) -> None:
+ """Remove a listening callback for log messages."""
+ self._log_listeners.discard(callback)
+
+ @property
+ def dbname(self) -> str:
+ return self._params.database
+
+ @property
+ def branch(self) -> str:
+ return self._params.branch
+
+ @abc.abstractmethod
+ def is_closed(self) -> bool:
+ ...
+
+ @abc.abstractmethod
+ async def connect_addr(self, addr, timeout):
+ ...
+
+ @abc.abstractmethod
+ async def sleep(self, seconds):
+ ...
+
+ async def connect(self, *, single_attempt=False):
+ start = time.monotonic()
+ if single_attempt:
+ max_time = 0
+ else:
+ max_time = start + self._config.wait_until_available
+ iteration = 1
+
+ while True:
+ for addr in self._addrs:
+ try:
+ await self.connect_addr(addr, self._config.connect_timeout)
+ except TimeoutError as e:
+ if iteration == 1 or time.monotonic() < max_time:
+ continue
+ else:
+ raise errors.ClientConnectionTimeoutError(
+ f"connecting to {addr} failed in"
+ f" {self._config.connect_timeout} sec"
+ ) from e
+ except errors.ClientConnectionError as e:
+ if (
+ e.has_tag(errors.SHOULD_RECONNECT) and
+ (iteration == 1 or time.monotonic() < max_time)
+ ):
+ continue
+ nice_err = e.__class__(
+ con_utils.render_client_no_connection_error(
+ e,
+ addr,
+ attempts=iteration,
+ duration=time.monotonic() - start,
+ ))
+ raise nice_err from e.__cause__
+ else:
+ return
+
+ iteration += 1
+ await self.sleep(0.01 + random.random() * 0.2)
+
+ async def privileged_execute(
+ self, execute_context: abstract.ExecuteContext
+ ):
+ if self._protocol.is_legacy:
+ await self._protocol.legacy_simple_query(
+ execute_context.query.query, enums.Capability.ALL
+ )
+ else:
+ await self._protocol.execute(
+ execute_context.lower(allow_capabilities=enums.Capability.ALL)
+ )
+
+ def is_in_transaction(self) -> bool:
+ """Return True if Connection is currently inside a transaction.
+
+ :return bool: True if inside transaction, False otherwise.
+ """
+ return self._protocol.is_in_transaction()
+
+ def get_settings(self) -> typing.Dict[str, typing.Any]:
+ return self._protocol.get_settings()
+
+ async def raw_query(self, query_context: abstract.QueryContext):
+ if self.is_closed():
+ await self.connect()
+
+ reconnect = False
+ i = 0
+ if self._protocol.is_legacy:
+ allow_capabilities = enums.Capability.LEGACY_EXECUTE
+ else:
+ allow_capabilities = enums.Capability.EXECUTE
+ ctx = query_context.lower(allow_capabilities=allow_capabilities)
+ while True:
+ i += 1
+ try:
+ if reconnect:
+ await self.connect(single_attempt=True)
+ if self._protocol.is_legacy:
+ return await self._protocol.legacy_execute_anonymous(ctx)
+ else:
+ res = await self._protocol.query(ctx)
+ if ctx.warnings:
+ res = query_context.warning_handler(ctx.warnings, res)
+ return res
+
+ except errors.EdgeDBError as e:
+ if query_context.retry_options is None:
+ raise
+ if not e.has_tag(errors.SHOULD_RETRY):
+ raise e
+ # A query is read-only if it has no capabilities i.e.
+ # capabilities == 0. Read-only queries are safe to retry.
+ # Explicit transaction conflicts as well.
+ if (
+ ctx.capabilities != 0
+ and not isinstance(e, errors.TransactionConflictError)
+ ):
+ raise e
+ rule = query_context.retry_options.get_rule_for_exception(e)
+ if i >= rule.attempts:
+ raise e
+ await self.sleep(rule.backoff(i))
+ reconnect = self.is_closed()
+
+ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ if self._protocol.is_legacy:
+ if execute_context.query.args or execute_context.query.kwargs:
+ raise errors.InterfaceError(
+ "Legacy protocol doesn't support arguments in execute()"
+ )
+ await self._protocol.legacy_simple_query(
+ execute_context.query.query, enums.Capability.LEGACY_EXECUTE
+ )
+ else:
+ ctx = execute_context.lower(
+ allow_capabilities=enums.Capability.EXECUTE
+ )
+ res = await self._protocol.execute(ctx)
+ if ctx.warnings:
+ res = execute_context.warning_handler(ctx.warnings, res)
+
+ async def describe(
+ self, describe_context: abstract.DescribeContext
+ ) -> abstract.DescribeResult:
+ ctx = describe_context.lower(
+ allow_capabilities=enums.Capability.EXECUTE
+ )
+ await self._protocol._parse(ctx)
+ return abstract.DescribeResult(
+ input_type=ctx.in_dc.make_type(describe_context),
+ output_type=ctx.out_dc.make_type(describe_context),
+ output_cardinality=enums.Cardinality(ctx.cardinality[0]),
+ capabilities=ctx.capabilities,
+ )
+
+ def terminate(self):
+ if not self.is_closed():
+ try:
+ self._protocol.abort()
+ finally:
+ self._cleanup()
+
+ def __repr__(self):
+ if self.is_closed():
+ return '<{classname} [closed] {id:#x}>'.format(
+ classname=self.__class__.__name__, id=id(self))
+ else:
+ return '<{classname} [connected to {addr}] {id:#x}>'.format(
+ classname=self.__class__.__name__,
+ addr=self.connected_addr(),
+ id=id(self))
+
+
+class PoolConnectionHolder(abc.ABC):
+ __slots__ = (
+ "_con",
+ "_pool",
+ "_release_event",
+ "_timeout",
+ "_generation",
+ )
+ _event_class = NotImplemented
+
+ def __init__(self, pool):
+
+ self._pool = pool
+ self._con = None
+
+ self._timeout = None
+ self._generation = None
+
+ self._release_event = self._event_class()
+ self._release_event.set()
+
+ @abc.abstractmethod
+ async def close(self, *, wait=True):
+ ...
+
+ @abc.abstractmethod
+ async def wait_until_released(self, timeout=None):
+ ...
+
+ async def connect(self):
+ if self._con is not None:
+ raise errors.InternalClientError(
+ 'PoolConnectionHolder.connect() called while another '
+ 'connection already exists')
+
+ self._con = await self._pool._get_new_connection()
+ assert self._con._holder is None
+ self._con._holder = self
+ self._generation = self._pool._generation
+
+ async def acquire(self) -> BaseConnection:
+ if self._con is None or self._con.is_closed():
+ self._con = None
+ await self.connect()
+
+ elif self._generation != self._pool._generation:
+ # Connections have been expired, re-connect the holder.
+ self._con._holder = None # don't release the connection
+ await self.close(wait=False)
+ self._con = None
+ await self.connect()
+
+ self._release_event.clear()
+
+ return self._con
+
+ async def release(self, timeout):
+ if self._release_event.is_set():
+ raise errors.InternalClientError(
+ 'PoolConnectionHolder.release() called on '
+ 'a free connection holder')
+
+ if self._con.is_closed():
+ # This is usually the case when the connection is broken rather
+ # than closed by the user, so we need to call _release_on_close()
+ # here to release the holder back to the queue, because
+ # self._con._cleanup() was never called. On the other hand, it is
+ # safe to call self._release() twice - the second call is no-op.
+ self._release_on_close()
+ return
+
+ self._timeout = None
+
+ if self._generation != self._pool._generation:
+ # The connection has expired because it belongs to
+ # an older generation (BasePoolImpl.expire_connections() has
+ # been called.)
+ await self.close()
+ return
+
+ # Free this connection holder and invalidate the
+ # connection proxy.
+ self._release()
+
+ def terminate(self):
+ if self._con is not None:
+ # AsyncIOConnection.terminate() will call _release_on_close() to
+ # finish holder cleanup.
+ self._con.terminate()
+
+ def _release_on_close(self):
+ self._release()
+ self._con = None
+
+ def _release(self):
+ """Release this connection holder."""
+ if self._release_event.is_set():
+ # The holder is not checked out.
+ return
+
+ self._release_event.set()
+
+ # Put ourselves back to the pool queue.
+ self._pool._queue.put_nowait(self)
+
+
+class BasePoolImpl(abc.ABC):
+ __slots__ = (
+ "_connect_args",
+ "_codecs_registry",
+ "_query_cache",
+ "_connection_factory",
+ "_queue",
+ "_user_max_concurrency",
+ "_max_concurrency",
+ "_first_connect_lock",
+ "_working_addr",
+ "_working_config",
+ "_working_params",
+ "_holders",
+ "_initialized",
+ "_initializing",
+ "_closing",
+ "_closed",
+ "_generation",
+ )
+
+ _holder_class = NotImplemented
+
+ def __init__(
+ self,
+ connect_args,
+ connection_factory,
+ *,
+ max_concurrency: typing.Optional[int],
+ ):
+ self._connection_factory = connection_factory
+ self._connect_args = connect_args
+ self._codecs_registry = protocol.CodecsRegistry()
+ self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)
+
+ if max_concurrency is not None and max_concurrency <= 0:
+ raise ValueError(
+ 'max_concurrency is expected to be greater than zero'
+ )
+
+ self._user_max_concurrency = max_concurrency
+ self._max_concurrency = max_concurrency if max_concurrency else 1
+
+ self._holders = []
+ self._queue = None
+
+ self._first_connect_lock = None
+ self._working_addr = None
+ self._working_config = None
+ self._working_params = None
+
+ self._closing = False
+ self._closed = False
+ self._generation = 0
+
+ @abc.abstractmethod
+ def _ensure_initialized(self):
+ ...
+
+ @abc.abstractmethod
+ def _set_queue_maxsize(self, maxsize):
+ ...
+
+ @abc.abstractmethod
+ async def _maybe_get_first_connection(self):
+ ...
+
+ @abc.abstractmethod
+ async def acquire(self, timeout=None):
+ ...
+
+ @abc.abstractmethod
+ async def _release(self, connection):
+ ...
+
+ @property
+ def codecs_registry(self):
+ return self._codecs_registry
+
+ @property
+ def query_cache(self):
+ return self._query_cache
+
+ def _resize_holder_pool(self):
+ resize_diff = self._max_concurrency - len(self._holders)
+
+ if (resize_diff > 0):
+ if self._queue.maxsize != self._max_concurrency:
+ self._set_queue_maxsize(self._max_concurrency)
+
+ for _ in range(resize_diff):
+ ch = self._holder_class(self)
+
+ self._holders.append(ch)
+ self._queue.put_nowait(ch)
+ elif resize_diff < 0:
+ # TODO: shrink the pool
+ pass
+
+ def get_max_concurrency(self):
+ return self._max_concurrency
+
+ def get_free_size(self):
+ if self._queue is None:
+ # Queue has not been initialized yet
+ return self._max_concurrency
+
+ return self._queue.qsize()
+
+ def set_connect_args(self, dsn=None, **connect_kwargs):
+ r"""Set the new connection arguments for this pool.
+
+ The new connection arguments will be used for all subsequent
+ new connection attempts. Existing connections will remain until
+ they expire. Use BasePoolImpl.expire_connections() to expedite
+ the connection expiry.
+
+ :param str dsn:
+ Connection arguments specified using as a single string in
+ the following format:
+ ``gel://user:pass@host:port/database?option=value``.
+
+ :param \*\*connect_kwargs:
+ Keyword arguments for the
+ :func:`~gel.asyncio_client.create_async_client` function.
+ """
+
+ connect_kwargs["dsn"] = dsn
+ self._connect_args = connect_kwargs
+ self._codecs_registry = protocol.CodecsRegistry()
+ self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)
+ self._working_addr = None
+ self._working_config = None
+ self._working_params = None
+
+ async def _get_first_connection(self):
+ # First connection attempt on this pool.
+ connect_config, client_config = con_utils.parse_connect_arguments(
+ **self._connect_args,
+ # ToDos
+ command_timeout=None,
+ server_settings=None,
+ )
+ con = self._connection_factory(
+ [connect_config.address], client_config, connect_config
+ )
+ await con.connect()
+ self._working_addr = con.connected_addr()
+ self._working_config = client_config
+ self._working_params = connect_config
+
+ if self._user_max_concurrency is None:
+ suggested_concurrency = con.get_settings().get(
+ 'suggested_pool_concurrency')
+ if suggested_concurrency:
+ self._max_concurrency = suggested_concurrency
+ self._resize_holder_pool()
+ return con
+
+ async def _get_new_connection(self):
+ con = None
+ if self._working_addr is None:
+ con = await self._maybe_get_first_connection()
+ if con is None:
+ assert self._working_addr is not None
+ # We've connected before and have a resolved address,
+ # and parsed options and config.
+ con = self._connection_factory(
+ [self._working_addr],
+ self._working_config,
+ self._working_params,
+ )
+ await con.connect()
+
+ return con
+
+ async def release(self, connection):
+
+ if not isinstance(connection, BaseConnection):
+ raise errors.InterfaceError(
+ f'BasePoolImpl.release() received invalid connection: '
+ f'{connection!r} does not belong to any connection pool'
+ )
+
+ ch = connection._holder
+ if ch is None:
+ # Already released, do nothing.
+ return
+
+ if ch._pool is not self:
+ raise errors.InterfaceError(
+ f'BasePoolImpl.release() received invalid connection: '
+ f'{connection!r} is not a member of this pool'
+ )
+
+ return await self._release(ch)
+
+ def terminate(self):
+ """Terminate all connections in the pool."""
+ if self._closed:
+ return
+ for ch in self._holders:
+ ch.terminate()
+ self._closed = True
+
+ def expire_connections(self):
+ """Expire all currently open connections.
+
+ Cause all currently open connections to get replaced on the
+ next query.
+ """
+ self._generation += 1
+
+ async def ensure_connected(self):
+ self._ensure_initialized()
+
+ for ch in self._holders:
+ if ch._con is not None and not ch._con.is_closed():
+ return
+
+ ch = self._holders[0]
+ ch._con = None
+ await ch.connect()
+
+
+class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin):
+ __slots__ = ("_impl", "_options")
+ _impl_class = NotImplemented
+
+ def __init__(
+ self,
+ *,
+ connection_class,
+ max_concurrency: typing.Optional[int],
+ dsn=None,
+ host: str = None,
+ port: int = None,
+ credentials: str = None,
+ credentials_file: str = None,
+ user: str = None,
+ password: str = None,
+ secret_key: str = None,
+ database: str = None,
+ branch: str = None,
+ tls_ca: str = None,
+ tls_ca_file: str = None,
+ tls_security: str = None,
+ tls_server_name: str = None,
+ wait_until_available: int = 30,
+ timeout: int = 10,
+ **kwargs,
+ ):
+ super().__init__()
+ connect_args = {
+ "dsn": dsn,
+ "host": host,
+ "port": port,
+ "credentials": credentials,
+ "credentials_file": credentials_file,
+ "user": user,
+ "password": password,
+ "secret_key": secret_key,
+ "database": database,
+ "branch": branch,
+ "timeout": timeout,
+ "tls_ca": tls_ca,
+ "tls_ca_file": tls_ca_file,
+ "tls_security": tls_security,
+ "tls_server_name": tls_server_name,
+ "wait_until_available": wait_until_available,
+ }
+
+ self._impl = self._impl_class(
+ connect_args,
+ connection_class=connection_class,
+ max_concurrency=max_concurrency,
+ **kwargs,
+ )
+
+ def _shallow_clone(self):
+ new_client = self.__class__.__new__(self.__class__)
+ new_client._impl = self._impl
+ return new_client
+
+ def _get_query_cache(self) -> abstract.QueryCache:
+ return abstract.QueryCache(
+ codecs_registry=self._impl.codecs_registry,
+ query_cache=self._impl.query_cache,
+ )
+
+ def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]:
+ return self._options.retry_options
+
+ def _get_state(self) -> _options.State:
+ return self._options.state
+
+ def _get_warning_handler(self) -> _options.WarningHandler:
+ return self._options.warning_handler
+
+ @property
+ def max_concurrency(self) -> int:
+ """Max number of connections in the pool."""
+
+ return self._impl.get_max_concurrency()
+
+ @property
+ def free_size(self) -> int:
+ """Number of available connections in the pool."""
+
+ return self._impl.get_free_size()
+
+ async def _query(self, query_context: abstract.QueryContext):
+ con = await self._impl.acquire()
+ try:
+ return await con.raw_query(query_context)
+ finally:
+ await self._impl.release(con)
+
+ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ con = await self._impl.acquire()
+ try:
+ await con._execute(execute_context)
+ finally:
+ await self._impl.release(con)
+
+ async def _describe(
+ self, describe_context: abstract.DescribeContext
+ ) -> abstract.DescribeResult:
+ con = await self._impl.acquire()
+ try:
+ return await con.describe(describe_context)
+ finally:
+ await self._impl.release(con)
+
+ def terminate(self):
+ """Terminate all connections in the pool."""
+ self._impl.terminate()
diff --git a/gel/blocking_client.py b/gel/blocking_client.py
new file mode 100644
index 00000000..797cdcb0
--- /dev/null
+++ b/gel/blocking_client.py
@@ -0,0 +1,490 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import contextlib
+import datetime
+import queue
+import socket
+import ssl
+import threading
+import time
+import typing
+
+from . import abstract
+from . import base_client
+from . import con_utils
+from . import errors
+from . import transaction
+from .protocol import blocking_proto
+from .protocol.protocol import InputLanguage, OutputFormat
+
+
+DEFAULT_PING_BEFORE_IDLE_TIMEOUT = datetime.timedelta(seconds=5)
+MINIMUM_PING_WAIT_TIME = datetime.timedelta(seconds=1)
+
+
+class BlockingIOConnection(base_client.BaseConnection):
+ __slots__ = ("_ping_wait_time",)
+
+ async def connect_addr(self, addr, timeout):
+ deadline = time.monotonic() + timeout
+
+ if isinstance(addr, str):
+ # UNIX socket
+ res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)]
+ else:
+ host, port = addr
+ try:
+ # getaddrinfo() doesn't take timeout!!
+ res_list = socket.getaddrinfo(
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ )
+ except socket.gaierror as e:
+ # All name resolution errors are considered temporary
+ err = errors.ClientConnectionFailedTemporarilyError(str(e))
+ raise err from e
+
+ for i, res in enumerate(res_list):
+ af, socktype, proto, _, sa = res
+ try:
+ sock = socket.socket(af, socktype, proto)
+ except OSError as e:
+ sock.close()
+ if i < len(res_list) - 1:
+ continue
+ else:
+ raise con_utils.wrap_error(e) from e
+ try:
+ await self._connect_addr(sock, addr, sa, deadline)
+ except TimeoutError:
+ raise
+ except Exception:
+ if i < len(res_list) - 1:
+ continue
+ else:
+ raise
+ else:
+ break
+
+ async def _connect_addr(self, sock, addr, sa, deadline):
+ try:
+ time_left = deadline - time.monotonic()
+ if time_left <= 0:
+ raise TimeoutError
+ try:
+ sock.settimeout(time_left)
+ sock.connect(sa)
+ except OSError as e:
+ raise con_utils.wrap_error(e) from e
+
+ if not isinstance(addr, str):
+ time_left = deadline - time.monotonic()
+ if time_left <= 0:
+ raise TimeoutError
+ try:
+ # Upgrade to TLS
+ sock.settimeout(time_left)
+ try:
+ sock = self._params.ssl_ctx.wrap_socket(
+ sock,
+ server_hostname=(
+ self._params.tls_server_name or addr[0]
+ ),
+ )
+ except ssl.CertificateError as e:
+ raise con_utils.wrap_error(e) from e
+ except ssl.SSLError as e:
+ raise con_utils.wrap_error(e) from e
+ else:
+ con_utils.check_alpn_protocol(sock)
+ except OSError as e:
+ raise con_utils.wrap_error(e) from e
+
+ time_left = deadline - time.monotonic()
+ if time_left <= 0:
+ raise TimeoutError
+
+ if not isinstance(addr, str):
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ proto = blocking_proto.BlockingIOProtocol(self._params, sock)
+ proto.set_connection(self)
+
+ try:
+ await proto.wait_for(proto.connect(), time_left)
+ except TimeoutError:
+ raise
+ except OSError as e:
+ raise con_utils.wrap_error(e) from e
+
+ self._protocol = proto
+ self._addr = addr
+ self._ping_wait_time = max(
+ (
+ self.get_settings()
+ .get("system_config")
+ .session_idle_timeout
+ - DEFAULT_PING_BEFORE_IDLE_TIMEOUT
+ ),
+ MINIMUM_PING_WAIT_TIME,
+ ).total_seconds()
+
+ except Exception:
+ sock.close()
+ raise
+
+ async def sleep(self, seconds):
+ time.sleep(seconds)
+
+ def is_closed(self):
+ proto = self._protocol
+ return not (proto and proto.sock is not None and
+ proto.sock.fileno() >= 0 and proto.connected)
+
+ async def close(self, timeout=None):
+ """Send graceful termination message wait for connection to drop."""
+ if not self.is_closed():
+ try:
+ self._protocol.terminate()
+ if timeout is None:
+ await self._protocol.wait_for_disconnect()
+ else:
+ await self._protocol.wait_for(
+ self._protocol.wait_for_disconnect(), timeout
+ )
+ except TimeoutError:
+ self.terminate()
+ raise errors.QueryTimeoutError()
+ except Exception:
+ self.terminate()
+ raise
+ finally:
+ self._cleanup()
+
+ def _dispatch_log_message(self, msg):
+ for cb in self._log_listeners:
+ cb(self, msg)
+
+ async def raw_query(self, query_context: abstract.QueryContext):
+ try:
+ if (
+ time.monotonic() - self._protocol.last_active_timestamp
+ > self._ping_wait_time
+ ):
+ await self._protocol.ping()
+ except (errors.IdleSessionTimeoutError, errors.ClientConnectionError):
+ await self.connect()
+
+ return await super().raw_query(query_context)
+
+
+class _PoolConnectionHolder(base_client.PoolConnectionHolder):
+ __slots__ = ()
+ _event_class = threading.Event
+
+ async def close(self, *, wait=True, timeout=None):
+ if self._con is None:
+ return
+ await self._con.close(timeout=timeout)
+
+ async def wait_until_released(self, timeout=None):
+ return self._release_event.wait(timeout)
+
+
+class _PoolImpl(base_client.BasePoolImpl):
+ _holder_class = _PoolConnectionHolder
+
+ def __init__(
+ self,
+ connect_args,
+ *,
+ max_concurrency: typing.Optional[int],
+ connection_class,
+ ):
+ if not issubclass(connection_class, BlockingIOConnection):
+ raise TypeError(
+ f'connection_class is expected to be a subclass of '
+ f'gel.blocking_client.BlockingIOConnection, '
+ f'got {connection_class}')
+ super().__init__(
+ connect_args,
+ connection_class,
+ max_concurrency=max_concurrency,
+ )
+
+ def _ensure_initialized(self):
+ if self._queue is None:
+ self._queue = queue.LifoQueue(maxsize=self._max_concurrency)
+ self._first_connect_lock = threading.Lock()
+ self._resize_holder_pool()
+
+ def _set_queue_maxsize(self, maxsize):
+ with self._queue.mutex:
+ self._queue.maxsize = maxsize
+
+ async def _maybe_get_first_connection(self):
+ with self._first_connect_lock:
+ if self._working_addr is None:
+ return await self._get_first_connection()
+
+ async def acquire(self, timeout=None):
+ self._ensure_initialized()
+
+ if self._closing:
+ raise errors.InterfaceError('pool is closing')
+
+ ch = self._queue.get(timeout=timeout)
+ try:
+ con = await ch.acquire()
+ except Exception:
+ self._queue.put_nowait(ch)
+ raise
+ else:
+ # Record the timeout, as we will apply it by default
+ # in release().
+ ch._timeout = timeout
+ return con
+
+ async def _release(self, holder):
+ if not isinstance(holder._con, BlockingIOConnection):
+ raise errors.InterfaceError(
+ f'release() received invalid connection: '
+ f'{holder._con!r} does not belong to any connection pool'
+ )
+
+ timeout = None
+ return await holder.release(timeout)
+
+ async def close(self, timeout=None):
+ if self._closed:
+ return
+ self._closing = True
+ try:
+ if timeout is None:
+ for ch in self._holders:
+ await ch.wait_until_released()
+ for ch in self._holders:
+ await ch.close()
+ else:
+ deadline = time.monotonic() + timeout
+ for ch in self._holders:
+ secs = deadline - time.monotonic()
+ if secs <= 0:
+ raise TimeoutError
+ if not await ch.wait_until_released(secs):
+ raise TimeoutError
+ for ch in self._holders:
+ secs = deadline - time.monotonic()
+ if secs <= 0:
+ raise TimeoutError
+ await ch.close(timeout=secs)
+ except TimeoutError as e:
+ self.terminate()
+ raise errors.InterfaceError(
+ "client is not fully closed in {} seconds; "
+ "terminating now.".format(timeout)
+ ) from e
+ except Exception:
+ self.terminate()
+ raise
+ finally:
+ self._closed = True
+ self._closing = False
+
+
+class Iteration(transaction.BaseTransaction, abstract.Executor):
+
+ __slots__ = ("_managed", "_lock")
+
+ def __init__(self, retry, client, iteration):
+ super().__init__(retry, client, iteration)
+ self._managed = False
+ self._lock = threading.Lock()
+
+ def __enter__(self):
+ with self._exclusive():
+ if self._managed:
+ raise errors.InterfaceError(
+ 'cannot enter context: already in a `with` block')
+ self._managed = True
+ return self
+
+ def __exit__(self, extype, ex, tb):
+ with self._exclusive():
+ self._managed = False
+ return self._client._iter_coroutine(self._exit(extype, ex))
+
+ async def _ensure_transaction(self):
+ if not self._managed:
+ raise errors.InterfaceError(
+ "Only managed retriable transactions are supported. "
+ "Use `with transaction:`"
+ )
+ await super()._ensure_transaction()
+
+ def _query(self, query_context: abstract.QueryContext):
+ with self._exclusive():
+ return self._client._iter_coroutine(super()._query(query_context))
+
+ def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ with self._exclusive():
+ self._client._iter_coroutine(super()._execute(execute_context))
+
+ @contextlib.contextmanager
+ def _exclusive(self):
+ if not self._lock.acquire(blocking=False):
+ raise errors.InterfaceError(
+ "concurrent queries within the same transaction "
+ "are not allowed"
+ )
+ try:
+ yield
+ finally:
+ self._lock.release()
+
+
+class Retry(transaction.BaseRetry):
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ # Note: when changing this code consider also
+ # updating AsyncIORetry.__anext__.
+ if self._done:
+ raise StopIteration
+ if self._next_backoff:
+ time.sleep(self._next_backoff)
+ self._done = True
+ iteration = Iteration(self, self._owner, self._iteration)
+ self._iteration += 1
+ return iteration
+
+
+class Client(base_client.BaseClient, abstract.Executor):
+ """A lazy connection pool.
+
+ A Client can be used to manage a set of connections to the database.
+ Connections are first acquired from the pool, then used, and then released
+ back to the pool. Once a connection is released, it's reset to close all
+ open cursors and other resources *except* prepared statements.
+
+ Clients are created by calling
+ :func:`~gel.blocking_client.create_client`.
+ """
+
+ __slots__ = ()
+ _impl_class = _PoolImpl
+
+ def _iter_coroutine(self, coro):
+ try:
+ coro.send(None)
+ except StopIteration as ex:
+ return ex.value
+ finally:
+ coro.close()
+
+ def _query(self, query_context: abstract.QueryContext):
+ return self._iter_coroutine(super()._query(query_context))
+
+ def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ self._iter_coroutine(super()._execute(execute_context))
+
+ def ensure_connected(self):
+ self._iter_coroutine(self._impl.ensure_connected())
+ return self
+
+ def transaction(self) -> Retry:
+ return Retry(self)
+
+ def close(self, timeout=None):
+ """Attempt to gracefully close all connections in the client.
+
+ Wait until all pool connections are released, close them and
+ shut down the pool. If any error (including cancellation) occurs
+ in ``close()`` the pool will terminate by calling
+ Client.terminate() .
+ """
+ self._iter_coroutine(self._impl.close(timeout))
+
+ def __enter__(self):
+ return self.ensure_connected()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def _describe_query(
+ self,
+ query: str,
+ *,
+ inject_type_names: bool = False,
+ input_language: InputLanguage = InputLanguage.EDGEQL,
+ output_format: OutputFormat = OutputFormat.BINARY,
+ expect_one: bool = False,
+ ) -> abstract.DescribeResult:
+ return self._iter_coroutine(self._describe(abstract.DescribeContext(
+ query=query,
+ state=self._get_state(),
+ inject_type_names=inject_type_names,
+ input_language=input_language,
+ output_format=output_format,
+ expect_one=expect_one,
+ )))
+
+
+def create_client(
+ dsn=None,
+ *,
+ max_concurrency=None,
+ host: str = None,
+ port: int = None,
+ credentials: str = None,
+ credentials_file: str = None,
+ user: str = None,
+ password: str = None,
+ secret_key: str = None,
+ database: str = None,
+ branch: str = None,
+ tls_ca: str = None,
+ tls_ca_file: str = None,
+ tls_security: str = None,
+ wait_until_available: int = 30,
+ timeout: int = 10,
+):
+ return Client(
+ connection_class=BlockingIOConnection,
+ max_concurrency=max_concurrency,
+
+ # connect arguments
+ dsn=dsn,
+ host=host,
+ port=port,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ user=user,
+ password=password,
+ secret_key=secret_key,
+ database=database,
+ branch=branch,
+ tls_ca=tls_ca,
+ tls_ca_file=tls_ca_file,
+ tls_security=tls_security,
+ wait_until_available=wait_until_available,
+ timeout=timeout,
+ )
diff --git a/edgedb/codegen/__init__.py b/gel/codegen/__init__.py
similarity index 100%
rename from edgedb/codegen/__init__.py
rename to gel/codegen/__init__.py
diff --git a/edgedb/codegen/__main__.py b/gel/codegen/__main__.py
similarity index 100%
rename from edgedb/codegen/__main__.py
rename to gel/codegen/__main__.py
diff --git a/edgedb/codegen/cli.py b/gel/codegen/cli.py
similarity index 100%
rename from edgedb/codegen/cli.py
rename to gel/codegen/cli.py
diff --git a/edgedb/codegen/generator.py b/gel/codegen/generator.py
similarity index 94%
rename from edgedb/codegen/generator.py
rename to gel/codegen/generator.py
index 08fbdc6e..b5c70ddb 100644
--- a/edgedb/codegen/generator.py
+++ b/gel/codegen/generator.py
@@ -25,11 +25,11 @@
import textwrap
import typing
-import edgedb
-from edgedb import abstract
-from edgedb import describe
-from edgedb.con_utils import find_edgedb_project_dir
-from edgedb.color import get_color
+import gel
+from gel import abstract
+from gel import describe
+from gel.con_utils import find_gel_project_dir
+from gel.color import get_color
C = get_color()
@@ -64,9 +64,9 @@
"cal::local_date": "datetime.date",
"cal::local_time": "datetime.time",
"cal::local_datetime": "datetime.datetime",
- "cal::relative_duration": "edgedb.RelativeDuration",
- "cal::date_duration": "edgedb.DateDuration",
- "cfg::memory": "edgedb.ConfigMemory",
+ "cal::relative_duration": "gel.RelativeDuration",
+ "cal::date_duration": "gel.DateDuration",
+ "cfg::memory": "gel.ConfigMemory",
"ext::pgvector::vector": "array.array",
}
@@ -158,15 +158,15 @@ def __init__(self, args: argparse.Namespace):
self._skip_pydantic_validation = args.skip_pydantic_validation
self._async = False
try:
- self._project_dir = pathlib.Path(find_edgedb_project_dir())
- except edgedb.ClientConnectionError:
+ self._project_dir = pathlib.Path(find_gel_project_dir())
+ except gel.ClientConnectionError:
print(
- "Cannot find edgedb.toml: "
+ "Cannot find gel.toml: "
"codegen must be run under an EdgeDB project dir"
)
sys.exit(2)
print_msg(f"Found EdgeDB project: {C.BOLD}{self._project_dir}{C.ENDC}")
- self._client = edgedb.create_client(**_get_conn_args(args))
+ self._client = gel.create_client(**_get_conn_args(args))
self._single_mode_files = args.file
self._search_dirs = []
for search_dir in args.dir or []:
@@ -203,7 +203,7 @@ def _new_file(self):
def run(self):
try:
self._client.ensure_connected()
- except edgedb.EdgeDBError as e:
+ except gel.EdgeDBError as e:
print(f"Failed to connect to EdgeDB instance: {e}")
sys.exit(61)
with self._client:
@@ -301,7 +301,7 @@ def _write_comments(
cmd = []
if sys.argv[0].endswith("__main__.py"):
cmd.append(pathlib.Path(sys.executable).name)
- cmd.extend(["-m", "edgedb.codegen"])
+ cmd.extend(["-m", "gel.codegen"])
else:
cmd.append(pathlib.Path(sys.argv[0]).name)
cmd.extend(sys.argv[1:])
@@ -346,7 +346,7 @@ def _generate(
else:
self._imports.add("typing")
out_type = f"typing.List[{out_type}]"
- elif dr.output_cardinality == edgedb.Cardinality.AT_MOST_ONE:
+ elif dr.output_cardinality == gel.Cardinality.AT_MOST_ONE:
if SYS_VERSION_INFO >= (3, 10):
out_type = f"{out_type} | None"
else:
@@ -377,11 +377,11 @@ def _generate(
print(f"async def {name}(", file=buf)
else:
print(f"def {name}(", file=buf)
- self._imports.add("edgedb")
+ self._imports.add("gel")
if self._async:
- print(f"{INDENT}executor: edgedb.AsyncIOExecutor,", file=buf)
+ print(f"{INDENT}executor: gel.AsyncIOExecutor,", file=buf)
else:
- print(f"{INDENT}executor: edgedb.Executor,", file=buf)
+ print(f"{INDENT}executor: gel.Executor,", file=buf)
if kw_only:
print(f"{INDENT}*,", file=buf)
for name, arg in args.items():
@@ -390,7 +390,7 @@ def _generate(
if dr.output_cardinality.is_multi():
method = "query"
rt = "return "
- elif dr.output_cardinality == edgedb.Cardinality.NO_RESULT:
+ elif dr.output_cardinality == gel.Cardinality.NO_RESULT:
method = "execute"
rt = ""
else:
@@ -485,7 +485,7 @@ def _generate_code(
el_code = self._generate_code_with_cardinality(
element.type, name_hint, element.cardinality
)
- if element.kind == edgedb.ElementKind.LINK_PROPERTY:
+ if element.kind == gel.ElementKind.LINK_PROPERTY:
link_props.append((el_name, el_code))
else:
print(f"{INDENT}{el_name}: {el_code}", file=buf)
@@ -536,7 +536,7 @@ def _generate_code(
elif isinstance(type_, describe.RangeType):
value = self._generate_code(type_.value_type, name_hint, is_input)
- rv = f"edgedb.Range[{value}]"
+ rv = f"gel.Range[{value}]"
else:
rv = "??"
@@ -548,12 +548,12 @@ def _generate_code_with_cardinality(
self,
type_: typing.Optional[describe.AnyType],
name_hint: str,
- cardinality: edgedb.Cardinality,
+ cardinality: gel.Cardinality,
keyword_argument: bool = False,
is_input: bool = False,
):
rv = self._generate_code(type_, name_hint, is_input)
- if cardinality == edgedb.Cardinality.AT_MOST_ONE:
+ if cardinality == gel.Cardinality.AT_MOST_ONE:
if SYS_VERSION_INFO >= (3, 10):
rv = f"{rv} | None"
else:
diff --git a/gel/color.py b/gel/color.py
new file mode 100644
index 00000000..2b972aaa
--- /dev/null
+++ b/gel/color.py
@@ -0,0 +1,61 @@
+import os
+import sys
+import warnings
+
+COLOR = None
+
+
+class Color:
+ HEADER = ""
+ BLUE = ""
+ CYAN = ""
+ GREEN = ""
+ WARNING = ""
+ FAIL = ""
+ ENDC = ""
+ BOLD = ""
+ UNDERLINE = ""
+
+
+def get_color() -> Color:
+ global COLOR
+
+ if COLOR is None:
+ COLOR = Color()
+ if type(USE_COLOR) is bool:
+ use_color = USE_COLOR
+ else:
+ try:
+ use_color = USE_COLOR()
+ except Exception:
+ use_color = False
+ if use_color:
+ COLOR.HEADER = '\033[95m'
+ COLOR.BLUE = '\033[94m'
+ COLOR.CYAN = '\033[96m'
+ COLOR.GREEN = '\033[92m'
+ COLOR.WARNING = '\033[93m'
+ COLOR.FAIL = '\033[91m'
+ COLOR.ENDC = '\033[0m'
+ COLOR.BOLD = '\033[1m'
+ COLOR.UNDERLINE = '\033[4m'
+
+ return COLOR
+
+
+try:
+ USE_COLOR = {
+ "default": lambda: sys.stderr.isatty(),
+ "auto": lambda: sys.stderr.isatty(),
+ "enabled": True,
+ "disabled": False,
+ }[
+ os.getenv("EDGEDB_COLOR_OUTPUT", "default")
+ ]
+except KeyError:
+ warnings.warn(
+ "EDGEDB_COLOR_OUTPUT can only be one of: "
+ "default, auto, enabled or disabled",
+ stacklevel=1,
+ )
+ USE_COLOR = False
diff --git a/gel/con_utils.py b/gel/con_utils.py
new file mode 100644
index 00000000..9a593d9f
--- /dev/null
+++ b/gel/con_utils.py
@@ -0,0 +1,1278 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import base64
+import binascii
+import errno
+import json
+import os
+import re
+import ssl
+import typing
+import urllib.parse
+import warnings
+import hashlib
+
+from . import errors
+from . import credentials as cred_utils
+from . import platform
+
+
+EDGEDB_PORT = 5656
+ERRNO_RE = re.compile(r"\[Errno (\d+)\]")
+TEMPORARY_ERRORS = (
+ ConnectionAbortedError,
+ ConnectionRefusedError,
+ ConnectionResetError,
+ FileNotFoundError,
+)
+TEMPORARY_ERROR_CODES = frozenset({
+ errno.ECONNREFUSED,
+ errno.ECONNABORTED,
+ errno.ECONNRESET,
+ errno.ENOENT,
+})
+
+ISO_SECONDS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)S')
+ISO_MINUTES_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M')
+ISO_HOURS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)H')
+ISO_UNITLESS_HOURS_RE = re.compile(r'^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$')
+ISO_DAYS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)D')
+ISO_WEEKS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)W')
+ISO_MONTHS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)M')
+ISO_YEARS_RE = re.compile(r'(-?\d+|-?\d+\.\d*|-?\d*\.\d+)Y')
+
+HUMAN_HOURS_RE = re.compile(
+ r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:h(\s|\d|\.|$)|hours?(\s|$))',
+)
+HUMAN_MINUTES_RE = re.compile(
+ r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:m(\s|\d|\.|$)|minutes?(\s|$))',
+)
+HUMAN_SECONDS_RE = re.compile(
+ r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:s(\s|\d|\.|$)|seconds?(\s|$))',
+)
+HUMAN_MS_RE = re.compile(
+ r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:ms(\s|\d|\.|$)|milliseconds?(\s|$))',
+)
+HUMAN_US_RE = re.compile(
+ r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))',
+)
+INSTANCE_NAME_RE = re.compile(
+ r'^(\w(?:-?\w)*)$',
+ re.ASCII,
+)
+CLOUD_INSTANCE_NAME_RE = re.compile(
+ r'^([A-Za-z0-9_-](?:-?[A-Za-z0-9_])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$',
+ re.ASCII,
+)
+DSN_RE = re.compile(
+ r'^[a-z]+://',
+ re.IGNORECASE,
+)
+DOMAIN_LABEL_MAX_LENGTH = 63
+
+
+class ClientConfiguration(typing.NamedTuple):
+
+ connect_timeout: float
+ command_timeout: float
+ wait_until_available: float
+
+
+def _validate_port_spec(hosts, port):
+ if isinstance(port, list):
+ # If there is a list of ports, its length must
+ # match that of the host list.
+ if len(port) != len(hosts):
+ raise errors.InterfaceError(
+ 'could not match {} port numbers to {} hosts'.format(
+ len(port), len(hosts)))
+ else:
+ port = [port for _ in range(len(hosts))]
+
+ return port
+
+
+def _parse_hostlist(hostlist, port):
+ if ',' in hostlist:
+ # A comma-separated list of host addresses.
+ hostspecs = hostlist.split(',')
+ else:
+ hostspecs = [hostlist]
+
+ hosts = []
+ hostlist_ports = []
+
+ if not port:
+ portspec = _getenv('PORT')
+ if portspec:
+ if ',' in portspec:
+ default_port = [int(p) for p in portspec.split(',')]
+ else:
+ default_port = int(portspec)
+ else:
+ default_port = EDGEDB_PORT
+
+ default_port = _validate_port_spec(hostspecs, default_port)
+
+ else:
+ port = _validate_port_spec(hostspecs, port)
+
+ for i, hostspec in enumerate(hostspecs):
+ addr, _, hostspec_port = hostspec.partition(':')
+ hosts.append(addr)
+
+ if not port:
+ if hostspec_port:
+ hostlist_ports.append(int(hostspec_port))
+ else:
+ hostlist_ports.append(default_port[i])
+
+ if not port:
+ port = hostlist_ports
+
+ return hosts, port
+
+
+def _hash_path(path):
+ path = os.path.realpath(path)
+ if platform.IS_WINDOWS and not path.startswith('\\\\'):
+ path = '\\\\?\\' + path
+ return hashlib.sha1(str(path).encode('utf-8')).hexdigest()
+
+
+def _stash_path(path):
+ base_name = os.path.basename(path)
+ dir_name = base_name + '-' + _hash_path(path)
+ return platform.search_config_dir('projects', dir_name)
+
+
+def _validate_tls_security(val: str) -> str:
+ val = val.lower()
+ if val not in {"insecure", "no_host_verification", "strict", "default"}:
+ raise ValueError(
+ "tls_security can only be one of "
+ "`insecure`, `no_host_verification`, `strict` or `default`"
+ )
+
+ return val
+
+
+def _getenv_and_key(key: str) -> typing.Tuple[typing.Optional[str], str]:
+ edgedb_key = f'EDGEDB_{key}'
+ edgedb_val = os.getenv(edgedb_key)
+ gel_key = f'GEL_{key}'
+ gel_val = os.getenv(gel_key)
+ if edgedb_val is not None and gel_val is not None:
+ warnings.warn(
+ f'Both {gel_key} and {edgedb_key} are set; '
+ f'{edgedb_key} will be ignored',
+ stacklevel=1,
+ )
+
+ if gel_val is None and edgedb_val is not None:
+ return edgedb_val, edgedb_key
+ else:
+ return gel_val, gel_key
+
+
+def _getenv(key: str) -> typing.Optional[str]:
+ return _getenv_and_key(key)[0]
+
+
+class ResolvedConnectConfig:
+ _host = None
+ _host_source = None
+
+ _port = None
+ _port_source = None
+
+ # We keep track of database and branch separately, because we want to make
+ # sure that we don't use both at the same time on the same configuration
+ # level.
+ _database = None
+ _database_source = None
+
+ _branch = None
+ _branch_source = None
+
+ _user = None
+ _user_source = None
+
+ _password = None
+ _password_source = None
+
+ _secret_key = None
+ _secret_key_source = None
+
+ _tls_ca_data = None
+ _tls_ca_data_source = None
+
+ _tls_server_name = None
+ _tls_security = None
+ _tls_security_source = None
+
+ _wait_until_available = None
+
+ _cloud_profile = None
+ _cloud_profile_source = None
+
+ server_settings = {}
+
+ def _set_param(self, param, value, source, validator=None):
+ param_name = '_' + param
+ if getattr(self, param_name) is None:
+ setattr(self, param_name + '_source', source)
+ if value is not None:
+ setattr(
+ self,
+ param_name,
+ validator(value) if validator else value
+ )
+
+ def set_host(self, host, source):
+ self._set_param('host', host, source, _validate_host)
+
+ def set_port(self, port, source):
+ self._set_param('port', port, source, _validate_port)
+
+ def set_database(self, database, source):
+ self._set_param('database', database, source, _validate_database)
+
+ def set_branch(self, branch, source):
+ self._set_param('branch', branch, source, _validate_branch)
+
+ def set_user(self, user, source):
+ self._set_param('user', user, source, _validate_user)
+
+ def set_password(self, password, source):
+ self._set_param('password', password, source)
+
+ def set_secret_key(self, secret_key, source):
+ self._set_param('secret_key', secret_key, source)
+
+ def set_tls_ca_data(self, ca_data, source):
+ self._set_param('tls_ca_data', ca_data, source)
+
+ def set_tls_ca_file(self, ca_file, source):
+ def read_ca_file(file_path):
+ with open(file_path) as f:
+ return f.read()
+
+ self._set_param('tls_ca_data', ca_file, source, read_ca_file)
+
+ def set_tls_server_name(self, ca_data, source):
+ self._set_param('tls_server_name', ca_data, source)
+
+ def set_tls_security(self, security, source):
+ self._set_param('tls_security', security, source,
+ _validate_tls_security)
+
+ def set_wait_until_available(self, wait_until_available, source):
+ self._set_param(
+ 'wait_until_available',
+ wait_until_available,
+ source,
+ _validate_wait_until_available,
+ )
+
+ def add_server_settings(self, server_settings):
+ _validate_server_settings(server_settings)
+ self.server_settings = {**server_settings, **self.server_settings}
+
+ @property
+ def address(self):
+ return (
+ self._host if self._host else 'localhost',
+ self._port if self._port else 5656
+ )
+
+ # The properties actually merge database and branch, but "default" is
+ # different. If you need to know the underlying config use the _database
+ # and _branch.
+ @property
+ def database(self):
+ return (
+ self._database if self._database else
+ self._branch if self._branch else
+ 'edgedb'
+ )
+
+ @property
+ def branch(self):
+ return (
+ self._database if self._database else
+ self._branch if self._branch else
+ '__default__'
+ )
+
+ @property
+ def user(self):
+ return self._user if self._user else 'edgedb'
+
+ @property
+ def password(self):
+ return self._password
+
+ @property
+ def secret_key(self):
+ return self._secret_key
+
+ @property
+ def tls_server_name(self):
+ return self._tls_server_name
+
+ @property
+ def tls_security(self):
+ tls_security = self._tls_security or 'default'
+ security, security_key = _getenv_and_key('CLIENT_SECURITY')
+ security = security or 'default'
+ if security not in {'default', 'insecure_dev_mode', 'strict'}:
+ raise ValueError(
+ f'environment variable {security_key} should be '
+ f'one of strict, insecure_dev_mode or default, '
+ f'got: {security!r}')
+
+ if security == 'default':
+ pass
+ elif security == 'insecure_dev_mode':
+ if tls_security == 'default':
+ tls_security = 'insecure'
+ elif security == 'strict':
+ if tls_security == 'default':
+ tls_security = 'strict'
+ elif tls_security in {'no_host_verification', 'insecure'}:
+ raise ValueError(
+ f'{security_key}=strict but '
+ f'tls_security={tls_security}, tls_security must be '
+ f'set to strict when {security_key} is strict')
+
+ if tls_security != 'default':
+ return tls_security
+
+ if self._tls_ca_data is not None:
+ return "no_host_verification"
+
+ return "strict"
+
+ _ssl_ctx = None
+
+ @property
+ def ssl_ctx(self):
+ if (self._ssl_ctx):
+ return self._ssl_ctx
+
+ self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+
+ if self._tls_ca_data:
+ self._ssl_ctx.load_verify_locations(
+ cadata=self._tls_ca_data
+ )
+ else:
+ self._ssl_ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
+ if platform.IS_WINDOWS:
+ import certifi
+ self._ssl_ctx.load_verify_locations(cafile=certifi.where())
+
+ tls_security = self.tls_security
+ self._ssl_ctx.check_hostname = tls_security == "strict"
+
+ if tls_security in {"strict", "no_host_verification"}:
+ self._ssl_ctx.verify_mode = ssl.CERT_REQUIRED
+ else:
+ self._ssl_ctx.verify_mode = ssl.CERT_NONE
+
+ self._ssl_ctx.set_alpn_protocols(['edgedb-binary'])
+
+ return self._ssl_ctx
+
+ @property
+ def wait_until_available(self):
+ return (
+ self._wait_until_available
+ if self._wait_until_available is not None
+ else 30
+ )
+
+
+def _validate_host(host):
+ if '/' in host:
+ raise ValueError('unix socket paths not supported')
+ if host == '' or ',' in host:
+ raise ValueError(f'invalid host: "{host}"')
+ return host
+
+
+def _prepare_host_for_dsn(host):
+ host = _validate_host(host)
+ if ':' in host:
+ # IPv6
+ host = f'[{host}]'
+ return host
+
+
+def _validate_port(port):
+ try:
+ if isinstance(port, str):
+ port = int(port)
+ if not isinstance(port, int):
+ raise ValueError()
+ except Exception:
+ raise ValueError(f'invalid port: {port}, not an integer')
+ if port < 1 or port > 65535:
+ raise ValueError(f'invalid port: {port}, must be between 1 and 65535')
+ return port
+
+
+def _validate_database(database):
+ if database == '':
+ raise ValueError(f'invalid database name: {database}')
+ return database
+
+
+def _validate_branch(branch):
+ if branch == '':
+ raise ValueError(f'invalid branch name: {branch}')
+ return branch
+
+
+def _validate_user(user):
+ if user == '':
+ raise ValueError(f'invalid user name: {user}')
+ return user
+
+
+def _pop_iso_unit(rgex: re.Pattern, string: str) -> typing.Tuple[float, str]:
+ s = string
+ total = 0
+ match = rgex.search(string)
+ if match:
+ total += float(match.group(1))
+ s = s.replace(match.group(0), "", 1)
+
+ return (total, s)
+
+
+def _parse_iso_duration(string: str) -> typing.Union[float, int]:
+ if not string.startswith("PT"):
+ raise ValueError(f"invalid duration {string!r}")
+
+ time = string[2:]
+ match = ISO_UNITLESS_HOURS_RE.search(time)
+ if match:
+ hours = float(match.group(0))
+ return 3600 * hours
+
+ hours, time = _pop_iso_unit(ISO_HOURS_RE, time)
+ minutes, time = _pop_iso_unit(ISO_MINUTES_RE, time)
+ seconds, time = _pop_iso_unit(ISO_SECONDS_RE, time)
+
+ if time:
+ raise ValueError(f'invalid duration {string!r}')
+
+ return 3600 * hours + 60 * minutes + seconds
+
+
+def _remove_white_space(s: str) -> str:
+ return ''.join(c for c in s if not c.isspace())
+
+
+def _pop_human_duration_unit(
+ rgex: re.Pattern,
+ string: str,
+) -> typing.Tuple[float, bool, str]:
+ match = rgex.search(string)
+ if not match:
+ return 0, False, string
+
+ number = 0
+ if match.group(1):
+ literal = _remove_white_space(match.group(1))
+ if literal.endswith('.'):
+ return 0, False, string
+
+ if literal.startswith('-.'):
+ return 0, False, string
+
+ number = float(literal)
+ string = string.replace(
+ match.group(0),
+ match.group(2) or match.group(3) or "",
+ 1,
+ )
+
+ return number, True, string
+
+
+def _parse_human_duration(string: str) -> float:
+ found = False
+
+ hour, f, s = _pop_human_duration_unit(HUMAN_HOURS_RE, string)
+ found |= f
+
+ minute, f, s = _pop_human_duration_unit(HUMAN_MINUTES_RE, s)
+ found |= f
+
+ second, f, s = _pop_human_duration_unit(HUMAN_SECONDS_RE, s)
+ found |= f
+
+ ms, f, s = _pop_human_duration_unit(HUMAN_MS_RE, s)
+ found |= f
+
+ us, f, s = _pop_human_duration_unit(HUMAN_US_RE, s)
+ found |= f
+
+ if s.strip() or not found:
+ raise ValueError(f'invalid duration {string!r}')
+
+ return 3600 * hour + 60 * minute + second + 0.001 * ms + 0.000001 * us
+
+
+def _parse_duration_str(string: str) -> float:
+ if string.startswith('PT'):
+ return _parse_iso_duration(string)
+ return _parse_human_duration(string)
+
+
+def _validate_wait_until_available(wait_until_available):
+ if isinstance(wait_until_available, str):
+ return _parse_duration_str(wait_until_available)
+
+ if isinstance(wait_until_available, (int, float)):
+ return wait_until_available
+
+ raise ValueError(f"invalid duration {wait_until_available!r}")
+
+
+def _validate_server_settings(server_settings):
+ if (
+ not isinstance(server_settings, dict) or
+ not all(isinstance(k, str) for k in server_settings) or
+ not all(isinstance(v, str) for v in server_settings.values())
+ ):
+ raise ValueError(
+ 'server_settings is expected to be None or '
+ 'a Dict[str, str]')
+
+
+def _parse_connect_dsn_and_args(
+ *,
+ dsn,
+ host,
+ port,
+ credentials,
+ credentials_file,
+ user,
+ password,
+ secret_key,
+ database,
+ branch,
+ tls_ca,
+ tls_ca_file,
+ tls_security,
+ tls_server_name,
+ server_settings,
+ wait_until_available,
+):
+ resolved_config = ResolvedConnectConfig()
+
+ if dsn and DSN_RE.match(dsn):
+ instance_name = None
+ else:
+ instance_name, dsn = dsn, None
+
+ def _get(key: str) -> typing.Optional[typing.Tuple[str, str]]:
+ val, env = _getenv_and_key(key)
+ return (
+ (val, f'"{env}" environment variable')
+ if val is not None else None
+ )
+
+ # The cloud profile is potentially relevant to resolving credentials at
+ # any stage, including the config stage when other environment variables
+ # are not yet read.
+ cloud_profile_tuple = _get('CLOUD_PROFILE')
+ cloud_profile = cloud_profile_tuple[0] if cloud_profile_tuple else None
+
+ has_compound_options = _resolve_config_options(
+ resolved_config,
+ 'Cannot have more than one of the following connection options: '
+ + '"dsn", "credentials", "credentials_file" or "host"/"port"',
+ dsn=(dsn, '"dsn" option') if dsn is not None else None,
+ instance_name=(
+ (instance_name, '"dsn" option (parsed as instance name)')
+ if instance_name is not None else None
+ ),
+ credentials=(
+ (credentials, '"credentials" option')
+ if credentials is not None else None
+ ),
+ credentials_file=(
+ (credentials_file, '"credentials_file" option')
+ if credentials_file is not None else None
+ ),
+ host=(host, '"host" option') if host is not None else None,
+ port=(port, '"port" option') if port is not None else None,
+ database=(
+ (database, '"database" option')
+ if database is not None else None
+ ),
+ branch=(
+ (branch, '"branch" option')
+ if branch is not None else None
+ ),
+ user=(user, '"user" option') if user is not None else None,
+ password=(
+ (password, '"password" option')
+ if password is not None else None
+ ),
+ secret_key=(
+ (secret_key, '"secret_key" option')
+ if secret_key is not None else None
+ ),
+ tls_ca=(
+ (tls_ca, '"tls_ca" option')
+ if tls_ca is not None else None
+ ),
+ tls_ca_file=(
+ (tls_ca_file, '"tls_ca_file" option')
+ if tls_ca_file is not None else None
+ ),
+ tls_security=(
+ (tls_security, '"tls_security" option')
+ if tls_security is not None else None
+ ),
+ tls_server_name=(
+ (tls_server_name, '"tls_server_name" option')
+ if tls_server_name is not None else None
+ ),
+ server_settings=(
+ (server_settings, '"server_settings" option')
+ if server_settings is not None else None
+ ),
+ wait_until_available=(
+ (wait_until_available, '"wait_until_available" option')
+ if wait_until_available is not None else None
+ ),
+ cloud_profile=cloud_profile_tuple,
+ )
+
+ if has_compound_options is False:
+ env_port_tuple = _get("PORT")
+ if (
+ resolved_config._port is None
+ and env_port_tuple
+ and env_port_tuple[0].startswith('tcp://')
+ ):
+ # EDGEDB_PORT is set by 'docker --link' so ignore and warn
+ warnings.warn('EDGEDB_PORT in "tcp://host:port" format, ' +
+ 'so will be ignored', stacklevel=1)
+ env_port_tuple = None
+
+ has_compound_options = _resolve_config_options(
+ resolved_config,
+ # XXX
+ 'Cannot have more than one of the following connection '
+ + 'environment variables: "EDGEDB_DSN", "EDGEDB_INSTANCE", '
+ + '"EDGEDB_CREDENTIALS_FILE" or "EDGEDB_HOST"/"EDGEDB_PORT"',
+ dsn=_get('DSN'),
+ instance_name=_get('INSTANCE'),
+ credentials_file=_get('CREDENTIALS_FILE'),
+ host=_get('HOST'),
+ port=env_port_tuple,
+ database=_get('DATABASE'),
+ branch=_get('BRANCH'),
+ user=_get('USER'),
+ password=_get('PASSWORD'),
+ secret_key=_get('SECRET_KEY'),
+ tls_ca=_get('TLS_CA'),
+ tls_ca_file=_get('TLS_CA_FILE'),
+ tls_security=_get('CLIENT_TLS_SECURITY'),
+ tls_server_name=_get('TLS_SERVER_NAME'),
+ wait_until_available=_get('WAIT_UNTIL_AVAILABLE'),
+ )
+
+ if not has_compound_options:
+ dir = find_gel_project_dir()
+ stash_dir = _stash_path(dir)
+ if os.path.exists(stash_dir):
+ with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f:
+ instance_name = f.read().strip()
+ cloud_profile_file = os.path.join(stash_dir, 'cloud-profile')
+ if os.path.exists(cloud_profile_file):
+ with open(cloud_profile_file, 'rt') as f:
+ cloud_profile = f.read().strip()
+ else:
+ cloud_profile = None
+
+ _resolve_config_options(
+ resolved_config,
+ '',
+ instance_name=(
+ instance_name,
+ f'project linked instance ("{instance_name}")'
+ ),
+ cloud_profile=(
+ cloud_profile,
+ f'project defined cloud profile ("{cloud_profile}")'
+ ),
+ )
+
+ opt_database_file = os.path.join(stash_dir, 'database')
+ if os.path.exists(opt_database_file):
+ with open(opt_database_file, 'rt') as f:
+ database = f.read().strip()
+ resolved_config.set_database(database, "project")
+ else:
+ raise errors.ClientConnectionError(
+ f'Found `gel.toml` but the project is not initialized. '
+ f'Run `gel project init`.'
+ )
+
+ return resolved_config
+
+
+def _parse_dsn_into_config(
+ resolved_config: ResolvedConnectConfig,
+ dsn: typing.Tuple[str, str]
+):
+ dsn_str, source = dsn
+
+ try:
+ parsed = urllib.parse.urlparse(dsn_str)
+ host = (
+ urllib.parse.unquote(parsed.hostname) if parsed.hostname else None
+ )
+ port = parsed.port
+ database = parsed.path
+ user = parsed.username
+ password = parsed.password
+ except Exception as e:
+ raise ValueError(f'invalid DSN or instance name: {str(e)}')
+
+ if parsed.scheme != 'edgedb':
+ raise ValueError(
+ f'invalid DSN: scheme is expected to be '
+ f'"edgedb", got {parsed.scheme!r}')
+
+ query = (
+ urllib.parse.parse_qs(parsed.query, keep_blank_values=True)
+ if parsed.query != ''
+ else {}
+ )
+ for key, val in query.items():
+ if isinstance(val, list):
+ if len(val) > 1:
+ raise ValueError(
+ f'invalid DSN: duplicate query parameter {key}')
+ query[key] = val[-1]
+
+ def handle_dsn_part(
+ paramName, value, currentValue, setter,
+ formatter=lambda val: val
+ ):
+ param_values = [
+ (value if value != '' else None),
+ query.get(paramName),
+ query.get(paramName + '_env'),
+ query.get(paramName + '_file')
+ ]
+ if len([p for p in param_values if p is not None]) > 1:
+ raise ValueError(
+ f'invalid DSN: more than one of ' +
+ f'{(paramName + ", ") if value else ""}' +
+ f'?{paramName}=, ?{paramName}_env=, ?{paramName}_file= ' +
+ f'was specified'
+ )
+
+ if currentValue is None:
+ param = (
+ value if (value is not None and value != '')
+ else query.get(paramName)
+ )
+ paramSource = source
+
+ if param is None:
+ env = query.get(paramName + '_env')
+ if env is not None:
+ param = os.getenv(env)
+ if param is None:
+ raise ValueError(
+ f'{paramName}_env environment variable "{env}" ' +
+ f'doesn\'t exist')
+ paramSource = paramSource + f' ({paramName}_env: {env})'
+ if param is None:
+ filename = query.get(paramName + '_file')
+ if filename is not None:
+ with open(filename) as f:
+ param = f.read()
+ paramSource = (
+ paramSource + f' ({paramName}_file: {filename})'
+ )
+
+ param = formatter(param) if param is not None else None
+
+ setter(param, paramSource)
+
+ query.pop(paramName, None)
+ query.pop(paramName + '_env', None)
+ query.pop(paramName + '_file', None)
+
+ handle_dsn_part(
+ 'host', host, resolved_config._host, resolved_config.set_host
+ )
+
+ handle_dsn_part(
+ 'port', port, resolved_config._port, resolved_config.set_port
+ )
+
+ def strip_leading_slash(str):
+ return str[1:] if str.startswith('/') else str
+
+ if (
+ 'branch' in query or
+ 'branch_env' in query or
+ 'branch_file' in query
+ ):
+ if (
+ 'database' in query or
+ 'database_env' in query or
+ 'database_file' in query
+ ):
+ raise ValueError(
+ f"invalid DSN: `database` and `branch` cannot be present "
+ f"at the same time"
+ )
+ if resolved_config._database is None:
+ # Only update the config if 'database' has not been already
+ # resolved.
+ handle_dsn_part(
+ 'branch', strip_leading_slash(database),
+ resolved_config._branch, resolved_config.set_branch,
+ strip_leading_slash
+ )
+ else:
+ # Clean up the query, if config already has 'database'
+ query.pop('branch', None)
+ query.pop('branch_env', None)
+ query.pop('branch_file', None)
+
+ else:
+ if resolved_config._branch is None:
+ # Only update the config if 'branch' has not been already
+ # resolved.
+ handle_dsn_part(
+ 'database', strip_leading_slash(database),
+ resolved_config._database, resolved_config.set_database,
+ strip_leading_slash
+ )
+ else:
+ # Clean up the query, if config already has 'branch'
+ query.pop('database', None)
+ query.pop('database_env', None)
+ query.pop('database_file', None)
+
+ handle_dsn_part(
+ 'user', user, resolved_config._user, resolved_config.set_user
+ )
+
+ handle_dsn_part(
+ 'password', password,
+ resolved_config._password, resolved_config.set_password
+ )
+
+ handle_dsn_part(
+ 'secret_key', None,
+ resolved_config._secret_key, resolved_config.set_secret_key
+ )
+
+ handle_dsn_part(
+ 'tls_ca_file', None,
+ resolved_config._tls_ca_data, resolved_config.set_tls_ca_file
+ )
+
+ handle_dsn_part(
+ 'tls_server_name', None,
+ resolved_config._tls_server_name,
+ resolved_config.set_tls_server_name
+ )
+
+ handle_dsn_part(
+ 'tls_security', None,
+ resolved_config._tls_security,
+ resolved_config.set_tls_security
+ )
+
+ handle_dsn_part(
+ 'wait_until_available', None,
+ resolved_config._wait_until_available,
+ resolved_config.set_wait_until_available
+ )
+
+ resolved_config.add_server_settings(query)
+
+
+def _jwt_base64_decode(payload):
+ remainder = len(payload) % 4
+ if remainder == 2:
+ payload += '=='
+ elif remainder == 3:
+ payload += '='
+ elif remainder != 0:
+ raise errors.ClientConnectionError("Invalid secret key")
+ payload = base64.urlsafe_b64decode(payload.encode("utf-8"))
+ return json.loads(payload.decode("utf-8"))
+
+
+def _parse_cloud_instance_name_into_config(
+ resolved_config: ResolvedConnectConfig,
+ source: str,
+ org_slug: str,
+ instance_name: str,
+):
+ org_slug = org_slug.lower()
+ instance_name = instance_name.lower()
+
+ label = f"{instance_name}--{org_slug}"
+ if len(label) > DOMAIN_LABEL_MAX_LENGTH:
+ raise ValueError(
+ f"invalid instance name: cloud instance name length cannot exceed "
+ f"{DOMAIN_LABEL_MAX_LENGTH - 1} characters: "
+ f"{org_slug}/{instance_name}"
+ )
+ secret_key = resolved_config.secret_key
+ if secret_key is None:
+ try:
+ config_dir = platform.config_dir()
+ if resolved_config._cloud_profile is None:
+ profile = profile_src = "default"
+ else:
+ profile = resolved_config._cloud_profile
+ profile_src = resolved_config._cloud_profile_source
+ path = config_dir / "cloud-credentials" / f"{profile}.json"
+ with open(path, "rt") as f:
+ secret_key = json.load(f)["secret_key"]
+ except Exception:
+ raise errors.ClientConnectionError(
+ "Cannot connect to cloud instances without secret key."
+ )
+ resolved_config.set_secret_key(
+ secret_key,
+ f"cloud-credentials/{profile}.json specified by {profile_src}",
+ )
+ try:
+ dns_zone = _jwt_base64_decode(secret_key.split(".", 2)[1])["iss"]
+ except errors.EdgeDBError:
+ raise
+ except Exception:
+ raise errors.ClientConnectionError("Invalid secret key")
+ payload = f"{org_slug}/{instance_name}".encode("utf-8")
+ dns_bucket = binascii.crc_hqx(payload, 0) % 100
+ host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}"
+ resolved_config.set_host(host, source)
+
+
+def _resolve_config_options(
+ resolved_config: ResolvedConnectConfig,
+ compound_error: str,
+ *,
+ dsn=None,
+ instance_name=None,
+ credentials=None,
+ credentials_file=None,
+ host=None,
+ port=None,
+ database=None,
+ branch=None,
+ user=None,
+ password=None,
+ secret_key=None,
+ tls_ca=None,
+ tls_ca_file=None,
+ tls_security=None,
+ tls_server_name=None,
+ server_settings=None,
+ wait_until_available=None,
+ cloud_profile=None,
+):
+ if database is not None:
+ if branch is not None:
+ raise errors.ClientConnectionError(
+ f"{database[1]} and {branch[1]} are mutually exclusive"
+ )
+ if resolved_config._branch is None:
+ # Only update the config if 'branch' has not been already
+ # resolved.
+ resolved_config.set_database(*database)
+ if branch is not None:
+ if resolved_config._database is None:
+ # Only update the config if 'database' has not been already
+ # resolved.
+ resolved_config.set_branch(*branch)
+ if user is not None:
+ resolved_config.set_user(*user)
+ if password is not None:
+ resolved_config.set_password(*password)
+ if secret_key is not None:
+ resolved_config.set_secret_key(*secret_key)
+ if tls_ca_file is not None:
+ if tls_ca is not None:
+ raise errors.ClientConnectionError(
+ f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive"
+ )
+ resolved_config.set_tls_ca_file(*tls_ca_file)
+ if tls_ca is not None:
+ resolved_config.set_tls_ca_data(*tls_ca)
+ if tls_security is not None:
+ resolved_config.set_tls_security(*tls_security)
+ if tls_server_name is not None:
+ resolved_config.set_tls_server_name(*tls_server_name)
+ if server_settings is not None:
+ resolved_config.add_server_settings(server_settings[0])
+ if wait_until_available is not None:
+ resolved_config.set_wait_until_available(*wait_until_available)
+ if cloud_profile is not None:
+ resolved_config._set_param('cloud_profile', *cloud_profile)
+
+ compound_params = [
+ dsn,
+ instance_name,
+ credentials,
+ credentials_file,
+ host or port,
+ ]
+ compound_params_count = len([p for p in compound_params if p is not None])
+
+ if compound_params_count > 1:
+ raise errors.ClientConnectionError(compound_error)
+
+ elif compound_params_count == 1:
+ if dsn is not None or host is not None or port is not None:
+ if port is not None:
+ resolved_config.set_port(*port)
+ if dsn is None:
+ dsn = (
+ 'edgedb://' +
+ (_prepare_host_for_dsn(host[0]) if host else ''),
+ host[1] if host is not None else port[1]
+ )
+ _parse_dsn_into_config(resolved_config, dsn)
+ else:
+ if credentials_file is not None:
+ creds = cred_utils.read_credentials(credentials_file[0])
+ source = "credentials"
+ elif credentials is not None:
+ try:
+ cred_data = json.loads(credentials[0])
+ except ValueError as e:
+ raise RuntimeError(f"cannot read credentials") from e
+ else:
+ creds = cred_utils.validate_credentials(cred_data)
+ source = "credentials"
+ elif INSTANCE_NAME_RE.match(instance_name[0]):
+ source = instance_name[1]
+ creds = cred_utils.read_credentials(
+ cred_utils.get_credentials_path(instance_name[0]),
+ )
+ else:
+ name_match = CLOUD_INSTANCE_NAME_RE.match(instance_name[0])
+ if name_match is None:
+ raise ValueError(
+ f'invalid DSN or instance name: "{instance_name[0]}"'
+ )
+ source = instance_name[1]
+ org, inst = name_match.groups()
+ _parse_cloud_instance_name_into_config(
+ resolved_config, source, org, inst
+ )
+ return True
+
+ resolved_config.set_host(creds.get('host'), source)
+ resolved_config.set_port(creds.get('port'), source)
+ if 'database' in creds and resolved_config._branch is None:
+ # Only update the config if 'branch' has not been already
+ # resolved.
+ resolved_config.set_database(creds.get('database'), source)
+
+ elif 'branch' in creds and resolved_config._database is None:
+ # Only update the config if 'database' has not been already
+ # resolved.
+ resolved_config.set_branch(creds.get('branch'), source)
+ resolved_config.set_user(creds.get('user'), source)
+ resolved_config.set_password(creds.get('password'), source)
+ resolved_config.set_tls_ca_data(creds.get('tls_ca'), source)
+ resolved_config.set_tls_security(
+ creds.get('tls_security'),
+ source
+ )
+
+ return True
+
+ else:
+ return False
+
+
+def find_gel_project_dir():
+ dir = os.getcwd()
+ dev = os.stat(dir).st_dev
+
+ while True:
+ gel_toml = os.path.join(dir, 'gel.toml')
+ edgedb_toml = os.path.join(dir, 'edgedb.toml')
+ if not os.path.isfile(gel_toml) and not os.path.isfile(edgedb_toml):
+ parent = os.path.dirname(dir)
+ if parent == dir:
+ raise errors.ClientConnectionError(
+ f'no `gel.toml` found and '
+ f'no connection options specified'
+ )
+ parent_dev = os.stat(parent).st_dev
+ if parent_dev != dev:
+ raise errors.ClientConnectionError(
+ f'no `gel.toml` found and '
+ f'no connection options specified'
+ f'(stopped searching for `edgedb.toml` at file system'
+ f'boundary {dir!r})'
+ )
+ dir = parent
+ dev = parent_dev
+ continue
+ return dir
+
+
+def parse_connect_arguments(
+ *,
+ dsn,
+ host,
+ port,
+ credentials,
+ credentials_file,
+ database,
+ branch,
+ user,
+ password,
+ secret_key,
+ tls_ca,
+ tls_ca_file,
+ tls_security,
+ tls_server_name,
+ timeout,
+ command_timeout,
+ wait_until_available,
+ server_settings,
+) -> typing.Tuple[ResolvedConnectConfig, ClientConfiguration]:
+
+ if command_timeout is not None:
+ try:
+ if isinstance(command_timeout, bool):
+ raise ValueError
+ command_timeout = float(command_timeout)
+ if command_timeout <= 0:
+ raise ValueError
+ except ValueError:
+ raise ValueError(
+ 'invalid command_timeout value: '
+ 'expected greater than 0 float (got {!r})'.format(
+ command_timeout)) from None
+
+ connect_config = _parse_connect_dsn_and_args(
+ dsn=dsn,
+ host=host,
+ port=port,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ database=database,
+ branch=branch,
+ user=user,
+ password=password,
+ secret_key=secret_key,
+ tls_ca=tls_ca,
+ tls_ca_file=tls_ca_file,
+ tls_security=tls_security,
+ tls_server_name=tls_server_name,
+ server_settings=server_settings,
+ wait_until_available=wait_until_available,
+ )
+
+ client_config = ClientConfiguration(
+ connect_timeout=timeout,
+ command_timeout=command_timeout,
+ wait_until_available=connect_config.wait_until_available,
+ )
+
+ return connect_config, client_config
+
+
+def check_alpn_protocol(ssl_obj):
+ if ssl_obj.selected_alpn_protocol() != 'edgedb-binary':
+ raise errors.ClientConnectionFailedError(
+ "The server doesn't support the edgedb-binary protocol."
+ )
+
+
+def render_client_no_connection_error(prefix, addr, attempts, duration):
+ if isinstance(addr, str):
+ msg = (
+ f'{prefix}'
+ f'\n\tAfter {attempts} attempts in {duration:.1f} sec'
+ f'\n\tIs the server running locally and accepting '
+ f'\n\tconnections on Unix domain socket {addr!r}?'
+ )
+ else:
+ msg = (
+ f'{prefix}'
+ f'\n\tAfter {attempts} attempts in {duration:.1f} sec'
+ f'\n\tIs the server running on host {addr[0]!r} '
+ f'and accepting '
+ f'\n\tTCP/IP connections on port {addr[1]}?'
+ )
+ return msg
+
+
+def _extract_errno(s):
+ """Extract multiple errnos from error string
+
+ When we connect to a host that has multiple underlying IP addresses, say
+ ``localhost`` having ``::1`` and ``127.0.0.1``, we get
+ ``OSError("Multiple exceptions:...")`` error without ``.errno`` attribute
+ set. There are multiple ones in the text, so we extract all of them.
+ """
+ result = []
+ for match in ERRNO_RE.finditer(s):
+ result.append(int(match.group(1)))
+ if result:
+ return result
+
+
+def wrap_error(e):
+ message = str(e)
+ if e.errno is None:
+ errnos = _extract_errno(message)
+ else:
+ errnos = [e.errno]
+
+ if errnos:
+ is_temp = any((code in TEMPORARY_ERROR_CODES for code in errnos))
+ else:
+ is_temp = isinstance(e, TEMPORARY_ERRORS)
+
+ if is_temp:
+ return errors.ClientConnectionFailedTemporarilyError(message)
+ else:
+ return errors.ClientConnectionFailedError(message)
diff --git a/edgedb/connresource.py b/gel/connresource.py
similarity index 100%
rename from edgedb/connresource.py
rename to gel/connresource.py
diff --git a/gel/credentials.py b/gel/credentials.py
new file mode 100644
index 00000000..71146fd9
--- /dev/null
+++ b/gel/credentials.py
@@ -0,0 +1,119 @@
+import os
+import pathlib
+import typing
+import json
+
+from . import platform
+
+
+class RequiredCredentials(typing.TypedDict, total=True):
+ port: int
+ user: str
+
+
+class Credentials(RequiredCredentials, total=False):
+ host: typing.Optional[str]
+ password: typing.Optional[str]
+ # It's OK for database and branch to appear in credentials, as long as
+ # they match.
+ database: typing.Optional[str]
+ branch: typing.Optional[str]
+ tls_ca: typing.Optional[str]
+ tls_security: typing.Optional[str]
+
+
+def get_credentials_path(instance_name: str) -> pathlib.Path:
+ return platform.search_config_dir("credentials", instance_name + ".json")
+
+
+def read_credentials(path: os.PathLike) -> Credentials:
+ try:
+ with open(path, encoding='utf-8') as f:
+ credentials = json.load(f)
+ return validate_credentials(credentials)
+ except Exception as e:
+ raise RuntimeError(
+ f"cannot read credentials at {path}"
+ ) from e
+
+
+def validate_credentials(data: dict) -> Credentials:
+ port = data.get('port')
+ if port is None:
+ port = 5656
+ if not isinstance(port, int) or port < 1 or port > 65535:
+ raise ValueError("invalid `port` value")
+
+ user = data.get('user')
+ if user is None:
+ raise ValueError("`user` key is required")
+ if not isinstance(user, str):
+ raise ValueError("`user` must be a string")
+
+ result = { # required keys
+ "user": user,
+ "port": port,
+ }
+
+ host = data.get('host')
+ if host is not None:
+ if not isinstance(host, str):
+ raise ValueError("`host` must be a string")
+ result['host'] = host
+
+ database = data.get('database')
+ if database is not None:
+ if not isinstance(database, str):
+ raise ValueError("`database` must be a string")
+ result['database'] = database
+
+ branch = data.get('branch')
+ if branch is not None:
+ if not isinstance(branch, str):
+ raise ValueError("`branch` must be a string")
+ if database is not None and branch != database:
+ raise ValueError(
+ f"`database` and `branch` cannot be different")
+ result['branch'] = branch
+
+ password = data.get('password')
+ if password is not None:
+ if not isinstance(password, str):
+ raise ValueError("`password` must be a string")
+ result['password'] = password
+
+ ca = data.get('tls_ca')
+ if ca is not None:
+ if not isinstance(ca, str):
+ raise ValueError("`tls_ca` must be a string")
+ result['tls_ca'] = ca
+
+ cert_data = data.get('tls_cert_data')
+ if cert_data is not None:
+ if not isinstance(cert_data, str):
+ raise ValueError("`tls_cert_data` must be a string")
+ if ca is not None and ca != cert_data:
+ raise ValueError(
+ f"tls_ca and tls_cert_data are both set and disagree")
+ result['tls_ca'] = cert_data
+
+ verify = data.get('tls_verify_hostname')
+ if verify is not None:
+ if not isinstance(verify, bool):
+ raise ValueError("`tls_verify_hostname` must be a bool")
+ result['tls_security'] = 'strict' if verify else 'no_host_verification'
+
+ tls_security = data.get('tls_security')
+ if tls_security is not None:
+ if not isinstance(tls_security, str):
+ raise ValueError("`tls_security` must be a string")
+ result['tls_security'] = tls_security
+
+ missmatch = ValueError(f"tls_verify_hostname={verify} and "
+ f"tls_security={tls_security} are incompatible")
+ if tls_security == "strict" and verify is False:
+ raise missmatch
+ if tls_security in {"no_host_verification", "insecure"} and verify is True:
+ raise missmatch
+
+ return result
diff --git a/edgedb/datatypes/.gitignore b/gel/datatypes/.gitignore
similarity index 100%
rename from edgedb/datatypes/.gitignore
rename to gel/datatypes/.gitignore
diff --git a/edgedb/datatypes/NOTES b/gel/datatypes/NOTES
similarity index 86%
rename from edgedb/datatypes/NOTES
rename to gel/datatypes/NOTES
index 3a18a430..26068364 100644
--- a/edgedb/datatypes/NOTES
+++ b/gel/datatypes/NOTES
@@ -5,7 +5,7 @@ Objects don't have tp_clear; that's because there shouldn't be
a situation when there's a ref cycle between them -- the way
the data is serialized on the wire precludes that.
-Furthermore, all edgedb.datatypes objects are immutable, so
+Furthermore, all gel.datatypes objects are immutable, so
it's not possible to create a ref cycle with only them in the
reference chain by using the public API.
diff --git a/edgedb/datatypes/__init__.pxd b/gel/datatypes/__init__.pxd
similarity index 100%
rename from edgedb/datatypes/__init__.pxd
rename to gel/datatypes/__init__.pxd
diff --git a/gel/datatypes/__init__.py b/gel/datatypes/__init__.py
new file mode 100644
index 00000000..73609285
--- /dev/null
+++ b/gel/datatypes/__init__.py
@@ -0,0 +1,17 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/edgedb/datatypes/args.c b/gel/datatypes/args.c
similarity index 100%
rename from edgedb/datatypes/args.c
rename to gel/datatypes/args.c
diff --git a/edgedb/datatypes/comp.c b/gel/datatypes/comp.c
similarity index 100%
rename from edgedb/datatypes/comp.c
rename to gel/datatypes/comp.c
diff --git a/edgedb/datatypes/config_memory.pxd b/gel/datatypes/config_memory.pxd
similarity index 100%
rename from edgedb/datatypes/config_memory.pxd
rename to gel/datatypes/config_memory.pxd
diff --git a/edgedb/datatypes/config_memory.pyx b/gel/datatypes/config_memory.pyx
similarity index 97%
rename from edgedb/datatypes/config_memory.pyx
rename to gel/datatypes/config_memory.pyx
index c28193e8..6ebf1ee4 100644
--- a/edgedb/datatypes/config_memory.pyx
+++ b/gel/datatypes/config_memory.pyx
@@ -43,7 +43,7 @@ cdef class ConfigMemory:
return hash((ConfigMemory, self._bytes))
def __repr__(self):
- return f''
+ return f''
@cython.cdivision(True)
def __str__(self):
diff --git a/edgedb/datatypes/datatypes.h b/gel/datatypes/datatypes.h
similarity index 94%
rename from edgedb/datatypes/datatypes.h
rename to gel/datatypes/datatypes.h
index c6a7589b..d1077477 100644
--- a/edgedb/datatypes/datatypes.h
+++ b/gel/datatypes/datatypes.h
@@ -32,7 +32,7 @@
#define EDGE_POINTER_IS_LINK (1 << 2)
-/* === edgedb.RecordDesc ==================================== */
+/* === gel.RecordDesc ==================================== */
extern PyTypeObject EdgeRecordDesc_Type;
@@ -86,7 +86,7 @@ PyObject * EdgeRecordDesc_List(PyObject *, uint8_t, uint8_t);
PyObject * EdgeRecordDesc_GetDataclassFields(PyObject *);
-/* === edgedb.NamedTuple ==================================== */
+/* === gel.NamedTuple ==================================== */
#define EDGE_NAMEDTUPLE_FREELIST_SIZE 500
#define EDGE_NAMEDTUPLE_FREELIST_MAXSAVE 20
@@ -98,7 +98,7 @@ PyObject * EdgeNamedTuple_Type_New(PyObject *);
PyObject * EdgeNamedTuple_New(PyObject *);
-/* === edgedb.Object ======================================== */
+/* === gel.Object ======================================== */
#define EDGE_OBJECT_FREELIST_SIZE 2000
#define EDGE_OBJECT_FREELIST_MAXSAVE 20
diff --git a/edgedb/datatypes/datatypes.pxd b/gel/datatypes/datatypes.pxd
similarity index 100%
rename from edgedb/datatypes/datatypes.pxd
rename to gel/datatypes/datatypes.pxd
diff --git a/edgedb/datatypes/datatypes.pyx b/gel/datatypes/datatypes.pyx
similarity index 100%
rename from edgedb/datatypes/datatypes.pyx
rename to gel/datatypes/datatypes.pyx
diff --git a/edgedb/datatypes/date_duration.pxd b/gel/datatypes/date_duration.pxd
similarity index 100%
rename from edgedb/datatypes/date_duration.pxd
rename to gel/datatypes/date_duration.pxd
diff --git a/edgedb/datatypes/date_duration.pyx b/gel/datatypes/date_duration.pyx
similarity index 97%
rename from edgedb/datatypes/date_duration.pyx
rename to gel/datatypes/date_duration.pyx
index 728075f4..a8d6e6fc 100644
--- a/edgedb/datatypes/date_duration.pyx
+++ b/gel/datatypes/date_duration.pyx
@@ -43,7 +43,7 @@ cdef class DateDuration:
return hash((DateDuration, self.days, self.months))
def __repr__(self):
- return f''
+ return f''
@cython.cdivision(True)
def __str__(self):
diff --git a/edgedb/datatypes/enum.pyx b/gel/datatypes/enum.pyx
similarity index 97%
rename from edgedb/datatypes/enum.pyx
rename to gel/datatypes/enum.pyx
index 9ca4b155..a9155ee1 100644
--- a/edgedb/datatypes/enum.pyx
+++ b/gel/datatypes/enum.pyx
@@ -25,7 +25,7 @@ class EnumValue(enum.Enum):
return self._value_
def __repr__(self):
- return f''
+ return f''
@classmethod
def _try_from(cls, value):
diff --git a/edgedb/datatypes/freelist.h b/gel/datatypes/freelist.h
similarity index 100%
rename from edgedb/datatypes/freelist.h
rename to gel/datatypes/freelist.h
diff --git a/edgedb/datatypes/hash.c b/gel/datatypes/hash.c
similarity index 100%
rename from edgedb/datatypes/hash.c
rename to gel/datatypes/hash.c
diff --git a/edgedb/datatypes/internal.h b/gel/datatypes/internal.h
similarity index 100%
rename from edgedb/datatypes/internal.h
rename to gel/datatypes/internal.h
diff --git a/edgedb/datatypes/namedtuple.c b/gel/datatypes/namedtuple.c
similarity index 96%
rename from edgedb/datatypes/namedtuple.c
rename to gel/datatypes/namedtuple.c
index de0fa8ea..251a0e92 100644
--- a/edgedb/datatypes/namedtuple.c
+++ b/gel/datatypes/namedtuple.c
@@ -159,7 +159,7 @@ namedtuple_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
{
PyErr_SetString(
PyExc_ValueError,
- "edgedb.NamedTuple requires at least one field/value");
+ "gel.NamedTuple requires at least one field/value");
goto fail;
}
@@ -254,7 +254,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
if (args_size > size) {
PyErr_Format(
PyExc_ValueError,
- "edgedb.NamedTuple only needs %zd arguments, %zd given",
+ "gel.NamedTuple only needs %zd arguments, %zd given",
size, args_size);
goto fail;
}
@@ -270,7 +270,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
} else {
PyErr_Format(
PyExc_ValueError,
- "edgedb.NamedTuple requires %zd arguments, %zd given",
+ "gel.NamedTuple requires %zd arguments, %zd given",
size, args_size);
goto fail;
}
@@ -278,7 +278,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
if (PyDict_Size(kwargs) > size - args_size) {
PyErr_SetString(
PyExc_ValueError,
- "edgedb.NamedTuple got extra keyword arguments");
+ "gel.NamedTuple got extra keyword arguments");
goto fail;
}
for (Py_ssize_t i = args_size; i < size; i++) {
@@ -294,7 +294,7 @@ namedtuple_derived_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
} else {
PyErr_Format(
PyExc_ValueError,
- "edgedb.NamedTuple missing required argument: %U",
+ "gel.NamedTuple missing required argument: %U",
key);
Py_CLEAR(key);
goto fail;
@@ -407,7 +407,7 @@ static PyType_Slot namedtuple_slots[] = {
static PyType_Spec namedtuple_spec = {
- "edgedb.DerivedNamedTuple",
+ "gel.DerivedNamedTuple",
sizeof(PyTupleObject) - sizeof(PyObject *),
sizeof(PyObject *),
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
@@ -417,7 +417,7 @@ static PyType_Spec namedtuple_spec = {
PyTypeObject EdgeNamedTuple_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
- "edgedb.NamedTuple",
+ "gel.NamedTuple",
.tp_basicsize = sizeof(PyTupleObject) - sizeof(PyObject *),
.tp_itemsize = sizeof(PyObject *),
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
diff --git a/edgedb/datatypes/object.c b/gel/datatypes/object.c
similarity index 99%
rename from edgedb/datatypes/object.c
rename to gel/datatypes/object.c
index 1d3029be..a5965bc1 100644
--- a/edgedb/datatypes/object.c
+++ b/gel/datatypes/object.c
@@ -86,7 +86,7 @@ EdgeObject_GetRecordDesc(PyObject *o)
if (!EdgeObject_Check(o)) {
PyErr_Format(
PyExc_TypeError,
- "an instance of edgedb.Object expected");
+ "an instance of gel.Object expected");
return NULL;
}
@@ -324,7 +324,7 @@ static PyMappingMethods object_as_mapping = {
PyTypeObject EdgeObject_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
- "edgedb.Object",
+ "gel.Object",
.tp_basicsize = sizeof(EdgeObject) - sizeof(PyObject *),
.tp_itemsize = sizeof(PyObject *),
.tp_dealloc = (destructor)object_dealloc,
diff --git a/edgedb/datatypes/pythoncapi_compat.h b/gel/datatypes/pythoncapi_compat.h
similarity index 100%
rename from edgedb/datatypes/pythoncapi_compat.h
rename to gel/datatypes/pythoncapi_compat.h
diff --git a/gel/datatypes/range.py b/gel/datatypes/range.py
new file mode 100644
index 00000000..e3fd3d1e
--- /dev/null
+++ b/gel/datatypes/range.py
@@ -0,0 +1,165 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator,
+ Sequence)
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+
+ __slots__ = ("_lower", "_upper", "_inc_lower", "_inc_upper", "_empty")
+
+ def __init__(
+ self,
+ lower: Optional[T] = None,
+ upper: Optional[T] = None,
+ *,
+ inc_lower: bool = True,
+ inc_upper: bool = False,
+ empty: bool = False,
+ ) -> None:
+ self._empty = empty
+
+ if empty:
+ if (
+ lower != upper
+ or lower is not None and inc_upper and inc_lower
+ ):
+ raise ValueError(
+ "conflicting arguments in range constructor: "
+ "\"empty\" is `true` while the specified bounds "
+ "suggest otherwise"
+ )
+
+ self._lower = self._upper = None
+ self._inc_lower = self._inc_upper = False
+ else:
+ self._lower = lower
+ self._upper = upper
+ self._inc_lower = lower is not None and inc_lower
+ self._inc_upper = upper is not None and inc_upper
+
+ @property
+ def lower(self) -> Optional[T]:
+ return self._lower
+
+ @property
+ def inc_lower(self) -> bool:
+ return self._inc_lower
+
+ @property
+ def upper(self) -> Optional[T]:
+ return self._upper
+
+ @property
+ def inc_upper(self) -> bool:
+ return self._inc_upper
+
+ def is_empty(self) -> bool:
+ return self._empty
+
+ def __bool__(self):
+ return not self.is_empty()
+
+ def __eq__(self, other) -> bool:
+ if isinstance(other, Range):
+ o = other
+ else:
+ return NotImplemented
+
+ return (
+ self._lower,
+ self._upper,
+ self._inc_lower,
+ self._inc_upper,
+ self._empty,
+ ) == (
+ o._lower,
+ o._upper,
+ o._inc_lower,
+ o._inc_upper,
+ o._empty,
+ )
+
+ def __hash__(self) -> int:
+ return hash((
+ self._lower,
+ self._upper,
+ self._inc_lower,
+ self._inc_upper,
+ self._empty,
+ ))
+
+ def __str__(self) -> str:
+ if self._empty:
+ desc = "empty"
+ else:
+ lb = "(" if not self._inc_lower else "["
+ if self._lower is not None:
+ lb += repr(self._lower)
+
+ if self._upper is not None:
+ ub = repr(self._upper)
+ else:
+ ub = ""
+
+ ub += ")" if self._inc_upper else "]"
+
+ desc = f"{lb}, {ub}"
+
+ return f""
+
+ __repr__ = __str__
+
+
+# TODO: maybe we should implement range and multirange operations as well as
+# normalization of the sub-ranges?
+class MultiRange(Iterable[T]):
+
+ _ranges: Sequence[T]
+
+ def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
+ if iterable is not None:
+ self._ranges = tuple(iterable)
+ else:
+ self._ranges = tuple()
+
+ def __len__(self) -> int:
+ return len(self._ranges)
+
+ def __iter__(self) -> Iterator[T]:
+ return iter(self._ranges)
+
+ def __reversed__(self) -> Iterator[T]:
+ return reversed(self._ranges)
+
+ def __str__(self) -> str:
+ return f''
+
+ __repr__ = __str__
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, MultiRange):
+ return set(self._ranges) == set(other._ranges)
+ else:
+ return NotImplemented
+
+ def __hash__(self) -> int:
+ return hash(self._ranges)
diff --git a/edgedb/datatypes/record_desc.c b/gel/datatypes/record_desc.c
similarity index 99%
rename from edgedb/datatypes/record_desc.c
rename to gel/datatypes/record_desc.c
index 69d24512..0dc98192 100644
--- a/edgedb/datatypes/record_desc.c
+++ b/gel/datatypes/record_desc.c
@@ -202,7 +202,7 @@ static PyMethodDef record_desc_methods[] = {
PyTypeObject EdgeRecordDesc_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
- .tp_name = "edgedb.RecordDescriptor",
+ .tp_name = "gel.RecordDescriptor",
.tp_basicsize = sizeof(EdgeRecordDescObject),
.tp_dealloc = (destructor)record_desc_dealloc,
.tp_getattro = PyObject_GenericGetAttr,
diff --git a/edgedb/datatypes/relative_duration.pxd b/gel/datatypes/relative_duration.pxd
similarity index 100%
rename from edgedb/datatypes/relative_duration.pxd
rename to gel/datatypes/relative_duration.pxd
diff --git a/edgedb/datatypes/relative_duration.pyx b/gel/datatypes/relative_duration.pyx
similarity index 98%
rename from edgedb/datatypes/relative_duration.pyx
rename to gel/datatypes/relative_duration.pyx
index 27c0c648..3489b315 100644
--- a/edgedb/datatypes/relative_duration.pyx
+++ b/gel/datatypes/relative_duration.pyx
@@ -50,7 +50,7 @@ cdef class RelativeDuration:
return hash((RelativeDuration, self.microseconds, self.days, self.months))
def __repr__(self):
- return f''
+ return f''
@cython.cdivision(True)
def __str__(self):
diff --git a/edgedb/datatypes/repr.c b/gel/datatypes/repr.c
similarity index 100%
rename from edgedb/datatypes/repr.c
rename to gel/datatypes/repr.c
diff --git a/gel/describe.py b/gel/describe.py
new file mode 100644
index 00000000..92e38854
--- /dev/null
+++ b/gel/describe.py
@@ -0,0 +1,98 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2022-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import dataclasses
+import typing
+import uuid
+
+from . import enums
+
+
+@dataclasses.dataclass(frozen=True)
+class AnyType:
+ desc_id: uuid.UUID
+ name: typing.Optional[str]
+
+
+@dataclasses.dataclass(frozen=True)
+class Element:
+ type: AnyType
+ cardinality: enums.Cardinality
+ is_implicit: bool
+ kind: enums.ElementKind
+
+
+@dataclasses.dataclass(frozen=True)
+class SequenceType(AnyType):
+ element_type: AnyType
+
+
+@dataclasses.dataclass(frozen=True)
+class SetType(SequenceType):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class ObjectType(AnyType):
+ elements: typing.Dict[str, Element]
+
+
+@dataclasses.dataclass(frozen=True)
+class BaseScalarType(AnyType):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class ScalarType(AnyType):
+ base_type: BaseScalarType
+
+
+@dataclasses.dataclass(frozen=True)
+class TupleType(AnyType):
+ element_types: typing.Tuple[AnyType]
+
+
+@dataclasses.dataclass(frozen=True)
+class NamedTupleType(AnyType):
+ element_types: typing.Dict[str, AnyType]
+
+
+@dataclasses.dataclass(frozen=True)
+class ArrayType(SequenceType):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class EnumType(AnyType):
+ members: typing.Tuple[str]
+
+
+@dataclasses.dataclass(frozen=True)
+class SparseObjectType(ObjectType):
+ pass
+
+
+@dataclasses.dataclass(frozen=True)
+class RangeType(AnyType):
+ value_type: AnyType
+
+
+@dataclasses.dataclass(frozen=True)
+class MultiRangeType(AnyType):
+ value_type: AnyType
diff --git a/gel/enums.py b/gel/enums.py
new file mode 100644
index 00000000..6312f596
--- /dev/null
+++ b/gel/enums.py
@@ -0,0 +1,74 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2021-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import enum
+
+
+class Capability(enum.IntFlag):
+
+ NONE = 0 # noqa
+ MODIFICATIONS = 1 << 0 # noqa
+ SESSION_CONFIG = 1 << 1 # noqa
+ TRANSACTION = 1 << 2 # noqa
+ DDL = 1 << 3 # noqa
+ PERSISTENT_CONFIG = 1 << 4 # noqa
+
+ ALL = 0xFFFF_FFFF_FFFF_FFFF # noqa
+ EXECUTE = ALL & ~TRANSACTION & ~SESSION_CONFIG # noqa
+ LEGACY_EXECUTE = ALL & ~TRANSACTION # noqa
+
+
+class CompilationFlag(enum.IntFlag):
+
+ INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa
+ INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa
+ INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa
+
+
+class Cardinality(enum.Enum):
+ # Cardinality isn't applicable for the query:
+ # * the query is a command like CONFIGURE that
+ # does not return any data;
+ # * the query is composed of multiple queries.
+ NO_RESULT = 0x6e
+
+ # Cardinality is 1 or 0
+ AT_MOST_ONE = 0x6f
+
+ # Cardinality is 1
+ ONE = 0x41
+
+ # Cardinality is >= 0
+ MANY = 0x6d
+
+ # Cardinality is >= 1
+ AT_LEAST_ONE = 0x4d
+
+ def is_single(self) -> bool:
+ return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE}
+
+ def is_multi(self) -> bool:
+ return self in {Cardinality.AT_LEAST_ONE, Cardinality.MANY}
+
+
+class ElementKind(enum.Enum):
+
+ LINK = 1 # noqa
+ PROPERTY = 2 # noqa
+ LINK_PROPERTY = 3 # noqa
diff --git a/gel/errors/__init__.py b/gel/errors/__init__.py
new file mode 100644
index 00000000..87f3711d
--- /dev/null
+++ b/gel/errors/__init__.py
@@ -0,0 +1,518 @@
+# AUTOGENERATED FROM "edb/api/errors.txt" WITH
+# $ edb gen-errors \
+# --import 'from gel.errors._base import *\nfrom gel.errors.tags import *' \
+# --extra-all "_base.__all__" \
+# --stdout \
+# --client
+
+
+# flake8: noqa
+
+
+from gel.errors._base import *
+from gel.errors.tags import *
+
+
+__all__ = _base.__all__ + ( # type: ignore
+ 'InternalServerError',
+ 'UnsupportedFeatureError',
+ 'ProtocolError',
+ 'BinaryProtocolError',
+ 'UnsupportedProtocolVersionError',
+ 'TypeSpecNotFoundError',
+ 'UnexpectedMessageError',
+ 'InputDataError',
+ 'ParameterTypeMismatchError',
+ 'StateMismatchError',
+ 'ResultCardinalityMismatchError',
+ 'CapabilityError',
+ 'UnsupportedCapabilityError',
+ 'DisabledCapabilityError',
+ 'QueryError',
+ 'InvalidSyntaxError',
+ 'EdgeQLSyntaxError',
+ 'SchemaSyntaxError',
+ 'GraphQLSyntaxError',
+ 'InvalidTypeError',
+ 'InvalidTargetError',
+ 'InvalidLinkTargetError',
+ 'InvalidPropertyTargetError',
+ 'InvalidReferenceError',
+ 'UnknownModuleError',
+ 'UnknownLinkError',
+ 'UnknownPropertyError',
+ 'UnknownUserError',
+ 'UnknownDatabaseError',
+ 'UnknownParameterError',
+ 'SchemaError',
+ 'SchemaDefinitionError',
+ 'InvalidDefinitionError',
+ 'InvalidModuleDefinitionError',
+ 'InvalidLinkDefinitionError',
+ 'InvalidPropertyDefinitionError',
+ 'InvalidUserDefinitionError',
+ 'InvalidDatabaseDefinitionError',
+ 'InvalidOperatorDefinitionError',
+ 'InvalidAliasDefinitionError',
+ 'InvalidFunctionDefinitionError',
+ 'InvalidConstraintDefinitionError',
+ 'InvalidCastDefinitionError',
+ 'DuplicateDefinitionError',
+ 'DuplicateModuleDefinitionError',
+ 'DuplicateLinkDefinitionError',
+ 'DuplicatePropertyDefinitionError',
+ 'DuplicateUserDefinitionError',
+ 'DuplicateDatabaseDefinitionError',
+ 'DuplicateOperatorDefinitionError',
+ 'DuplicateViewDefinitionError',
+ 'DuplicateFunctionDefinitionError',
+ 'DuplicateConstraintDefinitionError',
+ 'DuplicateCastDefinitionError',
+ 'DuplicateMigrationError',
+ 'SessionTimeoutError',
+ 'IdleSessionTimeoutError',
+ 'QueryTimeoutError',
+ 'TransactionTimeoutError',
+ 'IdleTransactionTimeoutError',
+ 'ExecutionError',
+ 'InvalidValueError',
+ 'DivisionByZeroError',
+ 'NumericOutOfRangeError',
+ 'AccessPolicyError',
+ 'QueryAssertionError',
+ 'IntegrityError',
+ 'ConstraintViolationError',
+ 'CardinalityViolationError',
+ 'MissingRequiredError',
+ 'TransactionError',
+ 'TransactionConflictError',
+ 'TransactionSerializationError',
+ 'TransactionDeadlockError',
+ 'WatchError',
+ 'ConfigurationError',
+ 'AccessError',
+ 'AuthenticationError',
+ 'AvailabilityError',
+ 'BackendUnavailableError',
+ 'ServerOfflineError',
+ 'BackendError',
+ 'UnsupportedBackendFeatureError',
+ 'LogMessage',
+ 'WarningMessage',
+ 'ClientError',
+ 'ClientConnectionError',
+ 'ClientConnectionFailedError',
+ 'ClientConnectionFailedTemporarilyError',
+ 'ClientConnectionTimeoutError',
+ 'ClientConnectionClosedError',
+ 'InterfaceError',
+ 'QueryArgumentError',
+ 'MissingArgumentError',
+ 'UnknownArgumentError',
+ 'InvalidArgumentError',
+ 'NoDataError',
+ 'InternalClientError',
+)
+
+
+class InternalServerError(EdgeDBError):
+ _code = 0x_01_00_00_00
+
+
+class UnsupportedFeatureError(EdgeDBError):
+ _code = 0x_02_00_00_00
+
+
+class ProtocolError(EdgeDBError):
+ _code = 0x_03_00_00_00
+
+
+class BinaryProtocolError(ProtocolError):
+ _code = 0x_03_01_00_00
+
+
+class UnsupportedProtocolVersionError(BinaryProtocolError):
+ _code = 0x_03_01_00_01
+
+
+class TypeSpecNotFoundError(BinaryProtocolError):
+ _code = 0x_03_01_00_02
+
+
+class UnexpectedMessageError(BinaryProtocolError):
+ _code = 0x_03_01_00_03
+
+
+class InputDataError(ProtocolError):
+ _code = 0x_03_02_00_00
+
+
+class ParameterTypeMismatchError(InputDataError):
+ _code = 0x_03_02_01_00
+
+
+class StateMismatchError(InputDataError):
+ _code = 0x_03_02_02_00
+ tags = frozenset({SHOULD_RETRY})
+
+
+class ResultCardinalityMismatchError(ProtocolError):
+ _code = 0x_03_03_00_00
+
+
+class CapabilityError(ProtocolError):
+ _code = 0x_03_04_00_00
+
+
+class UnsupportedCapabilityError(CapabilityError):
+ _code = 0x_03_04_01_00
+
+
+class DisabledCapabilityError(CapabilityError):
+ _code = 0x_03_04_02_00
+
+
+class QueryError(EdgeDBError):
+ _code = 0x_04_00_00_00
+
+
+class InvalidSyntaxError(QueryError):
+ _code = 0x_04_01_00_00
+
+
+class EdgeQLSyntaxError(InvalidSyntaxError):
+ _code = 0x_04_01_01_00
+
+
+class SchemaSyntaxError(InvalidSyntaxError):
+ _code = 0x_04_01_02_00
+
+
+class GraphQLSyntaxError(InvalidSyntaxError):
+ _code = 0x_04_01_03_00
+
+
+class InvalidTypeError(QueryError):
+ _code = 0x_04_02_00_00
+
+
+class InvalidTargetError(InvalidTypeError):
+ _code = 0x_04_02_01_00
+
+
+class InvalidLinkTargetError(InvalidTargetError):
+ _code = 0x_04_02_01_01
+
+
+class InvalidPropertyTargetError(InvalidTargetError):
+ _code = 0x_04_02_01_02
+
+
+class InvalidReferenceError(QueryError):
+ _code = 0x_04_03_00_00
+
+
+class UnknownModuleError(InvalidReferenceError):
+ _code = 0x_04_03_00_01
+
+
+class UnknownLinkError(InvalidReferenceError):
+ _code = 0x_04_03_00_02
+
+
+class UnknownPropertyError(InvalidReferenceError):
+ _code = 0x_04_03_00_03
+
+
+class UnknownUserError(InvalidReferenceError):
+ _code = 0x_04_03_00_04
+
+
+class UnknownDatabaseError(InvalidReferenceError):
+ _code = 0x_04_03_00_05
+
+
+class UnknownParameterError(InvalidReferenceError):
+ _code = 0x_04_03_00_06
+
+
+class SchemaError(QueryError):
+ _code = 0x_04_04_00_00
+
+
+class SchemaDefinitionError(QueryError):
+ _code = 0x_04_05_00_00
+
+
+class InvalidDefinitionError(SchemaDefinitionError):
+ _code = 0x_04_05_01_00
+
+
+class InvalidModuleDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_01
+
+
+class InvalidLinkDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_02
+
+
+class InvalidPropertyDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_03
+
+
+class InvalidUserDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_04
+
+
+class InvalidDatabaseDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_05
+
+
+class InvalidOperatorDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_06
+
+
+class InvalidAliasDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_07
+
+
+class InvalidFunctionDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_08
+
+
+class InvalidConstraintDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_09
+
+
+class InvalidCastDefinitionError(InvalidDefinitionError):
+ _code = 0x_04_05_01_0A
+
+
+class DuplicateDefinitionError(SchemaDefinitionError):
+ _code = 0x_04_05_02_00
+
+
+class DuplicateModuleDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_01
+
+
+class DuplicateLinkDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_02
+
+
+class DuplicatePropertyDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_03
+
+
+class DuplicateUserDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_04
+
+
+class DuplicateDatabaseDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_05
+
+
+class DuplicateOperatorDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_06
+
+
+class DuplicateViewDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_07
+
+
+class DuplicateFunctionDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_08
+
+
+class DuplicateConstraintDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_09
+
+
+class DuplicateCastDefinitionError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_0A
+
+
+class DuplicateMigrationError(DuplicateDefinitionError):
+ _code = 0x_04_05_02_0B
+
+
+class SessionTimeoutError(QueryError):
+ _code = 0x_04_06_00_00
+
+
+class IdleSessionTimeoutError(SessionTimeoutError):
+ _code = 0x_04_06_01_00
+ tags = frozenset({SHOULD_RETRY})
+
+
+class QueryTimeoutError(SessionTimeoutError):
+ _code = 0x_04_06_02_00
+
+
+class TransactionTimeoutError(SessionTimeoutError):
+ _code = 0x_04_06_0A_00
+
+
+class IdleTransactionTimeoutError(TransactionTimeoutError):
+ _code = 0x_04_06_0A_01
+
+
+class ExecutionError(EdgeDBError):
+ _code = 0x_05_00_00_00
+
+
+class InvalidValueError(ExecutionError):
+ _code = 0x_05_01_00_00
+
+
+class DivisionByZeroError(InvalidValueError):
+ _code = 0x_05_01_00_01
+
+
+class NumericOutOfRangeError(InvalidValueError):
+ _code = 0x_05_01_00_02
+
+
+class AccessPolicyError(InvalidValueError):
+ _code = 0x_05_01_00_03
+
+
+class QueryAssertionError(InvalidValueError):
+ _code = 0x_05_01_00_04
+
+
+class IntegrityError(ExecutionError):
+ _code = 0x_05_02_00_00
+
+
+class ConstraintViolationError(IntegrityError):
+ _code = 0x_05_02_00_01
+
+
+class CardinalityViolationError(IntegrityError):
+ _code = 0x_05_02_00_02
+
+
+class MissingRequiredError(IntegrityError):
+ _code = 0x_05_02_00_03
+
+
+class TransactionError(ExecutionError):
+ _code = 0x_05_03_00_00
+
+
+class TransactionConflictError(TransactionError):
+ _code = 0x_05_03_01_00
+ tags = frozenset({SHOULD_RETRY})
+
+
+class TransactionSerializationError(TransactionConflictError):
+ _code = 0x_05_03_01_01
+ tags = frozenset({SHOULD_RETRY})
+
+
+class TransactionDeadlockError(TransactionConflictError):
+ _code = 0x_05_03_01_02
+ tags = frozenset({SHOULD_RETRY})
+
+
+class WatchError(ExecutionError):
+ _code = 0x_05_04_00_00
+
+
+class ConfigurationError(EdgeDBError):
+ _code = 0x_06_00_00_00
+
+
+class AccessError(EdgeDBError):
+ _code = 0x_07_00_00_00
+
+
+class AuthenticationError(AccessError):
+ _code = 0x_07_01_00_00
+
+
+class AvailabilityError(EdgeDBError):
+ _code = 0x_08_00_00_00
+
+
+class BackendUnavailableError(AvailabilityError):
+ _code = 0x_08_00_00_01
+ tags = frozenset({SHOULD_RETRY})
+
+
+class ServerOfflineError(AvailabilityError):
+ _code = 0x_08_00_00_02
+ tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
+
+
+class BackendError(EdgeDBError):
+ _code = 0x_09_00_00_00
+
+
+class UnsupportedBackendFeatureError(BackendError):
+ _code = 0x_09_00_01_00
+
+
+class LogMessage(EdgeDBMessage):
+ _code = 0x_F0_00_00_00
+
+
+class WarningMessage(LogMessage):
+ _code = 0x_F0_01_00_00
+
+
+class ClientError(EdgeDBError):
+ _code = 0x_FF_00_00_00
+
+
+class ClientConnectionError(ClientError):
+ _code = 0x_FF_01_00_00
+
+
+class ClientConnectionFailedError(ClientConnectionError):
+ _code = 0x_FF_01_01_00
+
+
+class ClientConnectionFailedTemporarilyError(ClientConnectionFailedError):
+ _code = 0x_FF_01_01_01
+ tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
+
+
+class ClientConnectionTimeoutError(ClientConnectionError):
+ _code = 0x_FF_01_02_00
+ tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
+
+
+class ClientConnectionClosedError(ClientConnectionError):
+ _code = 0x_FF_01_03_00
+ tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY})
+
+
+class InterfaceError(ClientError):
+ _code = 0x_FF_02_00_00
+
+
+class QueryArgumentError(InterfaceError):
+ _code = 0x_FF_02_01_00
+
+
+class MissingArgumentError(QueryArgumentError):
+ _code = 0x_FF_02_01_01
+
+
+class UnknownArgumentError(QueryArgumentError):
+ _code = 0x_FF_02_01_02
+
+
+class InvalidArgumentError(QueryArgumentError):
+ _code = 0x_FF_02_01_03
+
+
+class NoDataError(ClientError):
+ _code = 0x_FF_03_00_00
+
+
+class InternalClientError(ClientError):
+ _code = 0x_FF_04_00_00
+
diff --git a/gel/errors/_base.py b/gel/errors/_base.py
new file mode 100644
index 00000000..3028f829
--- /dev/null
+++ b/gel/errors/_base.py
@@ -0,0 +1,363 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import io
+import os
+import traceback
+import unicodedata
+import warnings
+
+__all__ = (
+ 'EdgeDBError', 'EdgeDBMessage',
+)
+
+
+class Meta(type):
+
+ def __new__(mcls, name, bases, dct):
+ cls = super().__new__(mcls, name, bases, dct)
+
+ code = dct.get('_code')
+ if code is not None:
+ mcls._index[code] = cls
+
+ # If it's a base class add it to the base class index
+ b1, b2, b3, b4 = _decode(code)
+ if b1 == 0 or b2 == 0 or b3 == 0 or b4 == 0:
+ mcls._base_class_index[(b1, b2, b3, b4)] = cls
+
+ return cls
+
+
+class EdgeDBMessageMeta(Meta):
+
+ _base_class_index = {}
+ _index = {}
+
+
+class EdgeDBMessage(Warning, metaclass=EdgeDBMessageMeta):
+
+ _code = None
+
+ def __init__(self, severity, message):
+ super().__init__(message)
+ self._severity = severity
+
+ def get_severity(self):
+ return self._severity
+
+ def get_severity_name(self):
+ return _severity_name(self._severity)
+
+ def get_code(self):
+ return self._code
+
+ @staticmethod
+ def _from_code(code, severity, message, *args, **kwargs):
+ cls = _lookup_message_cls(code)
+ exc = cls(severity, message, *args, **kwargs)
+ exc._code = code
+ return exc
+
+
+class EdgeDBErrorMeta(Meta):
+
+ _base_class_index = {}
+ _index = {}
+
+
+class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta):
+
+ _code = None
+ _query = None
+ tags = frozenset()
+
+ def __init__(self, *args, **kwargs):
+ self._attrs = {}
+ super().__init__(*args, **kwargs)
+
+ def has_tag(self, tag):
+ return tag in self.tags
+
+ @property
+ def _position(self):
+ # not a stable API method
+ return int(self._read_str_field(FIELD_POSITION_START, -1))
+
+ @property
+ def _position_start(self):
+ # not a stable API method
+ return int(self._read_str_field(FIELD_CHARACTER_START, -1))
+
+ @property
+ def _position_end(self):
+ # not a stable API method
+ return int(self._read_str_field(FIELD_CHARACTER_END, -1))
+
+ @property
+ def _line(self):
+ # not a stable API method
+ return int(self._read_str_field(FIELD_LINE_START, -1))
+
+ @property
+ def _col(self):
+ # not a stable API method
+ return int(self._read_str_field(FIELD_COLUMN_START, -1))
+
+ @property
+ def _hint(self):
+ # not a stable API method
+ return self._read_str_field(FIELD_HINT)
+
+ @property
+ def _details(self):
+ # not a stable API method
+ return self._read_str_field(FIELD_DETAILS)
+
+ def _read_str_field(self, key, default=None):
+ val = self._attrs.get(key)
+ if isinstance(val, bytes):
+ return val.decode('utf-8')
+ elif val is not None:
+ return val
+ return default
+
+ def get_code(self):
+ return self._code
+
+ def get_server_context(self):
+ return self._read_str_field(FIELD_SERVER_TRACEBACK)
+
+ @staticmethod
+ def _from_code(code, *args, **kwargs):
+ cls = _lookup_error_cls(code)
+ exc = cls(*args, **kwargs)
+ exc._code = code
+ return exc
+
+ @staticmethod
+ def _from_json(data):
+ exc = EdgeDBError._from_code(data['code'], data['message'])
+ exc._attrs = {
+ field: data[name]
+ for name, field in _JSON_FIELDS.items()
+ if name in data
+ }
+ return exc
+
+ def __str__(self):
+ msg = super().__str__()
+ if SHOW_HINT and self._query and self._position_start >= 0:
+ try:
+ return _format_error(
+ msg,
+ self._query,
+ self._position_start,
+ max(1, self._position_end - self._position_start),
+ self._line if self._line > 0 else "?",
+ self._col if self._col > 0 else "?",
+ self._hint or "error",
+ self._details,
+ )
+ except Exception:
+ return "".join(
+ (
+ msg,
+ LINESEP,
+ LINESEP,
+ "During formatting of the above exception, "
+ "another exception occurred:",
+ LINESEP,
+ LINESEP,
+ traceback.format_exc(),
+ )
+ )
+ else:
+ return msg
+
+
+def _lookup_cls(code: int, *, meta: type, default: type):
+ try:
+ return meta._index[code]
+ except KeyError:
+ pass
+
+ b1, b2, b3, _ = _decode(code)
+
+ try:
+ return meta._base_class_index[(b1, b2, b3, 0)]
+ except KeyError:
+ pass
+ try:
+ return meta._base_class_index[(b1, b2, 0, 0)]
+ except KeyError:
+ pass
+ try:
+ return meta._base_class_index[(b1, 0, 0, 0)]
+ except KeyError:
+ pass
+
+ return default
+
+
+def _lookup_error_cls(code: int):
+ return _lookup_cls(code, meta=EdgeDBErrorMeta, default=EdgeDBError)
+
+
+def _lookup_message_cls(code: int):
+ return _lookup_cls(code, meta=EdgeDBMessageMeta, default=EdgeDBMessage)
+
+
+def _decode(code: int):
+ return tuple(code.to_bytes(4, 'big'))
+
+
+def _severity_name(severity):
+ if severity <= EDGE_SEVERITY_DEBUG:
+ return 'DEBUG'
+ if severity <= EDGE_SEVERITY_INFO:
+ return 'INFO'
+ if severity <= EDGE_SEVERITY_NOTICE:
+ return 'NOTICE'
+ if severity <= EDGE_SEVERITY_WARNING:
+ return 'WARNING'
+ if severity <= EDGE_SEVERITY_ERROR:
+ return 'ERROR'
+ if severity <= EDGE_SEVERITY_FATAL:
+ return 'FATAL'
+ return 'PANIC'
+
+
+def _format_error(msg, query, start, offset, line, col, hint, details):
+ c = get_color()
+ rv = io.StringIO()
+ rv.write(f"{c.BOLD}{msg}{c.ENDC}{LINESEP}")
+ lines = query.splitlines(keepends=True)
+ num_len = len(str(len(lines)))
+ rv.write(f"{c.BLUE}{'':>{num_len}} ┌─{c.ENDC} query:{line}:{col}{LINESEP}")
+ rv.write(f"{c.BLUE}{'':>{num_len}} │ {c.ENDC}{LINESEP}")
+ for num, line in enumerate(lines):
+ length = len(line)
+ line = line.rstrip() # we'll use our own line separator
+ if start >= length:
+ # skip lines before the error
+ start -= length
+ continue
+
+ if start >= 0:
+ # Error starts in current line, write the line before the error
+ first_half = repr(line[:start])[1:-1]
+ line = line[start:]
+ length -= start
+ rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.ENDC}{first_half}")
+ start = _unicode_width(first_half)
+ else:
+ # Multi-line error continues
+ rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.FAIL}│ {c.ENDC}")
+
+ if offset > length:
+ # Error is ending beyond current line
+ line = repr(line)[1:-1]
+ rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}")
+ if start >= 0:
+ # Multi-line error starts
+ rv.write(f"{c.BLUE}{'':>{num_len}} │ "
+ f"{c.FAIL}â•â”€{'─' * start}^{c.ENDC}{LINESEP}")
+ offset -= length
+ start = -1 # mark multi-line
+ else:
+ # Error is ending within current line
+ first_half = repr(line[:offset])[1:-1]
+ line = repr(line[offset:])[1:-1]
+ rv.write(f"{c.FAIL}{first_half}{c.ENDC}{line}{LINESEP}")
+ size = _unicode_width(first_half)
+ if start >= 0:
+ # Mark single-line error
+ rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}"
+ f"{c.FAIL}{'^' * size} {hint}{c.ENDC}")
+ else:
+ # End of multi-line error
+ rv.write(f"{c.BLUE}{'':>{num_len}} │ "
+ f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}")
+ break
+
+ if details:
+ rv.write(f"{LINESEP}Details: {details}")
+
+ return rv.getvalue()
+
+
+def _unicode_width(text):
+ return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else
+ 2 if unicodedata.east_asian_width(c) == "W" else 1
+ for c in text)
+
+
+FIELD_HINT = 0x_00_01
+FIELD_DETAILS = 0x_00_02
+FIELD_SERVER_TRACEBACK = 0x_01_01
+
+# XXX: Subject to be changed/deprecated.
+FIELD_POSITION_START = 0x_FF_F1
+FIELD_POSITION_END = 0x_FF_F2
+FIELD_LINE_START = 0x_FF_F3
+FIELD_COLUMN_START = 0x_FF_F4
+FIELD_UTF16_COLUMN_START = 0x_FF_F5
+FIELD_LINE_END = 0x_FF_F6
+FIELD_COLUMN_END = 0x_FF_F7
+FIELD_UTF16_COLUMN_END = 0x_FF_F8
+FIELD_CHARACTER_START = 0x_FF_F9
+FIELD_CHARACTER_END = 0x_FF_FA
+
+
+EDGE_SEVERITY_DEBUG = 20
+EDGE_SEVERITY_INFO = 40
+EDGE_SEVERITY_NOTICE = 60
+EDGE_SEVERITY_WARNING = 80
+EDGE_SEVERITY_ERROR = 120
+EDGE_SEVERITY_FATAL = 200
+EDGE_SEVERITY_PANIC = 255
+
+
+# Fields to include in the json dump of the type
+_JSON_FIELDS = {
+ 'hint': FIELD_HINT,
+ 'details': FIELD_DETAILS,
+ 'start': FIELD_CHARACTER_START,
+ 'end': FIELD_CHARACTER_END,
+ 'line': FIELD_LINE_START,
+ 'col': FIELD_COLUMN_START,
+}
+
+
+LINESEP = os.linesep
+
+try:
+ SHOW_HINT = {"default": True, "enabled": True, "disabled": False}[
+ os.getenv("EDGEDB_ERROR_HINT", "default")
+ ]
+except KeyError:
+ warnings.warn(
+ "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled",
+ stacklevel=1,
+ )
+ SHOW_HINT = False
+
+
+from gel.color import get_color
diff --git a/gel/errors/tags.py b/gel/errors/tags.py
new file mode 100644
index 00000000..275b31ac
--- /dev/null
+++ b/gel/errors/tags.py
@@ -0,0 +1,25 @@
+__all__ = [
+ 'Tag',
+ 'SHOULD_RECONNECT',
+ 'SHOULD_RETRY',
+]
+
+
+class Tag(object):
+ """Error tag
+
+ Tags are used to differentiate certain properties of errors that apply to
+ error classes across hierarchy.
+
+ Use ``error.has_tag(tag_name)`` to check for a tag.
+ """
+
+ def __init__(self, name):
+ self.name = name
+
+ def __repr__(self):
+ return f''
+
+
+SHOULD_RECONNECT = Tag('SHOULD_RECONNECT')
+SHOULD_RETRY = Tag('SHOULD_RETRY')
diff --git a/gel/introspect.py b/gel/introspect.py
new file mode 100644
index 00000000..4f7884d6
--- /dev/null
+++ b/gel/introspect.py
@@ -0,0 +1,66 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+# IMPORTANT: this private API is subject to change.
+
+
+import functools
+import typing
+
+from gel.datatypes import datatypes as dt
+from gel.enums import ElementKind
+
+
+class PointerDescription(typing.NamedTuple):
+
+ name: str
+ kind: ElementKind
+ implicit: bool
+
+
+class ObjectDescription(typing.NamedTuple):
+
+ pointers: typing.Tuple[PointerDescription, ...]
+
+
+@functools.lru_cache()
+def _introspect_object_desc(desc) -> ObjectDescription:
+ pointers = []
+ # Call __dir__ directly as dir() scrambles the order.
+ for name in desc.__dir__():
+ if desc.is_link(name):
+ kind = ElementKind.LINK
+ elif desc.is_linkprop(name):
+ continue
+ else:
+ kind = ElementKind.PROPERTY
+
+ pointers.append(
+ PointerDescription(
+ name=name,
+ kind=kind,
+ implicit=desc.is_implicit(name)))
+
+ return ObjectDescription(
+ pointers=tuple(pointers))
+
+
+def introspect_object(obj) -> ObjectDescription:
+ return _introspect_object_desc(
+ dt.get_object_descriptor(obj))
diff --git a/gel/options.py b/gel/options.py
new file mode 100644
index 00000000..6f83d421
--- /dev/null
+++ b/gel/options.py
@@ -0,0 +1,492 @@
+import abc
+import enum
+import logging
+import random
+import typing
+import sys
+from collections import namedtuple
+
+from . import errors
+
+
+logger = logging.getLogger('gel')
+
+
+_RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"])
+
+
+def default_backoff(attempt):
+ return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001
+
+
+WarningHandler = typing.Callable[
+ [typing.Tuple[errors.EdgeDBError, ...], typing.Any],
+ typing.Any,
+]
+
+
+def raise_warnings(warnings, res):
+ if (
+ len(warnings) > 1
+ and sys.version_info >= (3, 11)
+ ):
+ raise ExceptionGroup( # noqa
+ "Query produced warnings", warnings
+ )
+ else:
+ raise warnings[0]
+
+
+def log_warnings(warnings, res):
+ for w in warnings:
+ logger.warning("EdgeDB warning: %s", str(w))
+ return res
+
+
+class RetryCondition:
+ """Specific condition to retry on for fine-grained control"""
+ TransactionConflict = enum.auto()
+ NetworkError = enum.auto()
+
+
+class IsolationLevel:
+ """Isolation level for transaction"""
+ Serializable = "SERIALIZABLE"
+
+
+class RetryOptions:
+ """An immutable class that contains rules for `transaction()`"""
+ __slots__ = ['_default', '_overrides']
+
+ def __init__(self, attempts: int, backoff=default_backoff):
+ self._default = _RetryRule(attempts, backoff)
+ self._overrides = None
+
+ def with_rule(self, condition, attempts=None, backoff=None):
+ default = self._default
+ overrides = self._overrides
+ if overrides is None:
+ overrides = {}
+ else:
+ overrides = overrides.copy()
+ overrides[condition] = _RetryRule(
+ default.attempts if attempts is None else attempts,
+ default.backoff if backoff is None else backoff,
+ )
+ result = RetryOptions.__new__(RetryOptions)
+ result._default = default
+ result._overrides = overrides
+ return result
+
+ @classmethod
+ def defaults(cls):
+ return cls(
+ attempts=3,
+ backoff=default_backoff,
+ )
+
+ def get_rule_for_exception(self, exception):
+ default = self._default
+ overrides = self._overrides
+ res = default
+ if overrides:
+ if isinstance(exception, errors.TransactionConflictError):
+ res = overrides.get(RetryCondition.TransactionConflict, res)
+ elif isinstance(exception, errors.ClientError):
+ res = overrides.get(RetryCondition.NetworkError, res)
+ return res
+
+
+class TransactionOptions:
+ """Options for `transaction()`"""
+ __slots__ = ['_isolation', '_readonly', '_deferrable']
+
+ def __init__(
+ self,
+ isolation: IsolationLevel=IsolationLevel.Serializable,
+ readonly: bool = False,
+ deferrable: bool = False,
+ ):
+ self._isolation = isolation
+ self._readonly = readonly
+ self._deferrable = deferrable
+
+ @classmethod
+ def defaults(cls):
+ return cls()
+
+ def start_transaction_query(self):
+ isolation = str(self._isolation)
+ if self._readonly:
+ mode = 'READ ONLY'
+ else:
+ mode = 'READ WRITE'
+
+ if self._deferrable:
+ defer = 'DEFERRABLE'
+ else:
+ defer = 'NOT DEFERRABLE'
+
+ return f'START TRANSACTION ISOLATION {isolation}, {mode}, {defer};'
+
+ def __repr__(self):
+ return (
+ f'<{self.__class__.__name__} '
+ f'isolation:{self._isolation}, '
+ f'readonly:{self._readonly}, '
+ f'deferrable:{self._deferrable}>'
+ )
+
+
+class State:
+ __slots__ = ['_module', '_aliases', '_config', '_globals']
+
+ def __init__(
+ self,
+ default_module: typing.Optional[str] = None,
+ module_aliases: typing.Mapping[str, str] = None,
+ config: typing.Mapping[str, typing.Any] = None,
+ globals_: typing.Mapping[str, typing.Any] = None,
+ ):
+ self._module = default_module
+ self._aliases = {} if module_aliases is None else dict(module_aliases)
+ self._config = {} if config is None else dict(config)
+ self._globals = (
+ {} if globals_ is None else self.with_globals(globals_)._globals
+ )
+
+ @classmethod
+ def _new(cls, default_module, module_aliases, config, globals_):
+ rv = cls.__new__(cls)
+ rv._module = default_module
+ rv._aliases = module_aliases
+ rv._config = config
+ rv._globals = globals_
+ return rv
+
+ @classmethod
+ def defaults(cls):
+ return cls()
+
+ def with_default_module(self, module: typing.Optional[str] = None):
+ return self._new(
+ default_module=module,
+ module_aliases=self._aliases,
+ config=self._config,
+ globals_=self._globals,
+ )
+
+ def with_module_aliases(self, *args, **aliases):
+ if len(args) > 1:
+ raise errors.InvalidArgumentError(
+ "with_module_aliases() takes from 0 to 1 positional arguments "
+ "but {} were given".format(len(args))
+ )
+ aliases_dict = args[0] if args else {}
+ aliases_dict.update(aliases)
+ new_aliases = self._aliases.copy()
+ new_aliases.update(aliases_dict)
+ return self._new(
+ default_module=self._module,
+ module_aliases=new_aliases,
+ config=self._config,
+ globals_=self._globals,
+ )
+
+ def with_config(self, *args, **config):
+ if len(args) > 1:
+ raise errors.InvalidArgumentError(
+ "with_config() takes from 0 to 1 positional arguments "
+ "but {} were given".format(len(args))
+ )
+ config_dict = args[0] if args else {}
+ config_dict.update(config)
+ new_config = self._config.copy()
+ new_config.update(config_dict)
+ return self._new(
+ default_module=self._module,
+ module_aliases=self._aliases,
+ config=new_config,
+ globals_=self._globals,
+ )
+
+ def resolve(self, name: str) -> str:
+ parts = name.split("::", 1)
+ if len(parts) == 1:
+ return f"{self._module or 'default'}::{name}"
+ elif len(parts) == 2:
+ mod, name = parts
+ mod = self._aliases.get(mod, mod)
+ return f"{mod}::{name}"
+ else:
+ raise AssertionError('broken split')
+
+ def with_globals(self, *args, **globals_):
+ if len(args) > 1:
+ raise errors.InvalidArgumentError(
+ "with_globals() takes from 0 to 1 positional arguments "
+ "but {} were given".format(len(args))
+ )
+ new_globals = self._globals.copy()
+ if args:
+ for k, v in args[0].items():
+ new_globals[self.resolve(k)] = v
+ for k, v in globals_.items():
+ new_globals[self.resolve(k)] = v
+ return self._new(
+ default_module=self._module,
+ module_aliases=self._aliases,
+ config=self._config,
+ globals_=new_globals,
+ )
+
+ def without_module_aliases(self, *aliases):
+ if not aliases:
+ new_aliases = {}
+ else:
+ new_aliases = self._aliases.copy()
+ for alias in aliases:
+ new_aliases.pop(alias, None)
+ return self._new(
+ default_module=self._module,
+ module_aliases=new_aliases,
+ config=self._config,
+ globals_=self._globals,
+ )
+
+ def without_config(self, *config_names):
+ if not config_names:
+ new_config = {}
+ else:
+ new_config = self._config.copy()
+ for name in config_names:
+ new_config.pop(name, None)
+ return self._new(
+ default_module=self._module,
+ module_aliases=self._aliases,
+ config=new_config,
+ globals_=self._globals,
+ )
+
+ def without_globals(self, *global_names):
+ if not global_names:
+ new_globals = {}
+ else:
+ new_globals = self._globals.copy()
+ for name in global_names:
+ new_globals.pop(self.resolve(name), None)
+ return self._new(
+ default_module=self._module,
+ module_aliases=self._aliases,
+ config=self._config,
+ globals_=new_globals,
+ )
+
+ def as_dict(self):
+ rv = {}
+ if self._module is not None:
+ rv["module"] = self._module
+ if self._aliases:
+ rv["aliases"] = list(self._aliases.items())
+ if self._config:
+ rv["config"] = self._config
+ if self._globals:
+ rv["globals"] = self._globals
+ return rv
+
+
+class _OptionsMixin:
+ def __init__(self, *args, **kwargs):
+ self._options = _Options.defaults()
+ super().__init__(*args, **kwargs)
+
+ @abc.abstractmethod
+ def _shallow_clone(self):
+ pass
+
+ def with_transaction_options(self, options: TransactionOptions = None):
+ """Returns object with adjusted options for future transactions.
+
+ :param options TransactionOptions:
+ Object that encapsulates transaction options.
+
+ This method returns a "shallow copy" of the current object
+ with modified transaction options.
+
+ Both ``self`` and returned object can be used after, but when using
+ them transaction options applied will be different.
+
+ Transaction options are used by the ``transaction`` method.
+ """
+ result = self._shallow_clone()
+ result._options = self._options.with_transaction_options(options)
+ return result
+
+ def with_retry_options(self, options: RetryOptions=None):
+ """Returns object with adjusted options for future retrying
+ transactions.
+
+ :param options RetryOptions:
+ Object that encapsulates retry options.
+
+ This method returns a "shallow copy" of the current object
+ with modified retry options.
+
+ Both ``self`` and returned object can be used after, but when using
+ them retry options applied will be different.
+ """
+
+ result = self._shallow_clone()
+ result._options = self._options.with_retry_options(options)
+ return result
+
+ def with_warning_handler(self, warning_handler: WarningHandler=None):
+ """Returns object with adjusted options for handling warnings.
+
+ :param warning_handler WarningHandler:
+ Function for handling warnings. It is passed a tuple of warnings
+ and the query result and returns a potentially updated query
+ result.
+
+ This method returns a "shallow copy" of the current object
+ with modified retry options.
+
+ Both ``self`` and returned object can be used after, but when using
+ them retry options applied will be different.
+ """
+
+ result = self._shallow_clone()
+ result._options = self._options.with_warning_handler(warning_handler)
+ return result
+
+ def with_state(self, state: State):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(state)
+ return result
+
+ def with_default_module(self, module: typing.Optional[str] = None):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.with_default_module(module)
+ )
+ return result
+
+ def with_module_aliases(self, *args, **aliases):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.with_module_aliases(*args, **aliases)
+ )
+ return result
+
+ def with_config(self, *args, **config):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.with_config(*args, **config)
+ )
+ return result
+
+ def with_globals(self, *args, **globals_):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.with_globals(*args, **globals_)
+ )
+ return result
+
+ def without_module_aliases(self, *aliases):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.without_module_aliases(*aliases)
+ )
+ return result
+
+ def without_config(self, *config_names):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.without_config(*config_names)
+ )
+ return result
+
+ def without_globals(self, *global_names):
+ result = self._shallow_clone()
+ result._options = self._options.with_state(
+ self._options.state.without_globals(*global_names)
+ )
+ return result
+
+
+class _Options:
+ """Internal class for storing connection options"""
+
+ __slots__ = [
+ '_retry_options', '_transaction_options', '_state',
+ '_warning_handler'
+ ]
+
+ def __init__(
+ self,
+ retry_options: RetryOptions,
+ transaction_options: TransactionOptions,
+ state: State,
+ warning_handler: WarningHandler,
+ ):
+ self._retry_options = retry_options
+ self._transaction_options = transaction_options
+ self._state = state
+ self._warning_handler = warning_handler
+
+ @property
+ def retry_options(self):
+ return self._retry_options
+
+ @property
+ def transaction_options(self):
+ return self._transaction_options
+
+ @property
+ def state(self):
+ return self._state
+
+ @property
+ def warning_handler(self):
+ return self._warning_handler
+
+ def with_retry_options(self, options: RetryOptions):
+ return _Options(
+ options,
+ self._transaction_options,
+ self._state,
+ self._warning_handler,
+ )
+
+ def with_transaction_options(self, options: TransactionOptions):
+ return _Options(
+ self._retry_options,
+ options,
+ self._state,
+ self._warning_handler,
+ )
+
+ def with_state(self, state: State):
+ return _Options(
+ self._retry_options,
+ self._transaction_options,
+ state,
+ self._warning_handler,
+ )
+
+ def with_warning_handler(self, warning_handler: WarningHandler):
+ return _Options(
+ self._retry_options,
+ self._transaction_options,
+ self._state,
+ warning_handler,
+ )
+
+ @classmethod
+ def defaults(cls):
+ return cls(
+ RetryOptions.defaults(),
+ TransactionOptions.defaults(),
+ State.defaults(),
+ log_warnings,
+ )
diff --git a/edgedb/pgproto b/gel/pgproto
similarity index 100%
rename from edgedb/pgproto
rename to gel/pgproto
diff --git a/gel/platform.py b/gel/platform.py
new file mode 100644
index 00000000..55410532
--- /dev/null
+++ b/gel/platform.py
@@ -0,0 +1,52 @@
+import functools
+import os
+import pathlib
+import sys
+
+if sys.platform == "darwin":
+ def config_dir() -> pathlib.Path:
+ return (
+ pathlib.Path.home() / "Library" / "Application Support" / "edgedb"
+ )
+
+ IS_WINDOWS = False
+
+elif sys.platform == "win32":
+ import ctypes
+ from ctypes import windll
+
+ def config_dir() -> pathlib.Path:
+ path_buf = ctypes.create_unicode_buffer(255)
+ csidl = 28 # CSIDL_LOCAL_APPDATA
+ windll.shell32.SHGetFolderPathW(0, csidl, 0, 0, path_buf)
+ return pathlib.Path(path_buf.value) / "EdgeDB" / "config"
+
+ IS_WINDOWS = True
+
+else:
+ def config_dir() -> pathlib.Path:
+ xdg_conf_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "."))
+ if not xdg_conf_dir.is_absolute():
+ xdg_conf_dir = pathlib.Path.home() / ".config"
+ return xdg_conf_dir / "edgedb"
+
+ IS_WINDOWS = False
+
+
+def old_config_dir() -> pathlib.Path:
+ return pathlib.Path.home() / ".edgedb"
+
+
+def search_config_dir(*suffix):
+ rv = functools.reduce(lambda p1, p2: p1 / p2, [config_dir(), *suffix])
+ if rv.exists():
+ return rv
+
+ fallback = functools.reduce(
+ lambda p1, p2: p1 / p2, [old_config_dir(), *suffix]
+ )
+ if fallback.exists():
+ return fallback
+
+ # None of the searched files exists, return the new path.
+ return rv
diff --git a/edgedb/protocol/.gitignore b/gel/protocol/.gitignore
similarity index 100%
rename from edgedb/protocol/.gitignore
rename to gel/protocol/.gitignore
diff --git a/gel/protocol/__init__.py b/gel/protocol/__init__.py
new file mode 100644
index 00000000..46ceb625
--- /dev/null
+++ b/gel/protocol/__init__.py
@@ -0,0 +1,17 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/edgedb/protocol/asyncio_proto.pxd b/gel/protocol/asyncio_proto.pxd
similarity index 95%
rename from edgedb/protocol/asyncio_proto.pxd
rename to gel/protocol/asyncio_proto.pxd
index 4eeb68e9..8919c24b 100644
--- a/edgedb/protocol/asyncio_proto.pxd
+++ b/gel/protocol/asyncio_proto.pxd
@@ -19,7 +19,7 @@
from . cimport protocol
-from edgedb.pgproto.debug cimport PG_DEBUG
+from gel.pgproto.debug cimport PG_DEBUG
cdef class AsyncIOProtocol(protocol.SansIOProtocolBackwardsCompatible):
diff --git a/edgedb/protocol/asyncio_proto.pyx b/gel/protocol/asyncio_proto.pyx
similarity index 98%
rename from edgedb/protocol/asyncio_proto.pyx
rename to gel/protocol/asyncio_proto.pyx
index db3f616f..39d3b62a 100644
--- a/edgedb/protocol/asyncio_proto.pyx
+++ b/gel/protocol/asyncio_proto.pyx
@@ -19,8 +19,8 @@
import asyncio
-from edgedb import errors
-from edgedb.pgproto.pgproto cimport (
+from gel import errors
+from gel.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
)
diff --git a/edgedb/protocol/blocking_proto.pxd b/gel/protocol/blocking_proto.pxd
similarity index 95%
rename from edgedb/protocol/blocking_proto.pxd
rename to gel/protocol/blocking_proto.pxd
index cd19c2bb..13e4707e 100644
--- a/edgedb/protocol/blocking_proto.pxd
+++ b/gel/protocol/blocking_proto.pxd
@@ -19,7 +19,7 @@
from . cimport protocol
-from edgedb.pgproto.debug cimport PG_DEBUG
+from gel.pgproto.debug cimport PG_DEBUG
cdef class BlockingIOProtocol(protocol.SansIOProtocolBackwardsCompatible):
diff --git a/edgedb/protocol/blocking_proto.pyx b/gel/protocol/blocking_proto.pyx
similarity index 99%
rename from edgedb/protocol/blocking_proto.pyx
rename to gel/protocol/blocking_proto.pyx
index ea4c1c16..839fbe42 100644
--- a/edgedb/protocol/blocking_proto.pyx
+++ b/gel/protocol/blocking_proto.pyx
@@ -19,7 +19,7 @@
import socket
import time
-from edgedb.pgproto.pgproto cimport (
+from gel.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
)
diff --git a/edgedb/protocol/codecs/array.pxd b/gel/protocol/codecs/array.pxd
similarity index 100%
rename from edgedb/protocol/codecs/array.pxd
rename to gel/protocol/codecs/array.pxd
diff --git a/edgedb/protocol/codecs/array.pyx b/gel/protocol/codecs/array.pyx
similarity index 100%
rename from edgedb/protocol/codecs/array.pyx
rename to gel/protocol/codecs/array.pyx
diff --git a/edgedb/protocol/codecs/base.pxd b/gel/protocol/codecs/base.pxd
similarity index 100%
rename from edgedb/protocol/codecs/base.pxd
rename to gel/protocol/codecs/base.pxd
diff --git a/edgedb/protocol/codecs/base.pyx b/gel/protocol/codecs/base.pyx
similarity index 100%
rename from edgedb/protocol/codecs/base.pyx
rename to gel/protocol/codecs/base.pyx
diff --git a/edgedb/protocol/codecs/codecs.pxd b/gel/protocol/codecs/codecs.pxd
similarity index 100%
rename from edgedb/protocol/codecs/codecs.pxd
rename to gel/protocol/codecs/codecs.pxd
diff --git a/edgedb/protocol/codecs/codecs.pyx b/gel/protocol/codecs/codecs.pyx
similarity index 99%
rename from edgedb/protocol/codecs/codecs.pyx
rename to gel/protocol/codecs/codecs.pyx
index 2887320b..afa3b83f 100644
--- a/edgedb/protocol/codecs/codecs.pyx
+++ b/gel/protocol/codecs/codecs.pyx
@@ -21,9 +21,9 @@ import array
import decimal
import uuid
import datetime
-from edgedb import describe
-from edgedb import enums
-from edgedb.datatypes import datatypes
+from gel import describe
+from gel import enums
+from gel.datatypes import datatypes
from libc.string cimport memcpy
from cpython.bytes cimport PyBytes_FromStringAndSize
diff --git a/edgedb/protocol/codecs/edb_types.pxi b/gel/protocol/codecs/edb_types.pxi
similarity index 100%
rename from edgedb/protocol/codecs/edb_types.pxi
rename to gel/protocol/codecs/edb_types.pxi
diff --git a/edgedb/protocol/codecs/enum.pxd b/gel/protocol/codecs/enum.pxd
similarity index 100%
rename from edgedb/protocol/codecs/enum.pxd
rename to gel/protocol/codecs/enum.pxd
diff --git a/edgedb/protocol/codecs/enum.pyx b/gel/protocol/codecs/enum.pyx
similarity index 92%
rename from edgedb/protocol/codecs/enum.pyx
rename to gel/protocol/codecs/enum.pyx
index d7fe5620..0020397f 100644
--- a/edgedb/protocol/codecs/enum.pyx
+++ b/gel/protocol/codecs/enum.pyx
@@ -28,7 +28,7 @@ cdef class EnumCodec(BaseCodec):
obj = self.cls._try_from(obj)
except (TypeError, ValueError):
raise TypeError(
- f'a str or edgedb.EnumValue(__tid__={self.cls.__tid__}) '
+ f'a str or gel.EnumValue(__tid__={self.cls.__tid__}) '
f'is expected as a valid enum argument, '
f'got {type(obj).__name__}') from None
pgproto.text_encode(DEFAULT_CODEC_CONTEXT, buf, str(obj))
@@ -49,8 +49,8 @@ cdef class EnumCodec(BaseCodec):
cls = "DerivedEnumValue"
bases = (datatypes.EnumValue,)
classdict = enum.EnumMeta.__prepare__(cls, bases)
- classdict["__module__"] = "edgedb"
- classdict["__qualname__"] = "edgedb.DerivedEnumValue"
+ classdict["__module__"] = "gel"
+ classdict["__qualname__"] = "gel.DerivedEnumValue"
classdict["__tid__"] = pgproto.UUID(tid)
for label in enum_labels:
classdict[label.upper()] = label
diff --git a/edgedb/protocol/codecs/namedtuple.pxd b/gel/protocol/codecs/namedtuple.pxd
similarity index 100%
rename from edgedb/protocol/codecs/namedtuple.pxd
rename to gel/protocol/codecs/namedtuple.pxd
diff --git a/edgedb/protocol/codecs/namedtuple.pyx b/gel/protocol/codecs/namedtuple.pyx
similarity index 100%
rename from edgedb/protocol/codecs/namedtuple.pyx
rename to gel/protocol/codecs/namedtuple.pyx
diff --git a/edgedb/protocol/codecs/object.pxd b/gel/protocol/codecs/object.pxd
similarity index 100%
rename from edgedb/protocol/codecs/object.pxd
rename to gel/protocol/codecs/object.pxd
diff --git a/edgedb/protocol/codecs/object.pyx b/gel/protocol/codecs/object.pyx
similarity index 100%
rename from edgedb/protocol/codecs/object.pyx
rename to gel/protocol/codecs/object.pyx
diff --git a/edgedb/protocol/codecs/range.pxd b/gel/protocol/codecs/range.pxd
similarity index 100%
rename from edgedb/protocol/codecs/range.pxd
rename to gel/protocol/codecs/range.pxd
diff --git a/edgedb/protocol/codecs/range.pyx b/gel/protocol/codecs/range.pyx
similarity index 99%
rename from edgedb/protocol/codecs/range.pyx
rename to gel/protocol/codecs/range.pyx
index ea573b89..52a656c4 100644
--- a/edgedb/protocol/codecs/range.pyx
+++ b/gel/protocol/codecs/range.pyx
@@ -17,7 +17,7 @@
#
-from edgedb.datatypes import range as range_mod
+from gel.datatypes import range as range_mod
cdef uint8_t RANGE_EMPTY = 0x01
diff --git a/edgedb/protocol/codecs/scalar.pxd b/gel/protocol/codecs/scalar.pxd
similarity index 100%
rename from edgedb/protocol/codecs/scalar.pxd
rename to gel/protocol/codecs/scalar.pxd
diff --git a/edgedb/protocol/codecs/scalar.pyx b/gel/protocol/codecs/scalar.pyx
similarity index 100%
rename from edgedb/protocol/codecs/scalar.pyx
rename to gel/protocol/codecs/scalar.pyx
diff --git a/edgedb/protocol/codecs/set.pxd b/gel/protocol/codecs/set.pxd
similarity index 100%
rename from edgedb/protocol/codecs/set.pxd
rename to gel/protocol/codecs/set.pxd
diff --git a/edgedb/protocol/codecs/set.pyx b/gel/protocol/codecs/set.pyx
similarity index 100%
rename from edgedb/protocol/codecs/set.pyx
rename to gel/protocol/codecs/set.pyx
diff --git a/edgedb/protocol/codecs/tuple.pxd b/gel/protocol/codecs/tuple.pxd
similarity index 100%
rename from edgedb/protocol/codecs/tuple.pxd
rename to gel/protocol/codecs/tuple.pxd
diff --git a/edgedb/protocol/codecs/tuple.pyx b/gel/protocol/codecs/tuple.pyx
similarity index 100%
rename from edgedb/protocol/codecs/tuple.pyx
rename to gel/protocol/codecs/tuple.pyx
diff --git a/edgedb/protocol/consts.pxi b/gel/protocol/consts.pxi
similarity index 100%
rename from edgedb/protocol/consts.pxi
rename to gel/protocol/consts.pxi
diff --git a/edgedb/protocol/cpythonx.pxd b/gel/protocol/cpythonx.pxd
similarity index 100%
rename from edgedb/protocol/cpythonx.pxd
rename to gel/protocol/cpythonx.pxd
diff --git a/edgedb/protocol/lru.pxd b/gel/protocol/lru.pxd
similarity index 100%
rename from edgedb/protocol/lru.pxd
rename to gel/protocol/lru.pxd
diff --git a/edgedb/protocol/lru.pyx b/gel/protocol/lru.pyx
similarity index 100%
rename from edgedb/protocol/lru.pyx
rename to gel/protocol/lru.pyx
diff --git a/edgedb/protocol/protocol.pxd b/gel/protocol/protocol.pxd
similarity index 97%
rename from edgedb/protocol/protocol.pxd
rename to gel/protocol/protocol.pxd
index b6e143ae..7c5f943f 100644
--- a/edgedb/protocol/protocol.pxd
+++ b/gel/protocol/protocol.pxd
@@ -23,14 +23,14 @@ cimport cpython
from libc.stdint cimport int16_t, int32_t, uint16_t, \
uint32_t, int64_t, uint64_t
-from edgedb.pgproto.pgproto cimport (
+from gel.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
FRBuffer,
)
-from edgedb.pgproto cimport pgproto
-from edgedb.pgproto.debug cimport PG_DEBUG
+from gel.pgproto cimport pgproto
+from gel.pgproto.debug cimport PG_DEBUG
include "./lru.pxd"
diff --git a/edgedb/protocol/protocol.pyx b/gel/protocol/protocol.pyx
similarity index 99%
rename from edgedb/protocol/protocol.pyx
rename to gel/protocol/protocol.pyx
index e464a039..e22efef7 100644
--- a/edgedb/protocol/protocol.pyx
+++ b/gel/protocol/protocol.pyx
@@ -30,7 +30,7 @@ import types
import typing
import weakref
-from edgedb.pgproto.pgproto cimport (
+from gel.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
@@ -44,22 +44,22 @@ from edgedb.pgproto.pgproto cimport (
frb_get_len,
)
-from edgedb.pgproto import pgproto
-from edgedb.pgproto cimport pgproto
-from edgedb.pgproto cimport hton
-from edgedb.pgproto.pgproto import UUID
+from gel.pgproto import pgproto
+from gel.pgproto cimport pgproto
+from gel.pgproto cimport hton
+from gel.pgproto.pgproto import UUID
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t, \
UINT32_MAX
-from edgedb.datatypes cimport datatypes
+from gel.datatypes cimport datatypes
from . cimport cpythonx
-from edgedb import enums
-from edgedb import errors
-from edgedb import scram
+from gel import enums
+from gel import errors
+from gel import scram
include "./consts.pxi"
diff --git a/edgedb/protocol/protocol_v0.pxd b/gel/protocol/protocol_v0.pxd
similarity index 100%
rename from edgedb/protocol/protocol_v0.pxd
rename to gel/protocol/protocol_v0.pxd
diff --git a/edgedb/protocol/protocol_v0.pyx b/gel/protocol/protocol_v0.pyx
similarity index 99%
rename from edgedb/protocol/protocol_v0.pyx
rename to gel/protocol/protocol_v0.pyx
index 2a4cb80b..1ba6852c 100644
--- a/edgedb/protocol/protocol_v0.pyx
+++ b/gel/protocol/protocol_v0.pyx
@@ -17,7 +17,7 @@
#
-from edgedb import enums
+from gel import enums
DEF QUERY_OPT_IMPLICIT_LIMIT = 0xFF01
diff --git a/gel/py.typed b/gel/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/gel/scram/__init__.py b/gel/scram/__init__.py
new file mode 100644
index 00000000..62f3e712
--- /dev/null
+++ b/gel/scram/__init__.py
@@ -0,0 +1,434 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Helpers for SCRAM authentication."""
+
+import base64
+import hashlib
+import hmac
+import os
+import typing
+
+from .saslprep import saslprep
+
+
+RAW_NONCE_LENGTH = 18
+
+# Per recommendations in RFC 7677.
+DEFAULT_SALT_LENGTH = 16
+DEFAULT_ITERATIONS = 4096
+
+
+def generate_salt(length: int = DEFAULT_SALT_LENGTH) -> bytes:
+ return os.urandom(length)
+
+
+def generate_nonce(length: int = RAW_NONCE_LENGTH) -> str:
+ return B64(os.urandom(length))
+
+
+def build_verifier(password: str, *, salt: typing.Optional[bytes] = None,
+ iterations: int = DEFAULT_ITERATIONS) -> str:
+ """Build the SCRAM verifier for the given password.
+
+ Returns a string in the following format:
+
+ "$:$:"
+
+ The salt and keys are base64-encoded values.
+ """
+ password = saslprep(password).encode('utf-8')
+
+ if salt is None:
+ salt = generate_salt()
+
+ salted_password = get_salted_password(password, salt, iterations)
+ client_key = get_client_key(salted_password)
+ stored_key = H(client_key)
+ server_key = get_server_key(salted_password)
+
+ return (f'SCRAM-SHA-256${iterations}:{B64(salt)}$'
+ f'{B64(stored_key)}:{B64(server_key)}')
+
+
+class SCRAMVerifier(typing.NamedTuple):
+
+ mechanism: str
+ iterations: int
+ salt: bytes
+ stored_key: bytes
+ server_key: bytes
+
+
+def parse_verifier(verifier: str) -> SCRAMVerifier:
+
+ parts = verifier.split('$')
+ if len(parts) != 3:
+ raise ValueError('invalid SCRAM verifier')
+
+ mechanism = parts[0]
+ if mechanism != 'SCRAM-SHA-256':
+ raise ValueError('invalid SCRAM verifier')
+
+ iterations, _, salt = parts[1].partition(':')
+ stored_key, _, server_key = parts[2].partition(':')
+ if not salt or not server_key:
+ raise ValueError('invalid SCRAM verifier')
+
+ try:
+ iterations = int(iterations)
+ except ValueError:
+ raise ValueError('invalid SCRAM verifier') from None
+
+ return SCRAMVerifier(
+ mechanism=mechanism,
+ iterations=iterations,
+ salt=base64.b64decode(salt),
+ stored_key=base64.b64decode(stored_key),
+ server_key=base64.b64decode(server_key),
+ )
+
+
+def parse_client_first_message(resp: bytes):
+
+ # Relevant bits of RFC 5802:
+ #
+ # saslname = 1*(value-safe-char / "=2C" / "=3D")
+ # ;; Conforms to .
+ #
+ # authzid = "a=" saslname
+ # ;; Protocol specific.
+ #
+ # cb-name = 1*(ALPHA / DIGIT / "." / "-")
+ # ;; See RFC 5056, Section 7.
+ # ;; E.g., "tls-server-end-point" or
+ # ;; "tls-unique".
+ #
+ # gs2-cbind-flag = ("p=" cb-name) / "n" / "y"
+ # ;; "n" -> client doesn't support channel binding.
+ # ;; "y" -> client does support channel binding
+ # ;; but thinks the server does not.
+ # ;; "p" -> client requires channel binding.
+ # ;; The selected channel binding follows "p=".
+ #
+ # gs2-header = gs2-cbind-flag "," [ authzid ] ","
+ # ;; GS2 header for SCRAM
+ # ;; (the actual GS2 header includes an optional
+ # ;; flag to indicate that the GSS mechanism is not
+ # ;; "standard", but since SCRAM is "standard", we
+ # ;; don't include that flag).
+ #
+ # username = "n=" saslname
+ # ;; Usernames are prepared using SASLprep.
+ #
+ # reserved-mext = "m=" 1*(value-char)
+ # ;; Reserved for signaling mandatory extensions.
+ # ;; The exact syntax will be defined in
+ # ;; the future.
+ #
+ # nonce = "r=" c-nonce [s-nonce]
+ # ;; Second part provided by server.
+ #
+ # c-nonce = printable
+ #
+ # client-first-message-bare =
+ # [reserved-mext ","]
+ # username "," nonce ["," extensions]
+ #
+ # client-first-message =
+ # gs2-header client-first-message-bare
+
+ attrs = resp.split(b',')
+
+ cb_attr = attrs[0]
+ if cb_attr == b'y':
+ cb = True
+ elif cb_attr == b'n':
+ cb = False
+ elif cb_attr[0:1] == b'p':
+ _, _, cb = cb_attr.partition(b'=')
+ if not cb:
+ raise ValueError('malformed SCRAM message')
+ else:
+ raise ValueError('malformed SCRAM message')
+
+ authzid_attr = attrs[1]
+ if authzid_attr:
+ if authzid_attr[0:1] != b'a':
+ raise ValueError('malformed SCRAM message')
+ _, _, authzid = authzid_attr.partition(b'=')
+ else:
+ authzid = None
+
+ user_attr = attrs[2]
+ if user_attr[0:1] == b'm':
+ raise ValueError('unsupported SCRAM extensions in message')
+ elif user_attr[0:1] != b'n':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, user = user_attr.partition(b'=')
+
+ nonce_attr = attrs[3]
+ if nonce_attr[0:1] != b'r':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, nonce_bin = nonce_attr.partition(b'=')
+ nonce = nonce_bin.decode('ascii')
+ if not nonce.isprintable():
+ raise ValueError('invalid characters in client nonce')
+
+ # ["," extensions] are ignored
+
+ return len(cb_attr) + 2, cb, authzid, user, nonce
+
+
+def parse_client_final_message(
+ msg: bytes, client_nonce: str, server_nonce: str):
+
+ # Relevant bits of RFC 5802:
+ #
+ # gs2-header = gs2-cbind-flag "," [ authzid ] ","
+ # ;; GS2 header for SCRAM
+ # ;; (the actual GS2 header includes an optional
+ # ;; flag to indicate that the GSS mechanism is not
+ # ;; "standard", but since SCRAM is "standard", we
+ # ;; don't include that flag).
+ #
+ # cbind-input = gs2-header [ cbind-data ]
+ # ;; cbind-data MUST be present for
+ # ;; gs2-cbind-flag of "p" and MUST be absent
+ # ;; for "y" or "n".
+ #
+ # channel-binding = "c=" base64
+ # ;; base64 encoding of cbind-input.
+ #
+ # proof = "p=" base64
+ #
+ # client-final-message-without-proof =
+ # channel-binding "," nonce [","
+ # extensions]
+ #
+ # client-final-message =
+ # client-final-message-without-proof "," proof
+
+ attrs = msg.split(b',')
+
+ cb_attr = attrs[0]
+ if cb_attr[0:1] != b'c':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, cb_data = cb_attr.partition(b'=')
+
+ nonce_attr = attrs[1]
+ if nonce_attr[0:1] != b'r':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, nonce_bin = nonce_attr.partition(b'=')
+ nonce = nonce_bin.decode('ascii')
+
+ expected_nonce = f'{client_nonce}{server_nonce}'
+
+ if nonce != expected_nonce:
+ raise ValueError(
+ 'invalid SCRAM client-final message: nonce does not match')
+
+ proof = None
+
+ for attr in attrs[2:]:
+ if attr[0:1] == b'p':
+ _, _, proof = attr.partition(b'=')
+ proof_attr_len = len(attr)
+ proof = base64.b64decode(proof)
+ elif proof is not None:
+ raise ValueError('malformed SCRAM message')
+
+ if proof is None:
+ raise ValueError('malformed SCRAM message')
+
+ return cb_data, proof, proof_attr_len + 1
+
+
+def build_client_first_message(client_nonce: str, username: str) -> str:
+
+ bare = f'n={saslprep(username)},r={client_nonce}'
+ return f'n,,{bare}', bare
+
+
+def build_server_first_message(server_nonce: str, client_nonce: str,
+ salt: bytes, iterations: int) -> str:
+
+ return (
+ f'r={client_nonce}{server_nonce},'
+ f's={B64(salt)},i={iterations}'
+ )
+
+
+def build_auth_message(
+ client_first_bare: bytes,
+ server_first: bytes, client_final: bytes) -> bytes:
+
+ return b'%b,%b,%b' % (client_first_bare, server_first, client_final)
+
+
+def build_client_final_message(
+ password: str,
+ salt: bytes,
+ iterations: int,
+ client_first_bare: bytes,
+ server_first: bytes,
+ server_nonce: str) -> str:
+
+ client_final = f'c=biws,r={server_nonce}'
+
+ AuthMessage = build_auth_message(
+ client_first_bare, server_first, client_final.encode('utf-8'))
+
+ SaltedPassword = get_salted_password(
+ saslprep(password).encode('utf-8'),
+ salt,
+ iterations)
+
+ ClientKey = get_client_key(SaltedPassword)
+ StoredKey = H(ClientKey)
+ ClientSignature = HMAC(StoredKey, AuthMessage)
+ ClientProof = XOR(ClientKey, ClientSignature)
+
+ ServerKey = get_server_key(SaltedPassword)
+ ServerProof = HMAC(ServerKey, AuthMessage)
+
+ return f'{client_final},p={B64(ClientProof)}', ServerProof
+
+
+def build_server_final_message(
+ client_first_bare: bytes, server_first: bytes,
+ client_final: bytes, server_key: bytes) -> str:
+
+ AuthMessage = build_auth_message(
+ client_first_bare, server_first, client_final)
+ ServerSignature = HMAC(server_key, AuthMessage)
+ return f'v={B64(ServerSignature)}'
+
+
+def parse_server_first_message(msg: bytes):
+
+ attrs = msg.split(b',')
+
+ nonce_attr = attrs[0]
+ if nonce_attr[0:1] != b'r':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, nonce_bin = nonce_attr.partition(b'=')
+ nonce = nonce_bin.decode('ascii')
+ if not nonce.isprintable():
+ raise ValueError('malformed SCRAM message')
+
+ salt_attr = attrs[1]
+ if salt_attr[0:1] != b's':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, salt_b64 = salt_attr.partition(b'=')
+ salt = base64.b64decode(salt_b64)
+
+ iter_attr = attrs[2]
+ if iter_attr[0:1] != b'i':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, iterations = iter_attr.partition(b'=')
+
+ try:
+ itercount = int(iterations)
+ except ValueError:
+ raise ValueError('malformed SCRAM message') from None
+
+ return nonce, salt, itercount
+
+
+def parse_server_final_message(msg: bytes):
+
+ attrs = msg.split(b',')
+
+ nonce_attr = attrs[0]
+ if nonce_attr[0:1] != b'v':
+ raise ValueError('malformed SCRAM message')
+
+ _, _, signature_b64 = nonce_attr.partition(b'=')
+ signature = base64.b64decode(signature_b64)
+
+ return signature
+
+
+def verify_password(password: bytes, verifier: str) -> bool:
+ """Check the given password against a verifier.
+
+ Returns True if the password is OK, False otherwise.
+ """
+
+ password = saslprep(password).encode('utf-8')
+ v = parse_verifier(verifier)
+ salted_password = get_salted_password(password, v.salt, v.iterations)
+ computed_key = get_server_key(salted_password)
+ return v.server_key == computed_key
+
+
+def verify_client_proof(client_first: bytes, server_first: bytes,
+ client_final: bytes, StoredKey: bytes,
+ ClientProof: bytes) -> bool:
+ AuthMessage = build_auth_message(client_first, server_first, client_final)
+ ClientSignature = HMAC(StoredKey, AuthMessage)
+ ClientKey = XOR(ClientProof, ClientSignature)
+ return H(ClientKey) == StoredKey
+
+
+def B64(val: bytes) -> str:
+ """Return base64-encoded string representation of input binary data."""
+ return base64.b64encode(val).decode()
+
+
+def HMAC(key: bytes, msg: bytes) -> bytes:
+ return hmac.new(key, msg, digestmod=hashlib.sha256).digest()
+
+
+def XOR(a: bytes, b: bytes) -> bytes:
+ if len(a) != len(b):
+ raise ValueError('scram.XOR received operands of unequal length')
+ xint = int.from_bytes(a, 'big') ^ int.from_bytes(b, 'big')
+ return xint.to_bytes(len(a), 'big')
+
+
+def H(s: bytes) -> bytes:
+ return hashlib.sha256(s).digest()
+
+
+def get_salted_password(password: bytes, salt: bytes,
+ iterations: int) -> bytes:
+ # U1 := HMAC(str, salt + INT(1))
+ H_i = U_i = HMAC(password, salt + b'\x00\x00\x00\x01')
+
+ for _ in range(iterations - 1):
+ U_i = HMAC(password, U_i)
+ H_i = XOR(H_i, U_i)
+
+ return H_i
+
+
+def get_client_key(salted_password: bytes) -> bytes:
+ return HMAC(salted_password, b'Client Key')
+
+
+def get_server_key(salted_password: bytes) -> bytes:
+ return HMAC(salted_password, b'Server Key')
diff --git a/gel/scram/saslprep.py b/gel/scram/saslprep.py
new file mode 100644
index 00000000..79eb84d8
--- /dev/null
+++ b/gel/scram/saslprep.py
@@ -0,0 +1,82 @@
+# Copyright 2016-present MongoDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import stringprep
+import unicodedata
+
+
+# RFC4013 section 2.3 prohibited output.
+_PROHIBITED = (
+ # A strict reading of RFC 4013 requires table c12 here, but
+ # characters from it are mapped to SPACE in the Map step. Can
+ # normalization reintroduce them somehow?
+ stringprep.in_table_c12,
+ stringprep.in_table_c21_c22,
+ stringprep.in_table_c3,
+ stringprep.in_table_c4,
+ stringprep.in_table_c5,
+ stringprep.in_table_c6,
+ stringprep.in_table_c7,
+ stringprep.in_table_c8,
+ stringprep.in_table_c9)
+
+
+def saslprep(data: str, prohibit_unassigned_code_points=True):
+ """An implementation of RFC4013 SASLprep."""
+
+ if data == '':
+ return data
+
+ if prohibit_unassigned_code_points:
+ prohibited = _PROHIBITED + (stringprep.in_table_a1,)
+ else:
+ prohibited = _PROHIBITED
+
+ # RFC3454 section 2, step 1 - Map
+ # RFC4013 section 2.1 mappings
+ # Map Non-ASCII space characters to SPACE (U+0020). Map
+ # commonly mapped to nothing characters to, well, nothing.
+ in_table_c12 = stringprep.in_table_c12
+ in_table_b1 = stringprep.in_table_b1
+ data = u"".join(
+ [u"\u0020" if in_table_c12(elt) else elt
+ for elt in data if not in_table_b1(elt)])
+
+ # RFC3454 section 2, step 2 - Normalize
+ # RFC4013 section 2.2 normalization
+ data = unicodedata.ucd_3_2_0.normalize('NFKC', data)
+
+ in_table_d1 = stringprep.in_table_d1
+ if in_table_d1(data[0]):
+ if not in_table_d1(data[-1]):
+ # RFC3454, Section 6, #3. If a string contains any
+ # RandALCat character, the first and last characters
+ # MUST be RandALCat characters.
+ raise ValueError("SASLprep: failed bidirectional check")
+ # RFC3454, Section 6, #2. If a string contains any RandALCat
+ # character, it MUST NOT contain any LCat character.
+ prohibited = prohibited + (stringprep.in_table_d2,)
+ else:
+ # RFC3454, Section 6, #3. Following the logic of #3, if
+ # the first character is not a RandALCat, no other character
+ # can be either.
+ prohibited = prohibited + (in_table_d1,)
+
+ # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi
+ for char in data:
+ if any(in_table(char) for in_table in prohibited):
+ raise ValueError(
+ "SASLprep: failed prohibited character check")
+
+ return data
diff --git a/gel/transaction.py b/gel/transaction.py
new file mode 100644
index 00000000..8f3ab29f
--- /dev/null
+++ b/gel/transaction.py
@@ -0,0 +1,224 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import enum
+
+from . import abstract
+from . import errors
+from . import options
+
+
+class TransactionState(enum.Enum):
+ NEW = 0
+ STARTED = 1
+ COMMITTED = 2
+ ROLLEDBACK = 3
+ FAILED = 4
+
+
+class BaseTransaction:
+
+ __slots__ = (
+ '_client',
+ '_connection',
+ '_options',
+ '_state',
+ '__retry',
+ '__iteration',
+ '__started',
+ )
+
+ def __init__(self, retry, client, iteration):
+ self._client = client
+ self._connection = None
+ self._options = retry._options.transaction_options
+ self._state = TransactionState.NEW
+ self.__retry = retry
+ self.__iteration = iteration
+ self.__started = False
+
+ def is_active(self) -> bool:
+ return self._state is TransactionState.STARTED
+
+ def __check_state_base(self, opname):
+ if self._state is TransactionState.COMMITTED:
+ raise errors.InterfaceError(
+ 'cannot {}; the transaction is already committed'.format(
+ opname))
+ if self._state is TransactionState.ROLLEDBACK:
+ raise errors.InterfaceError(
+ 'cannot {}; the transaction is already rolled back'.format(
+ opname))
+ if self._state is TransactionState.FAILED:
+ raise errors.InterfaceError(
+ 'cannot {}; the transaction is in error state'.format(
+ opname))
+
+ def __check_state(self, opname):
+ if self._state is not TransactionState.STARTED:
+ if self._state is TransactionState.NEW:
+ raise errors.InterfaceError(
+ 'cannot {}; the transaction is not yet started'.format(
+ opname))
+ self.__check_state_base(opname)
+
+ def _make_start_query(self):
+ self.__check_state_base('start')
+ if self._state is TransactionState.STARTED:
+ raise errors.InterfaceError(
+ 'cannot start; the transaction is already started')
+
+ return self._options.start_transaction_query()
+
+ def _make_commit_query(self):
+ self.__check_state('commit')
+ return 'COMMIT;'
+
+ def _make_rollback_query(self):
+ self.__check_state('rollback')
+ return 'ROLLBACK;'
+
+ def __repr__(self):
+ attrs = []
+ attrs.append('state:{}'.format(self._state.name.lower()))
+ attrs.append(repr(self._options))
+
+ if self.__class__.__module__.startswith('gel.'):
+ mod = 'gel'
+ else:
+ mod = self.__class__.__module__
+
+ return '<{}.{} {} {:#x}>'.format(
+ mod, self.__class__.__name__, ' '.join(attrs), id(self))
+
+ async def _ensure_transaction(self):
+ if not self.__started:
+ self.__started = True
+ query = self._make_start_query()
+ self._connection = await self._client._impl.acquire()
+ if self._connection.is_closed():
+ await self._connection.connect(
+ single_attempt=self.__iteration != 0
+ )
+ try:
+ await self._privileged_execute(query)
+ except BaseException:
+ self._state = TransactionState.FAILED
+ raise
+ else:
+ self._state = TransactionState.STARTED
+
+ async def _exit(self, extype, ex):
+ if not self.__started:
+ return False
+
+ try:
+ if extype is None:
+ query = self._make_commit_query()
+ state = TransactionState.COMMITTED
+ else:
+ query = self._make_rollback_query()
+ state = TransactionState.ROLLEDBACK
+ try:
+ await self._privileged_execute(query)
+ except BaseException:
+ self._state = TransactionState.FAILED
+ if extype is None:
+ # COMMIT itself may fail; recover in connection
+ await self._privileged_execute("ROLLBACK;")
+ raise
+ else:
+ self._state = state
+ except errors.EdgeDBError as err:
+ if ex is None:
+ # On commit we don't know if commit is succeeded before the
+ # database have received it or after it have been done but
+ # network is dropped before we were able to receive a response.
+ # On a TransactionError, though, we know the we need
+ # to retry.
+ # TODO(tailhook) should other errors have retries?
+ if (
+ isinstance(err, errors.TransactionError)
+ and err.has_tag(errors.SHOULD_RETRY)
+ and self.__retry._retry(err)
+ ):
+ pass
+ else:
+ raise err
+ # If we were going to rollback, look at original error
+ # to find out whether we want to retry, regardless of
+ # the rollback error.
+ # In this case we ignore rollback issue as original error is more
+ # important, e.g. in case `CancelledError` it's important
+ # to propagate it to cancel the whole task.
+ # NOTE: rollback error is always swallowed, should we use
+ # on_log_message for it?
+ finally:
+ await self._client._impl.release(self._connection)
+
+ if (
+ extype is not None and
+ issubclass(extype, errors.EdgeDBError) and
+ ex.has_tag(errors.SHOULD_RETRY)
+ ):
+ return self.__retry._retry(ex)
+
+ def _get_query_cache(self) -> abstract.QueryCache:
+ return self._client._get_query_cache()
+
+ def _get_state(self) -> options.State:
+ return self._client._get_state()
+
+ def _get_warning_handler(self) -> options.WarningHandler:
+ return self._client._get_warning_handler()
+
+ async def _query(self, query_context: abstract.QueryContext):
+ await self._ensure_transaction()
+ return await self._connection.raw_query(query_context)
+
+ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
+ await self._ensure_transaction()
+ await self._connection._execute(execute_context)
+
+ async def _privileged_execute(self, query: str) -> None:
+ await self._connection.privileged_execute(abstract.ExecuteContext(
+ query=abstract.QueryWithArgs(query, (), {}),
+ cache=self._get_query_cache(),
+ state=self._get_state(),
+ warning_handler=self._get_warning_handler(),
+ ))
+
+
+class BaseRetry:
+
+ def __init__(self, owner):
+ self._owner = owner
+ self._iteration = 0
+ self._done = False
+ self._next_backoff = 0
+ self._options = owner._options
+
+ def _retry(self, exc):
+ self._last_exception = exc
+ rule = self._options.retry_options.get_rule_for_exception(exc)
+ if self._iteration >= rule.attempts:
+ return False
+ self._done = False
+ self._next_backoff = rule.backoff(self._iteration)
+ return True
diff --git a/setup.py b/setup.py
index 293030af..f4f445a8 100644
--- a/setup.py
+++ b/setup.py
@@ -98,7 +98,7 @@
readme = f.read()
-with open(str(_ROOT / 'edgedb' / '_version.py')) as f:
+with open(str(_ROOT / 'gel' / '_version.py')) as f:
for line in f:
if line.startswith('__version__ ='):
_, _, version = line.partition('=')
@@ -106,7 +106,7 @@
break
else:
raise RuntimeError(
- 'unable to read the version from edgedb/_version.py')
+ 'unable to read the version from gel/_version.py')
if (_ROOT / '.git').is_dir() and 'dev' in VERSION:
@@ -267,8 +267,8 @@ def finalize_options(self):
INCLUDE_DIRS = [
- 'edgedb/pgproto/',
- 'edgedb/datatypes',
+ 'gel/pgproto/',
+ 'gel/datatypes',
]
@@ -295,47 +295,50 @@ def finalize_options(self):
url='https://github.com/edgedb/edgedb-python',
license='Apache License, Version 2.0',
packages=setuptools.find_packages(),
- provides=['edgedb'],
+ provides=['edgedb', 'gel'],
zip_safe=False,
include_package_data=True,
- package_data={'edgedb': ['py.typed']},
+ package_data={
+ 'edgedb': ['py.typed'],
+ 'gel': ['py.typed'],
+ },
ext_modules=[
distutils_extension.Extension(
- "edgedb.pgproto.pgproto",
- ["edgedb/pgproto/pgproto.pyx"],
+ "gel.pgproto.pgproto",
+ ["gel/pgproto/pgproto.pyx"],
extra_compile_args=CFLAGS,
extra_link_args=LDFLAGS),
distutils_extension.Extension(
- "edgedb.datatypes.datatypes",
- ["edgedb/datatypes/args.c",
- "edgedb/datatypes/record_desc.c",
- "edgedb/datatypes/namedtuple.c",
- "edgedb/datatypes/object.c",
- "edgedb/datatypes/hash.c",
- "edgedb/datatypes/repr.c",
- "edgedb/datatypes/comp.c",
- "edgedb/datatypes/datatypes.pyx"],
+ "gel.datatypes.datatypes",
+ ["gel/datatypes/args.c",
+ "gel/datatypes/record_desc.c",
+ "gel/datatypes/namedtuple.c",
+ "gel/datatypes/object.c",
+ "gel/datatypes/hash.c",
+ "gel/datatypes/repr.c",
+ "gel/datatypes/comp.c",
+ "gel/datatypes/datatypes.pyx"],
extra_compile_args=CFLAGS,
extra_link_args=LDFLAGS),
distutils_extension.Extension(
- "edgedb.protocol.protocol",
- ["edgedb/protocol/protocol.pyx"],
+ "gel.protocol.protocol",
+ ["gel/protocol/protocol.pyx"],
extra_compile_args=CFLAGS,
extra_link_args=LDFLAGS,
include_dirs=INCLUDE_DIRS),
distutils_extension.Extension(
- "edgedb.protocol.asyncio_proto",
- ["edgedb/protocol/asyncio_proto.pyx"],
+ "gel.protocol.asyncio_proto",
+ ["gel/protocol/asyncio_proto.pyx"],
extra_compile_args=CFLAGS,
extra_link_args=LDFLAGS,
include_dirs=INCLUDE_DIRS),
distutils_extension.Extension(
- "edgedb.protocol.blocking_proto",
- ["edgedb/protocol/blocking_proto.pyx"],
+ "gel.protocol.blocking_proto",
+ ["gel/protocol/blocking_proto.pyx"],
extra_compile_args=CFLAGS,
extra_link_args=LDFLAGS,
include_dirs=INCLUDE_DIRS),
@@ -349,7 +352,8 @@ def finalize_options(self):
setup_requires=setup_requires,
entry_points={
"console_scripts": [
- "edgedb-py=edgedb.codegen.cli:main",
+ "edgedb-py=gel.codegen.cli:main",
+ "gel-py=gel.codegen.cli:main",
]
}
)
diff --git a/tests/codegen/linked/test_linked_async_edgeql.py.assert b/tests/codegen/linked/test_linked_async_edgeql.py.assert
index e54e2622..33058af6 100644
--- a/tests/codegen/linked/test_linked_async_edgeql.py.assert
+++ b/tests/codegen/linked/test_linked_async_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
async def test_linked(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
diff --git a/tests/codegen/linked/test_linked_edgeql.py.assert b/tests/codegen/linked/test_linked_edgeql.py.assert
index 00a1b169..d8d7aa04 100644
--- a/tests/codegen/linked/test_linked_edgeql.py.assert
+++ b/tests/codegen/linked/test_linked_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'linked/test_linked.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
def test_linked(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> int:
return executor.query_single(
"""\
diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert b/tests/codegen/test-project1/generated_async_edgeql.py.assert
index f4e8fea7..22ba6c8d 100644
--- a/tests/codegen/test-project1/generated_async_edgeql.py.assert
+++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert
@@ -3,12 +3,12 @@
# 'select_scalar.edgeql'
# 'linked/test_linked.edgeql'
# WITH:
-# $ edgedb-py --target async --file --no-skip-pydantic-validation
+# $ gel-py --target async --file --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import uuid
@@ -24,7 +24,7 @@ class SelectOptionalJsonResultItemSnakeCase:
async def select_optional_json(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
arg0: str | None,
) -> list[tuple[str, SelectOptionalJsonResultItem]]:
return await executor.query(
@@ -40,7 +40,7 @@ async def select_optional_json(
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
@@ -50,7 +50,7 @@ async def select_scalar(
async def test_linked(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
diff --git a/tests/codegen/test-project1/generated_async_edgeql.py.assert5 b/tests/codegen/test-project1/generated_async_edgeql.py.assert5
index f4e8fea7..22ba6c8d 100644
--- a/tests/codegen/test-project1/generated_async_edgeql.py.assert5
+++ b/tests/codegen/test-project1/generated_async_edgeql.py.assert5
@@ -3,12 +3,12 @@
# 'select_scalar.edgeql'
# 'linked/test_linked.edgeql'
# WITH:
-# $ edgedb-py --target async --file --no-skip-pydantic-validation
+# $ gel-py --target async --file --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import uuid
@@ -24,7 +24,7 @@ class SelectOptionalJsonResultItemSnakeCase:
async def select_optional_json(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
arg0: str | None,
) -> list[tuple[str, SelectOptionalJsonResultItem]]:
return await executor.query(
@@ -40,7 +40,7 @@ async def select_optional_json(
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
@@ -50,7 +50,7 @@ async def select_scalar(
async def test_linked(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
diff --git a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert
index f70213b0..f294ffd2 100644
--- a/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert
+++ b/tests/codegen/test-project1/select_optional_json_async_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'select_optional_json.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -37,7 +37,7 @@ class SelectOptionalJsonResultItemSnakeCase(NoPydanticValidation):
async def select_optional_json(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
arg0: typing.Optional[str],
) -> typing.List[typing.Tuple[str, SelectOptionalJsonResultItem]]:
return await executor.query(
diff --git a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert
index 004dd27c..61d2cd0d 100644
--- a/tests/codegen/test-project1/select_optional_json_edgeql.py.assert
+++ b/tests/codegen/test-project1/select_optional_json_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'select_optional_json.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -21,7 +21,7 @@ class SelectOptionalJsonResultItemSnakeCase:
def select_optional_json(
- executor: edgedb.Executor,
+ executor: gel.Executor,
arg0: typing.Optional[str],
) -> list[tuple[str, SelectOptionalJsonResultItem]]:
return executor.query(
diff --git a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert
index 6312fe0e..2a6dc130 100644
--- a/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert
+++ b/tests/codegen/test-project1/select_scalar_async_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'select_scalar.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
diff --git a/tests/codegen/test-project1/select_scalar_edgeql.py.assert b/tests/codegen/test-project1/select_scalar_edgeql.py.assert
index 4c870f1b..d8b16a53 100644
--- a/tests/codegen/test-project1/select_scalar_edgeql.py.assert
+++ b/tests/codegen/test-project1/select_scalar_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'select_scalar.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
def select_scalar(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> int:
return executor.query_single(
"""\
diff --git a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert
index cbfb0961..ecf680f7 100644
--- a/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/argnames/query_one_async_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
async def query_one(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_name_with_underscores: int,
) -> int:
diff --git a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert
index c7cfe601..56701a68 100644
--- a/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert
+++ b/tests/codegen/test-project2/argnames/query_one_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'argnames/query_one.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
def query_one(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
arg_name_with_underscores: int,
) -> int:
diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert
index 007f2af1..4efdfad8 100644
--- a/tests/codegen/test-project2/generated_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert
@@ -9,15 +9,15 @@
# 'scalar/select_scalar.edgeql'
# 'scalar/select_scalars.edgeql'
# WITH:
-# $ edgedb-py --target async --file --no-skip-pydantic-validation
+# $ gel-py --target async --file --no-skip-pydantic-validation
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -91,26 +91,26 @@ class MyQueryResult:
ab: datetime.timedelta | None
ac: int
ad: int | None
- ae: edgedb.RelativeDuration
- af: edgedb.RelativeDuration | None
- ag: edgedb.DateDuration
- ah: edgedb.DateDuration | None
- ai: edgedb.ConfigMemory
- aj: edgedb.ConfigMemory | None
- ak: edgedb.Range[int]
- al: edgedb.Range[int] | None
- am: edgedb.Range[int]
- an: edgedb.Range[int] | None
- ao: edgedb.Range[float]
- ap: edgedb.Range[float] | None
- aq: edgedb.Range[float]
- ar: edgedb.Range[float] | None
- as_: edgedb.Range[datetime.datetime]
- at: edgedb.Range[datetime.datetime] | None
- au: edgedb.Range[datetime.datetime]
- av: edgedb.Range[datetime.datetime] | None
- aw: edgedb.Range[datetime.date]
- ax: edgedb.Range[datetime.date] | None
+ ae: gel.RelativeDuration
+ af: gel.RelativeDuration | None
+ ag: gel.DateDuration
+ ah: gel.DateDuration | None
+ ai: gel.ConfigMemory
+ aj: gel.ConfigMemory | None
+ ak: gel.Range[int]
+ al: gel.Range[int] | None
+ am: gel.Range[int]
+ an: gel.Range[int] | None
+ ao: gel.Range[float]
+ ap: gel.Range[float] | None
+ aq: gel.Range[float]
+ ar: gel.Range[float] | None
+ as_: gel.Range[datetime.datetime]
+ at: gel.Range[datetime.datetime] | None
+ au: gel.Range[datetime.datetime]
+ av: gel.Range[datetime.datetime] | None
+ aw: gel.Range[datetime.date]
+ ax: gel.Range[datetime.date] | None
ay: MyScalar
az: MyScalar | None
ba: MyEnum
@@ -141,7 +141,7 @@ class SelectObjectResultParamsItem:
async def custom_vector_input(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
input: V3 | None = None,
) -> int | None:
@@ -154,7 +154,7 @@ async def custom_vector_input(
async def link_prop(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[LinkPropResult]:
return await executor.query(
"""\
@@ -181,7 +181,7 @@ async def link_prop(
async def my_query(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
a: uuid.UUID,
b: uuid.UUID | None = None,
@@ -213,26 +213,26 @@ async def my_query(
ab: datetime.timedelta | None = None,
ac: int,
ad: int | None = None,
- ae: edgedb.RelativeDuration,
- af: edgedb.RelativeDuration | None = None,
- ag: edgedb.DateDuration,
- ah: edgedb.DateDuration | None = None,
- ai: edgedb.ConfigMemory,
- aj: edgedb.ConfigMemory | None = None,
- ak: edgedb.Range[int],
- al: edgedb.Range[int] | None = None,
- am: edgedb.Range[int],
- an: edgedb.Range[int] | None = None,
- ao: edgedb.Range[float],
- ap: edgedb.Range[float] | None = None,
- aq: edgedb.Range[float],
- ar: edgedb.Range[float] | None = None,
- as_: edgedb.Range[datetime.datetime],
- at: edgedb.Range[datetime.datetime] | None = None,
- au: edgedb.Range[datetime.datetime],
- av: edgedb.Range[datetime.datetime] | None = None,
- aw: edgedb.Range[datetime.date],
- ax: edgedb.Range[datetime.date] | None = None,
+ ae: gel.RelativeDuration,
+ af: gel.RelativeDuration | None = None,
+ ag: gel.DateDuration,
+ ah: gel.DateDuration | None = None,
+ ai: gel.ConfigMemory,
+ aj: gel.ConfigMemory | None = None,
+ ak: gel.Range[int],
+ al: gel.Range[int] | None = None,
+ am: gel.Range[int],
+ an: gel.Range[int] | None = None,
+ ao: gel.Range[float],
+ ap: gel.Range[float] | None = None,
+ aq: gel.Range[float],
+ ar: gel.Range[float] | None = None,
+ as_: gel.Range[datetime.datetime],
+ at: gel.Range[datetime.datetime] | None = None,
+ au: gel.Range[datetime.datetime],
+ av: gel.Range[datetime.datetime] | None = None,
+ aw: gel.Range[datetime.date],
+ ax: gel.Range[datetime.date] | None = None,
bc: typing.Sequence[float],
bd: typing.Sequence[float] | None = None,
) -> MyQueryResult:
@@ -356,7 +356,7 @@ async def my_query(
async def query_one(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_name_with_underscores: int,
) -> int:
@@ -369,7 +369,7 @@ async def query_one(
async def select_args(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
@@ -387,7 +387,7 @@ async def select_args(
async def select_object(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> SelectObjectResult | None:
return await executor.query_single(
"""\
@@ -405,7 +405,7 @@ async def select_object(
async def select_objects(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[SelectObjectResult]:
return await executor.query(
"""\
@@ -422,7 +422,7 @@ async def select_objects(
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
@@ -432,8 +432,8 @@ async def select_scalar(
async def select_scalars(
- executor: edgedb.AsyncIOExecutor,
-) -> list[edgedb.ConfigMemory]:
+ executor: gel.AsyncIOExecutor,
+) -> list[gel.ConfigMemory]:
return await executor.query(
"""\
select {1, 2, 3};\
diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert3 b/tests/codegen/test-project2/generated_async_edgeql.py.assert3
index 2095cc9a..4a1a7444 100644
--- a/tests/codegen/test-project2/generated_async_edgeql.py.assert3
+++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert3
@@ -9,15 +9,15 @@
# 'scalar/select_scalar.edgeql'
# 'scalar/select_scalars.edgeql'
# WITH:
-# $ edgedb-py --target async --file --no-skip-pydantic-validation
+# $ gel-py --target async --file --no-skip-pydantic-validation
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -92,26 +92,26 @@ class MyQueryResult:
ab: datetime.timedelta | None
ac: int
ad: int | None
- ae: edgedb.RelativeDuration
- af: edgedb.RelativeDuration | None
- ag: edgedb.DateDuration
- ah: edgedb.DateDuration | None
- ai: edgedb.ConfigMemory
- aj: edgedb.ConfigMemory | None
- ak: edgedb.Range[int]
- al: edgedb.Range[int] | None
- am: edgedb.Range[int]
- an: edgedb.Range[int] | None
- ao: edgedb.Range[float]
- ap: edgedb.Range[float] | None
- aq: edgedb.Range[float]
- ar: edgedb.Range[float] | None
- as_: edgedb.Range[datetime.datetime]
- at: edgedb.Range[datetime.datetime] | None
- au: edgedb.Range[datetime.datetime]
- av: edgedb.Range[datetime.datetime] | None
- aw: edgedb.Range[datetime.date]
- ax: edgedb.Range[datetime.date] | None
+ ae: gel.RelativeDuration
+ af: gel.RelativeDuration | None
+ ag: gel.DateDuration
+ ah: gel.DateDuration | None
+ ai: gel.ConfigMemory
+ aj: gel.ConfigMemory | None
+ ak: gel.Range[int]
+ al: gel.Range[int] | None
+ am: gel.Range[int]
+ an: gel.Range[int] | None
+ ao: gel.Range[float]
+ ap: gel.Range[float] | None
+ aq: gel.Range[float]
+ ar: gel.Range[float] | None
+ as_: gel.Range[datetime.datetime]
+ at: gel.Range[datetime.datetime] | None
+ au: gel.Range[datetime.datetime]
+ av: gel.Range[datetime.datetime] | None
+ aw: gel.Range[datetime.date]
+ ax: gel.Range[datetime.date] | None
ay: MyScalar
az: MyScalar | None
ba: MyEnum
@@ -143,7 +143,7 @@ class SelectObjectResultParamsItem:
async def custom_vector_input(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
input: Input | None = None,
) -> int | None:
@@ -156,7 +156,7 @@ async def custom_vector_input(
async def link_prop(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[LinkPropResult]:
return await executor.query(
"""\
@@ -183,7 +183,7 @@ async def link_prop(
async def my_query(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
a: uuid.UUID,
b: uuid.UUID | None = None,
@@ -215,26 +215,26 @@ async def my_query(
ab: datetime.timedelta | None = None,
ac: int,
ad: int | None = None,
- ae: edgedb.RelativeDuration,
- af: edgedb.RelativeDuration | None = None,
- ag: edgedb.DateDuration,
- ah: edgedb.DateDuration | None = None,
- ai: edgedb.ConfigMemory,
- aj: edgedb.ConfigMemory | None = None,
- ak: edgedb.Range[int],
- al: edgedb.Range[int] | None = None,
- am: edgedb.Range[int],
- an: edgedb.Range[int] | None = None,
- ao: edgedb.Range[float],
- ap: edgedb.Range[float] | None = None,
- aq: edgedb.Range[float],
- ar: edgedb.Range[float] | None = None,
- as_: edgedb.Range[datetime.datetime],
- at: edgedb.Range[datetime.datetime] | None = None,
- au: edgedb.Range[datetime.datetime],
- av: edgedb.Range[datetime.datetime] | None = None,
- aw: edgedb.Range[datetime.date],
- ax: edgedb.Range[datetime.date] | None = None,
+ ae: gel.RelativeDuration,
+ af: gel.RelativeDuration | None = None,
+ ag: gel.DateDuration,
+ ah: gel.DateDuration | None = None,
+ ai: gel.ConfigMemory,
+ aj: gel.ConfigMemory | None = None,
+ ak: gel.Range[int],
+ al: gel.Range[int] | None = None,
+ am: gel.Range[int],
+ an: gel.Range[int] | None = None,
+ ao: gel.Range[float],
+ ap: gel.Range[float] | None = None,
+ aq: gel.Range[float],
+ ar: gel.Range[float] | None = None,
+ as_: gel.Range[datetime.datetime],
+ at: gel.Range[datetime.datetime] | None = None,
+ au: gel.Range[datetime.datetime],
+ av: gel.Range[datetime.datetime] | None = None,
+ aw: gel.Range[datetime.date],
+ ax: gel.Range[datetime.date] | None = None,
bc: typing.Sequence[float],
bd: typing.Sequence[float] | None = None,
) -> MyQueryResult:
@@ -358,7 +358,7 @@ async def my_query(
async def query_one(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_name_with_underscores: int,
) -> int:
@@ -371,7 +371,7 @@ async def query_one(
async def select_args(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
@@ -389,7 +389,7 @@ async def select_args(
async def select_object(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> SelectObjectResult | None:
return await executor.query_single(
"""\
@@ -407,7 +407,7 @@ async def select_object(
async def select_objects(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[SelectObjectResult]:
return await executor.query(
"""\
@@ -424,7 +424,7 @@ async def select_objects(
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
@@ -434,8 +434,8 @@ async def select_scalar(
async def select_scalars(
- executor: edgedb.AsyncIOExecutor,
-) -> list[edgedb.ConfigMemory]:
+ executor: gel.AsyncIOExecutor,
+) -> list[gel.ConfigMemory]:
return await executor.query(
"""\
select {1, 2, 3};\
diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert5 b/tests/codegen/test-project2/generated_async_edgeql.py.assert5
index 75d52ea7..8907bf73 100644
--- a/tests/codegen/test-project2/generated_async_edgeql.py.assert5
+++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert5
@@ -9,15 +9,15 @@
# 'scalar/select_scalar.edgeql'
# 'scalar/select_scalars.edgeql'
# WITH:
-# $ edgedb-py --target async --file --no-skip-pydantic-validation
+# $ gel-py --target async --file --no-skip-pydantic-validation
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -92,26 +92,26 @@ class MyQueryResult:
ab: datetime.timedelta | None
ac: int
ad: int | None
- ae: edgedb.RelativeDuration
- af: edgedb.RelativeDuration | None
- ag: edgedb.DateDuration
- ah: edgedb.DateDuration | None
- ai: edgedb.ConfigMemory
- aj: edgedb.ConfigMemory | None
- ak: edgedb.Range[int]
- al: edgedb.Range[int] | None
- am: edgedb.Range[int]
- an: edgedb.Range[int] | None
- ao: edgedb.Range[float]
- ap: edgedb.Range[float] | None
- aq: edgedb.Range[float]
- ar: edgedb.Range[float] | None
- as_: edgedb.Range[datetime.datetime]
- at: edgedb.Range[datetime.datetime] | None
- au: edgedb.Range[datetime.datetime]
- av: edgedb.Range[datetime.datetime] | None
- aw: edgedb.Range[datetime.date]
- ax: edgedb.Range[datetime.date] | None
+ ae: gel.RelativeDuration
+ af: gel.RelativeDuration | None
+ ag: gel.DateDuration
+ ah: gel.DateDuration | None
+ ai: gel.ConfigMemory
+ aj: gel.ConfigMemory | None
+ ak: gel.Range[int]
+ al: gel.Range[int] | None
+ am: gel.Range[int]
+ an: gel.Range[int] | None
+ ao: gel.Range[float]
+ ap: gel.Range[float] | None
+ aq: gel.Range[float]
+ ar: gel.Range[float] | None
+ as_: gel.Range[datetime.datetime]
+ at: gel.Range[datetime.datetime] | None
+ au: gel.Range[datetime.datetime]
+ av: gel.Range[datetime.datetime] | None
+ aw: gel.Range[datetime.date]
+ ax: gel.Range[datetime.date] | None
ay: MyScalar
az: MyScalar | None
ba: MyEnum
@@ -143,7 +143,7 @@ class SelectObjectResultParamsItem:
async def custom_vector_input(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
input: V3 | None = None,
) -> int | None:
@@ -156,7 +156,7 @@ async def custom_vector_input(
async def link_prop(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[LinkPropResult]:
return await executor.query(
"""\
@@ -183,7 +183,7 @@ async def link_prop(
async def my_query(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
a: uuid.UUID,
b: uuid.UUID | None = None,
@@ -215,26 +215,26 @@ async def my_query(
ab: datetime.timedelta | None = None,
ac: int,
ad: int | None = None,
- ae: edgedb.RelativeDuration,
- af: edgedb.RelativeDuration | None = None,
- ag: edgedb.DateDuration,
- ah: edgedb.DateDuration | None = None,
- ai: edgedb.ConfigMemory,
- aj: edgedb.ConfigMemory | None = None,
- ak: edgedb.Range[int],
- al: edgedb.Range[int] | None = None,
- am: edgedb.Range[int],
- an: edgedb.Range[int] | None = None,
- ao: edgedb.Range[float],
- ap: edgedb.Range[float] | None = None,
- aq: edgedb.Range[float],
- ar: edgedb.Range[float] | None = None,
- as_: edgedb.Range[datetime.datetime],
- at: edgedb.Range[datetime.datetime] | None = None,
- au: edgedb.Range[datetime.datetime],
- av: edgedb.Range[datetime.datetime] | None = None,
- aw: edgedb.Range[datetime.date],
- ax: edgedb.Range[datetime.date] | None = None,
+ ae: gel.RelativeDuration,
+ af: gel.RelativeDuration | None = None,
+ ag: gel.DateDuration,
+ ah: gel.DateDuration | None = None,
+ ai: gel.ConfigMemory,
+ aj: gel.ConfigMemory | None = None,
+ ak: gel.Range[int],
+ al: gel.Range[int] | None = None,
+ am: gel.Range[int],
+ an: gel.Range[int] | None = None,
+ ao: gel.Range[float],
+ ap: gel.Range[float] | None = None,
+ aq: gel.Range[float],
+ ar: gel.Range[float] | None = None,
+ as_: gel.Range[datetime.datetime],
+ at: gel.Range[datetime.datetime] | None = None,
+ au: gel.Range[datetime.datetime],
+ av: gel.Range[datetime.datetime] | None = None,
+ aw: gel.Range[datetime.date],
+ ax: gel.Range[datetime.date] | None = None,
bc: typing.Sequence[float],
bd: typing.Sequence[float] | None = None,
) -> MyQueryResult:
@@ -358,7 +358,7 @@ async def my_query(
async def query_one(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_name_with_underscores: int,
) -> int:
@@ -371,7 +371,7 @@ async def query_one(
async def select_args(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
@@ -389,7 +389,7 @@ async def select_args(
async def select_object(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> SelectObjectResult | None:
return await executor.query_single(
"""\
@@ -407,7 +407,7 @@ async def select_object(
async def select_objects(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> list[SelectObjectResult]:
return await executor.query(
"""\
@@ -424,7 +424,7 @@ async def select_objects(
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
@@ -434,8 +434,8 @@ async def select_scalar(
async def select_scalars(
- executor: edgedb.AsyncIOExecutor,
-) -> list[edgedb.ConfigMemory]:
+ executor: gel.AsyncIOExecutor,
+) -> list[gel.ConfigMemory]:
return await executor.query(
"""\
select {1, 2, 3};\
diff --git a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert
index 95263c5b..bc323374 100644
--- a/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/link_prop_async_edgeql.py.assert
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'object/link_prop.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
import typing
import uuid
@@ -52,7 +52,7 @@ class LinkPropResultFriendsItem(NoPydanticValidation):
async def link_prop(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> typing.List[LinkPropResult]:
return await executor.query(
"""\
diff --git a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert
index 3eaf4ddb..d3912dac 100644
--- a/tests/codegen/test-project2/object/link_prop_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/link_prop_edgeql.py.assert
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'object/link_prop.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
import typing
import uuid
@@ -36,7 +36,7 @@ class LinkPropResultFriendsItem:
def link_prop(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> list[LinkPropResult]:
return executor.query(
"""\
diff --git a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert
index 3e121c7a..37e1ec36 100644
--- a/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/select_object_async_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'object/select_object.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -41,7 +41,7 @@ class SelectObjectResultParamsItem(NoPydanticValidation):
async def select_object(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> typing.Optional[SelectObjectResult]:
return await executor.query_single(
"""\
diff --git a/tests/codegen/test-project2/object/select_object_edgeql.py.assert b/tests/codegen/test-project2/object/select_object_edgeql.py.assert
index c3ebb60e..2fa0df0a 100644
--- a/tests/codegen/test-project2/object/select_object_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/select_object_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'object/select_object.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -25,7 +25,7 @@ class SelectObjectResultParamsItem:
def select_object(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> typing.Optional[SelectObjectResult]:
return executor.query_single(
"""\
diff --git a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert
index 25e8ec92..f4e97aba 100644
--- a/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/select_objects_async_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'object/select_objects.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -41,7 +41,7 @@ class SelectObjectsResultParamsItem(NoPydanticValidation):
async def select_objects(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> typing.List[SelectObjectsResult]:
return await executor.query(
"""\
diff --git a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert
index fed335d7..875b3cfa 100644
--- a/tests/codegen/test-project2/object/select_objects_edgeql.py.assert
+++ b/tests/codegen/test-project2/object/select_objects_edgeql.py.assert
@@ -1,10 +1,10 @@
# AUTOGENERATED FROM 'object/select_objects.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
-import edgedb
+import gel
import typing
import uuid
@@ -25,7 +25,7 @@ class SelectObjectsResultParamsItem:
def select_objects(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> list[SelectObjectsResult]:
return executor.query(
"""\
diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert
index 7643f4f4..4f3ece14 100644
--- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
class NoPydanticValidation:
@@ -31,7 +31,7 @@ class SelectArgsResult(NoPydanticValidation):
async def select_args(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
diff --git a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5
index b318bb5c..0e48f41d 100644
--- a/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5
+++ b/tests/codegen/test-project2/parpkg/select_args_async_edgeql.py.assert5
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
import uuid
@@ -33,7 +33,7 @@ class SelectArgsResult(NoPydanticValidation):
async def select_args(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert
index 88fd6cae..5adeb922 100644
--- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert
+++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
@dataclasses.dataclass
@@ -15,7 +15,7 @@ class SelectArgsResult:
def select_args(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
diff --git a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5
index cc0d8cfc..ca9cfddb 100644
--- a/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5
+++ b/tests/codegen/test-project2/parpkg/select_args_edgeql.py.assert5
@@ -1,11 +1,11 @@
# AUTOGENERATED FROM 'parpkg/select_args.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import dataclasses
import datetime
-import edgedb
+import gel
import uuid
@@ -17,7 +17,7 @@ class SelectArgsResult:
def select_args(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
arg_str: str,
arg_datetime: datetime.datetime,
diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert
index 79b133af..ddb0d695 100644
--- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -71,26 +71,26 @@ class MyQueryResult(NoPydanticValidation):
ab: typing.Optional[datetime.timedelta]
ac: int
ad: typing.Optional[int]
- ae: edgedb.RelativeDuration
- af: typing.Optional[edgedb.RelativeDuration]
- ag: edgedb.DateDuration
- ah: typing.Optional[edgedb.DateDuration]
- ai: edgedb.ConfigMemory
- aj: typing.Optional[edgedb.ConfigMemory]
- ak: edgedb.Range[int]
- al: typing.Optional[edgedb.Range[int]]
- am: edgedb.Range[int]
- an: typing.Optional[edgedb.Range[int]]
- ao: edgedb.Range[float]
- ap: typing.Optional[edgedb.Range[float]]
- aq: edgedb.Range[float]
- ar: typing.Optional[edgedb.Range[float]]
- as_: edgedb.Range[datetime.datetime]
- at: typing.Optional[edgedb.Range[datetime.datetime]]
- au: edgedb.Range[datetime.datetime]
- av: typing.Optional[edgedb.Range[datetime.datetime]]
- aw: edgedb.Range[datetime.date]
- ax: typing.Optional[edgedb.Range[datetime.date]]
+ ae: gel.RelativeDuration
+ af: typing.Optional[gel.RelativeDuration]
+ ag: gel.DateDuration
+ ah: typing.Optional[gel.DateDuration]
+ ai: gel.ConfigMemory
+ aj: typing.Optional[gel.ConfigMemory]
+ ak: gel.Range[int]
+ al: typing.Optional[gel.Range[int]]
+ am: gel.Range[int]
+ an: typing.Optional[gel.Range[int]]
+ ao: gel.Range[float]
+ ap: typing.Optional[gel.Range[float]]
+ aq: gel.Range[float]
+ ar: typing.Optional[gel.Range[float]]
+ as_: gel.Range[datetime.datetime]
+ at: typing.Optional[gel.Range[datetime.datetime]]
+ au: gel.Range[datetime.datetime]
+ av: typing.Optional[gel.Range[datetime.datetime]]
+ aw: gel.Range[datetime.date]
+ ax: typing.Optional[gel.Range[datetime.date]]
ay: MyScalar
az: typing.Optional[MyScalar]
ba: MyEnum
@@ -100,7 +100,7 @@ class MyQueryResult(NoPydanticValidation):
async def my_query(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID] = None,
@@ -132,26 +132,26 @@ async def my_query(
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int] = None,
- ae: edgedb.RelativeDuration,
- af: typing.Optional[edgedb.RelativeDuration] = None,
- ag: edgedb.DateDuration,
- ah: typing.Optional[edgedb.DateDuration] = None,
- ai: edgedb.ConfigMemory,
- aj: typing.Optional[edgedb.ConfigMemory] = None,
- ak: edgedb.Range[int],
- al: typing.Optional[edgedb.Range[int]] = None,
- am: edgedb.Range[int],
- an: typing.Optional[edgedb.Range[int]] = None,
- ao: edgedb.Range[float],
- ap: typing.Optional[edgedb.Range[float]] = None,
- aq: edgedb.Range[float],
- ar: typing.Optional[edgedb.Range[float]] = None,
- as_: edgedb.Range[datetime.datetime],
- at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- au: edgedb.Range[datetime.datetime],
- av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- aw: edgedb.Range[datetime.date],
- ax: typing.Optional[edgedb.Range[datetime.date]] = None,
+ ae: gel.RelativeDuration,
+ af: typing.Optional[gel.RelativeDuration] = None,
+ ag: gel.DateDuration,
+ ah: typing.Optional[gel.DateDuration] = None,
+ ai: gel.ConfigMemory,
+ aj: typing.Optional[gel.ConfigMemory] = None,
+ ak: gel.Range[int],
+ al: typing.Optional[gel.Range[int]] = None,
+ am: gel.Range[int],
+ an: typing.Optional[gel.Range[int]] = None,
+ ao: gel.Range[float],
+ ap: typing.Optional[gel.Range[float]] = None,
+ aq: gel.Range[float],
+ ar: typing.Optional[gel.Range[float]] = None,
+ as_: gel.Range[datetime.datetime],
+ at: typing.Optional[gel.Range[datetime.datetime]] = None,
+ au: gel.Range[datetime.datetime],
+ av: typing.Optional[gel.Range[datetime.datetime]] = None,
+ aw: gel.Range[datetime.date],
+ ax: typing.Optional[gel.Range[datetime.date]] = None,
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]] = None,
) -> MyQueryResult:
diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5
index 05e4a7e2..7b7138cb 100644
--- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5
+++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert5
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -72,26 +72,26 @@ class MyQueryResult(NoPydanticValidation):
ab: typing.Optional[datetime.timedelta]
ac: int
ad: typing.Optional[int]
- ae: edgedb.RelativeDuration
- af: typing.Optional[edgedb.RelativeDuration]
- ag: edgedb.DateDuration
- ah: typing.Optional[edgedb.DateDuration]
- ai: edgedb.ConfigMemory
- aj: typing.Optional[edgedb.ConfigMemory]
- ak: edgedb.Range[int]
- al: typing.Optional[edgedb.Range[int]]
- am: edgedb.Range[int]
- an: typing.Optional[edgedb.Range[int]]
- ao: edgedb.Range[float]
- ap: typing.Optional[edgedb.Range[float]]
- aq: edgedb.Range[float]
- ar: typing.Optional[edgedb.Range[float]]
- as_: edgedb.Range[datetime.datetime]
- at: typing.Optional[edgedb.Range[datetime.datetime]]
- au: edgedb.Range[datetime.datetime]
- av: typing.Optional[edgedb.Range[datetime.datetime]]
- aw: edgedb.Range[datetime.date]
- ax: typing.Optional[edgedb.Range[datetime.date]]
+ ae: gel.RelativeDuration
+ af: typing.Optional[gel.RelativeDuration]
+ ag: gel.DateDuration
+ ah: typing.Optional[gel.DateDuration]
+ ai: gel.ConfigMemory
+ aj: typing.Optional[gel.ConfigMemory]
+ ak: gel.Range[int]
+ al: typing.Optional[gel.Range[int]]
+ am: gel.Range[int]
+ an: typing.Optional[gel.Range[int]]
+ ao: gel.Range[float]
+ ap: typing.Optional[gel.Range[float]]
+ aq: gel.Range[float]
+ ar: typing.Optional[gel.Range[float]]
+ as_: gel.Range[datetime.datetime]
+ at: typing.Optional[gel.Range[datetime.datetime]]
+ au: gel.Range[datetime.datetime]
+ av: typing.Optional[gel.Range[datetime.datetime]]
+ aw: gel.Range[datetime.date]
+ ax: typing.Optional[gel.Range[datetime.date]]
ay: MyScalar
az: typing.Optional[MyScalar]
ba: MyEnum
@@ -101,7 +101,7 @@ class MyQueryResult(NoPydanticValidation):
async def my_query(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID] = None,
@@ -133,26 +133,26 @@ async def my_query(
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int] = None,
- ae: edgedb.RelativeDuration,
- af: typing.Optional[edgedb.RelativeDuration] = None,
- ag: edgedb.DateDuration,
- ah: typing.Optional[edgedb.DateDuration] = None,
- ai: edgedb.ConfigMemory,
- aj: typing.Optional[edgedb.ConfigMemory] = None,
- ak: edgedb.Range[int],
- al: typing.Optional[edgedb.Range[int]] = None,
- am: edgedb.Range[int],
- an: typing.Optional[edgedb.Range[int]] = None,
- ao: edgedb.Range[float],
- ap: typing.Optional[edgedb.Range[float]] = None,
- aq: edgedb.Range[float],
- ar: typing.Optional[edgedb.Range[float]] = None,
- as_: edgedb.Range[datetime.datetime],
- at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- au: edgedb.Range[datetime.datetime],
- av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- aw: edgedb.Range[datetime.date],
- ax: typing.Optional[edgedb.Range[datetime.date]] = None,
+ ae: gel.RelativeDuration,
+ af: typing.Optional[gel.RelativeDuration] = None,
+ ag: gel.DateDuration,
+ ah: typing.Optional[gel.DateDuration] = None,
+ ai: gel.ConfigMemory,
+ aj: typing.Optional[gel.ConfigMemory] = None,
+ ak: gel.Range[int],
+ al: typing.Optional[gel.Range[int]] = None,
+ am: gel.Range[int],
+ an: typing.Optional[gel.Range[int]] = None,
+ ao: gel.Range[float],
+ ap: typing.Optional[gel.Range[float]] = None,
+ aq: gel.Range[float],
+ ar: typing.Optional[gel.Range[float]] = None,
+ as_: gel.Range[datetime.datetime],
+ at: typing.Optional[gel.Range[datetime.datetime]] = None,
+ au: gel.Range[datetime.datetime],
+ av: typing.Optional[gel.Range[datetime.datetime]] = None,
+ aw: gel.Range[datetime.date],
+ ax: typing.Optional[gel.Range[datetime.date]] = None,
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]] = None,
) -> MyQueryResult:
diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert
index fa79a94c..b22baa49 100644
--- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert
+++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -55,26 +55,26 @@ class MyQueryResult:
ab: typing.Optional[datetime.timedelta]
ac: int
ad: typing.Optional[int]
- ae: edgedb.RelativeDuration
- af: typing.Optional[edgedb.RelativeDuration]
- ag: edgedb.DateDuration
- ah: typing.Optional[edgedb.DateDuration]
- ai: edgedb.ConfigMemory
- aj: typing.Optional[edgedb.ConfigMemory]
- ak: edgedb.Range[int]
- al: typing.Optional[edgedb.Range[int]]
- am: edgedb.Range[int]
- an: typing.Optional[edgedb.Range[int]]
- ao: edgedb.Range[float]
- ap: typing.Optional[edgedb.Range[float]]
- aq: edgedb.Range[float]
- ar: typing.Optional[edgedb.Range[float]]
- as_: edgedb.Range[datetime.datetime]
- at: typing.Optional[edgedb.Range[datetime.datetime]]
- au: edgedb.Range[datetime.datetime]
- av: typing.Optional[edgedb.Range[datetime.datetime]]
- aw: edgedb.Range[datetime.date]
- ax: typing.Optional[edgedb.Range[datetime.date]]
+ ae: gel.RelativeDuration
+ af: typing.Optional[gel.RelativeDuration]
+ ag: gel.DateDuration
+ ah: typing.Optional[gel.DateDuration]
+ ai: gel.ConfigMemory
+ aj: typing.Optional[gel.ConfigMemory]
+ ak: gel.Range[int]
+ al: typing.Optional[gel.Range[int]]
+ am: gel.Range[int]
+ an: typing.Optional[gel.Range[int]]
+ ao: gel.Range[float]
+ ap: typing.Optional[gel.Range[float]]
+ aq: gel.Range[float]
+ ar: typing.Optional[gel.Range[float]]
+ as_: gel.Range[datetime.datetime]
+ at: typing.Optional[gel.Range[datetime.datetime]]
+ au: gel.Range[datetime.datetime]
+ av: typing.Optional[gel.Range[datetime.datetime]]
+ aw: gel.Range[datetime.date]
+ ax: typing.Optional[gel.Range[datetime.date]]
ay: MyScalar
az: typing.Optional[MyScalar]
ba: MyEnum
@@ -84,7 +84,7 @@ class MyQueryResult:
def my_query(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID] = None,
@@ -116,26 +116,26 @@ def my_query(
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int] = None,
- ae: edgedb.RelativeDuration,
- af: typing.Optional[edgedb.RelativeDuration] = None,
- ag: edgedb.DateDuration,
- ah: typing.Optional[edgedb.DateDuration] = None,
- ai: edgedb.ConfigMemory,
- aj: typing.Optional[edgedb.ConfigMemory] = None,
- ak: edgedb.Range[int],
- al: typing.Optional[edgedb.Range[int]] = None,
- am: edgedb.Range[int],
- an: typing.Optional[edgedb.Range[int]] = None,
- ao: edgedb.Range[float],
- ap: typing.Optional[edgedb.Range[float]] = None,
- aq: edgedb.Range[float],
- ar: typing.Optional[edgedb.Range[float]] = None,
- as_: edgedb.Range[datetime.datetime],
- at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- au: edgedb.Range[datetime.datetime],
- av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- aw: edgedb.Range[datetime.date],
- ax: typing.Optional[edgedb.Range[datetime.date]] = None,
+ ae: gel.RelativeDuration,
+ af: typing.Optional[gel.RelativeDuration] = None,
+ ag: gel.DateDuration,
+ ah: typing.Optional[gel.DateDuration] = None,
+ ai: gel.ConfigMemory,
+ aj: typing.Optional[gel.ConfigMemory] = None,
+ ak: gel.Range[int],
+ al: typing.Optional[gel.Range[int]] = None,
+ am: gel.Range[int],
+ an: typing.Optional[gel.Range[int]] = None,
+ ao: gel.Range[float],
+ ap: typing.Optional[gel.Range[float]] = None,
+ aq: gel.Range[float],
+ ar: typing.Optional[gel.Range[float]] = None,
+ as_: gel.Range[datetime.datetime],
+ at: typing.Optional[gel.Range[datetime.datetime]] = None,
+ au: gel.Range[datetime.datetime],
+ av: typing.Optional[gel.Range[datetime.datetime]] = None,
+ aw: gel.Range[datetime.date],
+ ax: typing.Optional[gel.Range[datetime.date]] = None,
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]] = None,
) -> MyQueryResult:
diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5 b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5
index c5a85ea8..2d60d18a 100644
--- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5
+++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert5
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'parpkg/subpkg/my_query.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
import array
import dataclasses
import datetime
-import edgedb
import enum
+import gel
import typing
import uuid
@@ -56,26 +56,26 @@ class MyQueryResult:
ab: typing.Optional[datetime.timedelta]
ac: int
ad: typing.Optional[int]
- ae: edgedb.RelativeDuration
- af: typing.Optional[edgedb.RelativeDuration]
- ag: edgedb.DateDuration
- ah: typing.Optional[edgedb.DateDuration]
- ai: edgedb.ConfigMemory
- aj: typing.Optional[edgedb.ConfigMemory]
- ak: edgedb.Range[int]
- al: typing.Optional[edgedb.Range[int]]
- am: edgedb.Range[int]
- an: typing.Optional[edgedb.Range[int]]
- ao: edgedb.Range[float]
- ap: typing.Optional[edgedb.Range[float]]
- aq: edgedb.Range[float]
- ar: typing.Optional[edgedb.Range[float]]
- as_: edgedb.Range[datetime.datetime]
- at: typing.Optional[edgedb.Range[datetime.datetime]]
- au: edgedb.Range[datetime.datetime]
- av: typing.Optional[edgedb.Range[datetime.datetime]]
- aw: edgedb.Range[datetime.date]
- ax: typing.Optional[edgedb.Range[datetime.date]]
+ ae: gel.RelativeDuration
+ af: typing.Optional[gel.RelativeDuration]
+ ag: gel.DateDuration
+ ah: typing.Optional[gel.DateDuration]
+ ai: gel.ConfigMemory
+ aj: typing.Optional[gel.ConfigMemory]
+ ak: gel.Range[int]
+ al: typing.Optional[gel.Range[int]]
+ am: gel.Range[int]
+ an: typing.Optional[gel.Range[int]]
+ ao: gel.Range[float]
+ ap: typing.Optional[gel.Range[float]]
+ aq: gel.Range[float]
+ ar: typing.Optional[gel.Range[float]]
+ as_: gel.Range[datetime.datetime]
+ at: typing.Optional[gel.Range[datetime.datetime]]
+ au: gel.Range[datetime.datetime]
+ av: typing.Optional[gel.Range[datetime.datetime]]
+ aw: gel.Range[datetime.date]
+ ax: typing.Optional[gel.Range[datetime.date]]
ay: MyScalar
az: typing.Optional[MyScalar]
ba: MyEnum
@@ -85,7 +85,7 @@ class MyQueryResult:
def my_query(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID] = None,
@@ -117,26 +117,26 @@ def my_query(
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int] = None,
- ae: edgedb.RelativeDuration,
- af: typing.Optional[edgedb.RelativeDuration] = None,
- ag: edgedb.DateDuration,
- ah: typing.Optional[edgedb.DateDuration] = None,
- ai: edgedb.ConfigMemory,
- aj: typing.Optional[edgedb.ConfigMemory] = None,
- ak: edgedb.Range[int],
- al: typing.Optional[edgedb.Range[int]] = None,
- am: edgedb.Range[int],
- an: typing.Optional[edgedb.Range[int]] = None,
- ao: edgedb.Range[float],
- ap: typing.Optional[edgedb.Range[float]] = None,
- aq: edgedb.Range[float],
- ar: typing.Optional[edgedb.Range[float]] = None,
- as_: edgedb.Range[datetime.datetime],
- at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- au: edgedb.Range[datetime.datetime],
- av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
- aw: edgedb.Range[datetime.date],
- ax: typing.Optional[edgedb.Range[datetime.date]] = None,
+ ae: gel.RelativeDuration,
+ af: typing.Optional[gel.RelativeDuration] = None,
+ ag: gel.DateDuration,
+ ah: typing.Optional[gel.DateDuration] = None,
+ ai: gel.ConfigMemory,
+ aj: typing.Optional[gel.ConfigMemory] = None,
+ ak: gel.Range[int],
+ al: typing.Optional[gel.Range[int]] = None,
+ am: gel.Range[int],
+ an: typing.Optional[gel.Range[int]] = None,
+ ao: gel.Range[float],
+ ap: typing.Optional[gel.Range[float]] = None,
+ aq: gel.Range[float],
+ ar: typing.Optional[gel.Range[float]] = None,
+ as_: gel.Range[datetime.datetime],
+ at: typing.Optional[gel.Range[datetime.datetime]] = None,
+ au: gel.Range[datetime.datetime],
+ av: typing.Optional[gel.Range[datetime.datetime]] = None,
+ aw: gel.Range[datetime.date],
+ ax: typing.Optional[gel.Range[datetime.date]] = None,
bc: typing.Sequence[float],
bd: typing.Optional[typing.Sequence[float]] = None,
) -> MyQueryResult:
diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert
index 39272b4f..277e472a 100644
--- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert
@@ -1,9 +1,9 @@
# AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
import typing
@@ -11,7 +11,7 @@ V3 = typing.Sequence[float]
async def custom_vector_input(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
input: typing.Optional[V3] = None,
) -> typing.Optional[int]:
diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3
index a16beb5e..1a1f0e0d 100644
--- a/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3
+++ b/tests/codegen/test-project2/scalar/custom_vector_input_async_edgeql.py.assert3
@@ -1,9 +1,9 @@
# AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
import typing
@@ -11,7 +11,7 @@ Input = typing.Sequence[float]
async def custom_vector_input(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
*,
input: typing.Optional[Input] = None,
) -> typing.Optional[int]:
diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert
index ed2947c0..6ccca1a6 100644
--- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert
@@ -1,9 +1,9 @@
# AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
import typing
@@ -11,7 +11,7 @@ V3 = typing.Sequence[float]
def custom_vector_input(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
input: typing.Optional[V3] = None,
) -> typing.Optional[int]:
diff --git a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3 b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3
index 4cf4659a..509bd484 100644
--- a/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3
+++ b/tests/codegen/test-project2/scalar/custom_vector_input_edgeql.py.assert3
@@ -1,9 +1,9 @@
# AUTOGENERATED FROM 'scalar/custom_vector_input.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
import typing
@@ -11,7 +11,7 @@ Input = typing.Sequence[float]
def custom_vector_input(
- executor: edgedb.Executor,
+ executor: gel.Executor,
*,
input: typing.Optional[Input] = None,
) -> typing.Optional[int]:
diff --git a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert
index a9fc5917..327be5e3 100644
--- a/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/select_scalar_async_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
async def select_scalar(
- executor: edgedb.AsyncIOExecutor,
+ executor: gel.AsyncIOExecutor,
) -> int:
return await executor.query_single(
"""\
diff --git a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert
index c8cfeb1b..0337455b 100644
--- a/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/select_scalar_edgeql.py.assert
@@ -1,13 +1,13 @@
# AUTOGENERATED FROM 'scalar/select_scalar.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
def select_scalar(
- executor: edgedb.Executor,
+ executor: gel.Executor,
) -> int:
return executor.query_single(
"""\
diff --git a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert
index d7e9b233..a7291850 100644
--- a/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/select_scalars_async_edgeql.py.assert
@@ -1,15 +1,15 @@
# AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH:
-# $ edgedb-py
+# $ gel-py
from __future__ import annotations
-import edgedb
+import gel
import typing
async def select_scalars(
- executor: edgedb.AsyncIOExecutor,
-) -> typing.List[edgedb.ConfigMemory]:
+ executor: gel.AsyncIOExecutor,
+) -> typing.List[gel.ConfigMemory]:
return await executor.query(
"""\
select {1, 2, 3};\
diff --git a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert
index 74016c24..9cf2ebb4 100644
--- a/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert
+++ b/tests/codegen/test-project2/scalar/select_scalars_edgeql.py.assert
@@ -1,14 +1,14 @@
# AUTOGENERATED FROM 'scalar/select_scalars.edgeql' WITH:
-# $ edgedb-py --target blocking --no-skip-pydantic-validation
+# $ gel-py --target blocking --no-skip-pydantic-validation
from __future__ import annotations
-import edgedb
+import gel
def select_scalars(
- executor: edgedb.Executor,
-) -> list[edgedb.ConfigMemory]:
+ executor: gel.Executor,
+) -> list[gel.ConfigMemory]:
return executor.query(
"""\
select {1, 2, 3};\
diff --git a/tests/test_async_query.py b/tests/test_async_query.py
index da731bb3..851afd02 100644
--- a/tests/test_async_query.py
+++ b/tests/test_async_query.py
@@ -29,7 +29,7 @@
import edgedb
from edgedb import abstract
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb.options import RetryOptions
from edgedb.protocol import protocol
@@ -988,7 +988,7 @@ async def test_enum_argument_01(self):
await tx.query_single('SELECT $0', 'Oups')
with self.assertRaisesRegex(
- edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'):
+ edgedb.InvalidArgumentError, 'a str or gel.EnumValue'):
await self.client.query_single('SELECT $0', 123)
async def test_enum_argument_02(self):
diff --git a/tests/test_async_retry.py b/tests/test_async_retry.py
index d47c5603..0fb13d08 100644
--- a/tests/test_async_retry.py
+++ b/tests/test_async_retry.py
@@ -24,7 +24,7 @@
import edgedb
from edgedb import errors
from edgedb import RetryOptions
-from edgedb import _testbase as tb
+from gel import _testbase as tb
log = logging.getLogger(__name__)
diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py
index 8ceeb239..030b4672 100644
--- a/tests/test_async_tx.py
+++ b/tests/test_async_tx.py
@@ -21,7 +21,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import TransactionOptions
from edgedb.options import RetryOptions
diff --git a/tests/test_asyncio_client.py b/tests/test_asyncio_client.py
index b4d5b03d..3971029c 100644
--- a/tests/test_asyncio_client.py
+++ b/tests/test_asyncio_client.py
@@ -21,7 +21,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import errors
from edgedb import asyncio_client
diff --git a/tests/test_blocking_client.py b/tests/test_blocking_client.py
index 099baa71..4ebc153d 100644
--- a/tests/test_blocking_client.py
+++ b/tests/test_blocking_client.py
@@ -24,7 +24,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import errors
from edgedb import blocking_client
diff --git a/tests/test_codegen.py b/tests/test_codegen.py
index 801b00ae..3124f550 100644
--- a/tests/test_codegen.py
+++ b/tests/test_codegen.py
@@ -24,7 +24,7 @@
import os
import tempfile
-from edgedb import _testbase as tb
+from gel import _testbase as tb
ASSERT_SUFFIX = os.environ.get("EDGEDB_TEST_CODEGEN_ASSERT_SUFFIX", ".assert")
@@ -91,7 +91,7 @@ async def run(*args, extra_env=None):
p.returncode, args, output=await p.stdout.read(),
)
- cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "edgedb-py")
+ cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "gel-py")
await run(
cmd, extra_env={"EDGEDB_PYTHON_CODEGEN_PY_VER": "3.8.5"}
)
diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py
index 24027f06..13454671 100644
--- a/tests/test_con_utils.py
+++ b/tests/test_con_utils.py
@@ -29,7 +29,7 @@
from unittest import mock
-from edgedb import con_utils
+from gel import con_utils
from edgedb import errors
@@ -389,9 +389,9 @@ def test_connect_params(self):
self.run_testcase(testcase)
- @mock.patch("edgedb.platform.config_dir",
+ @mock.patch("gel.platform.config_dir",
lambda: pathlib.Path("/home/user/.config/edgedb"))
- @mock.patch("edgedb.platform.IS_WINDOWS", False)
+ @mock.patch("gel.platform.IS_WINDOWS", False)
@mock.patch("pathlib.Path.exists", lambda p: True)
@mock.patch("os.path.realpath", lambda p: p)
def test_stash_path(self):
@@ -423,7 +423,7 @@ def test_project_config(self):
"database": "inst1_db",
}))
- with mock.patch('edgedb.platform.config_dir',
+ with mock.patch('gel.platform.config_dir',
lambda: home / '.edgedb'), \
mock.patch('os.getcwd', lambda: str(project)):
stash_path = con_utils._stash_path(project)
diff --git a/tests/test_connect.py b/tests/test_connect.py
index e22f1958..48cebade 100644
--- a/tests/test_connect.py
+++ b/tests/test_connect.py
@@ -21,7 +21,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
class TestConnect(tb.AsyncQueryTestCase):
diff --git a/tests/test_datetime.py b/tests/test_datetime.py
index ff7dfbc0..bb19a5b1 100644
--- a/tests/test_datetime.py
+++ b/tests/test_datetime.py
@@ -20,7 +20,7 @@
import random
from datetime import timedelta
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import errors
from edgedb.datatypes.datatypes import RelativeDuration, DateDuration
@@ -178,7 +178,7 @@ async def test_relative_duration_02(self):
self.assertEqual(d1.months, 0)
self.assertEqual(d1.microseconds, 1)
- self.assertEqual(repr(d1), '')
+ self.assertEqual(repr(d1), '')
async def test_relative_duration_03(self):
# Make sure that when we break down the microseconds into the bigger
diff --git a/tests/test_enum.py b/tests/test_enum.py
index 1f3b6fae..fa0040c4 100644
--- a/tests/test_enum.py
+++ b/tests/test_enum.py
@@ -22,7 +22,7 @@
import uuid
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
class TestEnum(tb.AsyncQueryTestCase):
@@ -40,7 +40,7 @@ async def test_enum_01(self):
self.assertTrue(isinstance(ct_red, edgedb.EnumValue))
self.assertTrue(isinstance(ct_red.__tid__, uuid.UUID))
- self.assertEqual(repr(ct_red), "")
+ self.assertEqual(repr(ct_red), "")
self.assertEqual(str(ct_red), 'red')
with self.assertRaises(TypeError):
diff --git a/tests/test_errors.py b/tests/test_errors.py
index f76efa28..209c80e3 100644
--- a/tests/test_errors.py
+++ b/tests/test_errors.py
@@ -20,8 +20,8 @@
import unittest
-from edgedb import errors
-from edgedb.errors import _base as base_errors
+from gel import errors
+from gel.errors import _base as base_errors
class TestErrors(unittest.TestCase):
diff --git a/tests/test_globals.py b/tests/test_globals.py
index f4c72029..ee9a2bb0 100644
--- a/tests/test_globals.py
+++ b/tests/test_globals.py
@@ -17,7 +17,7 @@
#
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import errors
diff --git a/tests/test_memory.py b/tests/test_memory.py
index 63c032e2..0004d1a7 100644
--- a/tests/test_memory.py
+++ b/tests/test_memory.py
@@ -16,7 +16,7 @@
# limitations under the License.
#
-from edgedb import _testbase as tb
+from gel import _testbase as tb
class TestConfigMemory(tb.SyncQueryTestCase):
diff --git a/tests/test_namedtuples.py b/tests/test_namedtuples.py
index 04a74ddc..3a525449 100644
--- a/tests/test_namedtuples.py
+++ b/tests/test_namedtuples.py
@@ -20,7 +20,7 @@
from collections import namedtuple, UserDict
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
class TestNamedTupleTypes(tb.SyncQueryTestCase):
diff --git a/tests/test_postgis.py b/tests/test_postgis.py
index 9d6583ef..029ff57d 100644
--- a/tests/test_postgis.py
+++ b/tests/test_postgis.py
@@ -18,7 +18,7 @@
from collections import namedtuple
-from edgedb import _testbase as tb
+from gel import _testbase as tb
Geo = namedtuple('Geo', ['wkb'])
diff --git a/tests/test_proto.py b/tests/test_proto.py
index 48be7f15..9b984ce5 100644
--- a/tests/test_proto.py
+++ b/tests/test_proto.py
@@ -20,7 +20,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
class TestProto(tb.SyncQueryTestCase):
diff --git a/tests/test_state.py b/tests/test_state.py
index d59ef1d5..e7047765 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -17,7 +17,7 @@
#
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import State
diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py
index 36566282..f6d27f60 100644
--- a/tests/test_sync_query.py
+++ b/tests/test_sync_query.py
@@ -27,7 +27,7 @@
import edgedb
from edgedb import abstract
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb.protocol import protocol
@@ -767,7 +767,7 @@ def test_enum_argument_01(self):
tx.query_single('SELECT $0', 'Oups')
with self.assertRaisesRegex(
- edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'
+ edgedb.InvalidArgumentError, 'a str or gel.EnumValue'
):
self.client.query_single('SELECT $0', 123)
diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py
index ae32c633..8b1ef1a6 100644
--- a/tests/test_sync_retry.py
+++ b/tests/test_sync_retry.py
@@ -23,10 +23,8 @@
import unittest.mock
from concurrent import futures
-import edgedb
-from edgedb import _testbase as tb
-from edgedb import errors
-from edgedb import RetryOptions
+import gel
+from gel import _testbase as tb
class Barrier:
@@ -82,7 +80,7 @@ def test_sync_retry_02(self):
};
''')
1 / 0
- with self.assertRaises(edgedb.NoDataError):
+ with self.assertRaises(gel.NoDataError):
self.client.query_required_single('''
SELECT test::Counter
FILTER .name = 'counter_retry_02'
@@ -109,9 +107,9 @@ def cleanup():
self.addCleanup(cleanup)
- _start.side_effect = errors.BackendUnavailableError()
+ _start.side_effect = gel.BackendUnavailableError()
- with self.assertRaises(errors.BackendUnavailableError):
+ with self.assertRaises(gel.BackendUnavailableError):
for tx in self.client.transaction():
with tx:
tx.execute('''
@@ -119,7 +117,7 @@ def cleanup():
name := 'counter_retry_begin'
};
''')
- with self.assertRaises(edgedb.NoDataError):
+ with self.assertRaises(gel.NoDataError):
self.client.query_required_single('''
SELECT test::Counter
FILTER .name = 'counter_retry_begin'
@@ -134,7 +132,7 @@ def cleanup():
def recover_after_first_error(*_, **__):
patcher.stop()
- raise errors.BackendUnavailableError()
+ raise gel.BackendUnavailableError()
_start.side_effect = recover_after_first_error
call_count = _start.call_count
@@ -156,16 +154,16 @@ def test_sync_retry_conflict(self):
self.execute_conflict('counter2')
def test_sync_conflict_no_retry(self):
- with self.assertRaises(edgedb.TransactionSerializationError):
+ with self.assertRaises(gel.TransactionSerializationError):
self.execute_conflict(
'counter3',
- RetryOptions(attempts=1, backoff=edgedb.default_backoff)
+ gel.RetryOptions(attempts=1, backoff=gel.default_backoff)
)
def execute_conflict(self, name='counter2', options=None):
con_args = self.get_connect_args().copy()
con_args.update(database=self.get_database_name())
- client2 = edgedb.create_client(**con_args)
+ client2 = gel.create_client(**con_args)
self.addCleanup(client2.close)
barrier = Barrier(2)
@@ -243,13 +241,13 @@ def test_sync_transaction_interface_errors(self):
for tx in self.client.transaction():
tx.start()
- with self.assertRaisesRegex(edgedb.InterfaceError,
+ with self.assertRaisesRegex(gel.InterfaceError,
r'.*Use `with transaction:`'):
for tx in self.client.transaction():
tx.execute("SELECT 123")
with self.assertRaisesRegex(
- edgedb.InterfaceError,
+ gel.InterfaceError,
r"already in a `with` block",
):
for tx in self.client.transaction():
diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py
index 497af782..e7774f8f 100644
--- a/tests/test_sync_tx.py
+++ b/tests/test_sync_tx.py
@@ -21,7 +21,7 @@
import edgedb
-from edgedb import _testbase as tb
+from gel import _testbase as tb
from edgedb import TransactionOptions
diff --git a/tests/test_vector.py b/tests/test_vector.py
index 514ded14..b4708281 100644
--- a/tests/test_vector.py
+++ b/tests/test_vector.py
@@ -16,7 +16,7 @@
# limitations under the License.
#
-from edgedb import _testbase as tb
+from gel import _testbase as tb
import edgedb
import array
diff --git a/tools/make_import_shims.py b/tools/make_import_shims.py
new file mode 100644
index 00000000..60325af8
--- /dev/null
+++ b/tools/make_import_shims.py
@@ -0,0 +1,42 @@
+import os
+import sys
+
+MODS = sorted(['gel', 'gel._taskgroup', 'gel._version', 'gel.abstract', 'gel.ai', 'gel.ai.core', 'gel.ai.types', 'gel.asyncio_client', 'gel.base_client', 'gel.blocking_client', 'gel.codegen', 'gel.color', 'gel.con_utils', 'gel.credentials', 'gel.datatypes', 'gel.datatypes.datatypes', 'gel.datatypes.range', 'gel.describe', 'gel.enums', 'gel.errors', 'gel.errors._base', 'gel.errors.tags', 'gel.introspect', 'gel.options', 'gel.pgproto', 'gel.pgproto.pgproto', 'gel.pgproto.types', 'gel.platform', 'gel.protocol', 'gel.protocol.asyncio_proto', 'gel.protocol.blocking_proto', 'gel.protocol.protocol', 'gel.scram', 'gel.scram.saslprep', 'gel.transaction'])
+
+
+
+def main():
+ for mod in MODS:
+ is_package = any(k.startswith(mod + '.') for k in MODS)
+
+ nmod = 'edgedb' + mod[len('gel'):]
+ slash_name = nmod.replace('.', '/')
+ if is_package:
+ os.mkdir(slash_name)
+ fname = slash_name + '/__init__.py'
+ else:
+ fname = slash_name + '.py'
+
+ # import * skips things not in __all__ or with underscores at
+ # the start, so we have to do some nonsense.
+ with open(fname, 'w') as f:
+ f.write(f'''\
+# Auto-generated shim
+import {mod} as _mod
+import sys as _sys
+_cur = _sys.modules['{nmod}']
+for _k in vars(_mod):
+ if not _k.startswith('__') or _k in ('__all__', '__doc__'):
+ setattr(_cur, _k, getattr(_mod, _k))
+del _cur
+del _sys
+del _mod
+del _k
+''')
+
+ with open('edgedb/py.typed', 'w') as f:
+ pass
+
+
+if __name__ == '__main__':
+ main()