diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c65f68a6..d6c4f051 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -168,13 +168,15 @@ def _read_password_from_pgpass( def _validate_port_spec(hosts, port): - if isinstance(port, list): + if isinstance(port, list) and len(port) > 1: # If there is a list of ports, its length must # match that of the host list. if len(port) != len(hosts): raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) + elif isinstance(port, list) and len(port) == 1: + port = [port[0] for _ in range(len(hosts))] else: port = [port for _ in range(len(hosts))] diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 89aeb21c..3cd73173 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -312,7 +312,12 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, + query: str, + *args, + timeout: typing.Optional[float]=None, + ) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -359,7 +364,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args, + *, + timeout: typing.Optional[float]=None, + ): """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -395,7 +406,7 @@ async def _get_statement( query, timeout, *, - named=False, + named: typing.Union[str, bool, None] = False, use_cache=True, ignore_custom_codec=False, record_class=None @@ -535,26 +546,18 @@ async def _introspect_types(self, typeoids, timeout): return result async def _introspect_type(self, typename, schema): - if ( - schema == 'pg_catalog' - and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP - ): - typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] - rows = await self._execute( - introspection.TYPE_BY_OID, - [typeoid], - limit=0, - timeout=None, - ignore_custom_codec=True, - ) - else: - rows = await self._execute( - introspection.TYPE_BY_NAME, - [typename, schema], - limit=1, - timeout=None, - ignore_custom_codec=True, - ) + if schema == 'pg_catalog' and not typename.endswith("[]"): + typeoid = protocol.BUILTIN_TYPE_NAME_MAP.get(typename.lower()) + if typeoid is not None: + return introspection.TypeRecord((typeoid, None, b"b")) + + rows = await self._execute( + introspection.TYPE_BY_NAME, + [typename, schema], + limit=1, + timeout=None, + ignore_custom_codec=True, + ) if not rows: raise ValueError( @@ -637,7 +640,6 @@ async def prepare( query, name=name, timeout=timeout, - use_cache=False, record_class=record_class, ) @@ -645,16 +647,18 @@ async def _prepare( self, query, *, - name=None, + name: typing.Union[str, bool, None] = None, timeout=None, use_cache: bool=False, record_class=None ): self._check_open() + if name is None: + name = self._stmt_cache_enabled stmt = await self._get_statement( query, timeout, - named=True if name is None else name, + named=name, use_cache=use_cache, record_class=record_class, ) @@ -758,7 +762,12 @@ async def fetchrow( return data[0] async def fetchmany( - self, query, args, *, timeout: float=None, record_class=None + self, + query, + args, + *, + timeout: typing.Optional[float]=None, + record_class=None, ): """Run a query for each sequence of arguments in *args* and return the results as a list of :class:`Record`. @@ -1108,7 +1117,7 @@ async def copy_records_to_table(self, table_name, *, records, intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( tab=tabname, cols=col_list) - intro_ps = await self._prepare(intro_query, use_cache=True) + intro_ps = await self.prepare(intro_query) cond = self._format_copy_where(where) opts = '(FORMAT binary)' diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index 641cf700..c3b4e60c 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -7,10 +7,12 @@ from __future__ import annotations import typing +from .protocol.protocol import _create_record # type: ignore if typing.TYPE_CHECKING: from . import protocol + _TYPEINFO_13: typing.Final = '''\ ( SELECT @@ -267,16 +269,12 @@ ''' -TYPE_BY_OID = '''\ -SELECT - t.oid, - t.typelem AS elemtype, - t.typtype AS kind -FROM - pg_catalog.pg_type AS t -WHERE - t.oid = $1 -''' +def TypeRecord( + rec: typing.Tuple[int, typing.Optional[int], bytes], +) -> protocol.Record: + assert len(rec) == 3 + return _create_record( # type: ignore + {"oid": 0, "elemtype": 1, "kind": 2}, rec) # 'b' for a base type, 'd' for a domain, 'e' for enum. diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 2e4a7b4f..5c7ea9ca 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -574,7 +574,12 @@ async def _get_new_connection(self): return con - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute( + self, + query: str, + *args, + timeout: Optional[float]=None, + ) -> str: """Execute an SQL command (or commands). Pool performs this operation using one of its connections. Other than @@ -586,7 +591,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany( + self, + command: str, + args, + *, + timeout: Optional[float]=None, + ): """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index d66a5ad3..0c2d335e 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -6,6 +6,7 @@ import json +import typing from . import connresource from . import cursor @@ -232,7 +233,7 @@ async def fetchmany(self, args, *, timeout=None): ) @connresource.guarded - async def executemany(self, args, *, timeout: float=None): + async def executemany(self, args, *, timeout: typing.Optional[float]=None): """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. diff --git a/tests/test_connect.py b/tests/test_connect.py index 0037ee5e..62cabc47 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -846,25 +846,26 @@ class TestConnectParams(tb.TestCase): ), }, - { - 'name': 'dsn_ipv6_multi_host', - 'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db', - 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { - 'database': 'db', - 'user': 'user', - 'target_session_attrs': 'any', - }) - }, - - { - 'name': 'dsn_ipv6_multi_host_port', - 'dsn': 'postgresql://user@[2001:db8::1234]:1111,[::1]:2222/db', - 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { - 'database': 'db', - 'user': 'user', - 'target_session_attrs': 'any', - }) - }, + # broken by https://github.com/python/cpython/pull/129418 + # { + # 'name': 'dsn_ipv6_multi_host', + # 'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db', + # 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { + # 'database': 'db', + # 'user': 'user', + # 'target_session_attrs': 'any', + # }) + # }, + + # { + # 'name': 'dsn_ipv6_multi_host_port', + # 'dsn': 'postgresql://user@[2001:db8::1234]:1111,[::1]:2222/db', + # 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { + # 'database': 'db', + # 'user': 'user', + # 'target_session_attrs': 'any', + # }) + # }, { 'name': 'dsn_ipv6_multi_host_query_part', @@ -1087,6 +1088,21 @@ class TestConnectParams(tb.TestCase): } ) }, + { + 'name': 'multi_host_single_port', + 'dsn': 'postgres:///postgres?host=127.0.0.1,127.0.0.2&port=5432' + '&user=postgres', + 'result': ( + [ + ('127.0.0.1', 5432), + ('127.0.0.2', 5432) + ], { + 'user': 'postgres', + 'database': 'postgres', + 'target_session_attrs': 'any', + } + ) + }, ] @contextlib.contextmanager