Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions aiosqlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
Core implementation of aiosqlite proxies
"""
from __future__ import annotations

import asyncio
import logging
Expand All @@ -13,7 +14,7 @@
from pathlib import Path
from queue import Empty, Queue, SimpleQueue
from threading import Thread
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional
from warnings import warn

from .context import contextmanager
Expand Down Expand Up @@ -47,11 +48,11 @@ def __init__(
self,
connector: Callable[[], sqlite3.Connection],
iter_chunk_size: int,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
super().__init__()
self._running = True
self._connection: Optional[sqlite3.Connection] = None
self._connection: sqlite3.Connection | None = None
self._connector = connector
self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
self._iter_chunk_size = iter_chunk_size
Expand All @@ -74,7 +75,7 @@ def _conn(self) -> sqlite3.Connection:

return self._connection

def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
def _execute_insert(self, sql: str, parameters: Any) -> sqlite3.Row | None:
cursor = self._conn.execute(sql, parameters)
cursor.execute("SELECT last_insert_rowid()")
return cursor.fetchone()
Expand Down Expand Up @@ -121,7 +122,7 @@ async def _execute(self, fn, *args, **kwargs):

return await future

async def _connect(self) -> "Connection":
async def _connect(self) -> Connection:
"""Connect to the actual sqlite database."""
if self._connection is None:
try:
Expand All @@ -135,11 +136,11 @@ async def _connect(self) -> "Connection":

return self

def __await__(self) -> Generator[Any, None, "Connection"]:
def __await__(self) -> Generator[Any, None, Connection]:
self.start()
return self._connect().__await__()

async def __aenter__(self) -> "Connection":
async def __aenter__(self) -> Connection:
return await self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
Expand Down Expand Up @@ -175,7 +176,7 @@ async def close(self) -> None:

@contextmanager
async def execute(
self, sql: str, parameters: Optional[Iterable[Any]] = None
self, sql: str, parameters: Iterable[Any] | None = None
) -> Cursor:
"""Helper to create a cursor and execute the given query."""
if parameters is None:
Expand All @@ -185,16 +186,16 @@ async def execute(

@contextmanager
async def execute_insert(
self, sql: str, parameters: Optional[Iterable[Any]] = None
) -> Optional[sqlite3.Row]:
self, sql: str, parameters: Iterable[Any] | None = None
) -> sqlite3.Row | None:
"""Helper to insert and get the last_insert_rowid."""
if parameters is None:
parameters = []
return await self._execute(self._execute_insert, sql, parameters)

@contextmanager
async def execute_fetchall(
self, sql: str, parameters: Optional[Iterable[Any]] = None
self, sql: str, parameters: Iterable[Any] | None = None
) -> Iterable[sqlite3.Row]:
"""Helper to execute a query and return all the data."""
if parameters is None:
Expand Down Expand Up @@ -246,19 +247,19 @@ def in_transaction(self) -> bool:
return self._conn.in_transaction

@property
def isolation_level(self) -> Optional[str]:
def isolation_level(self) -> str | None:
return self._conn.isolation_level

@isolation_level.setter
def isolation_level(self, value: IsolationLevel) -> None:
self._conn.isolation_level = value

@property
def row_factory(self) -> Optional[type]:
def row_factory(self) -> type | None:
return self._conn.row_factory

@row_factory.setter
def row_factory(self, factory: Optional[type]) -> None:
def row_factory(self, factory: type | None) -> None:
self._conn.row_factory = factory

@property
Expand All @@ -280,7 +281,7 @@ async def load_extension(self, path: str):
await self._execute(self._conn.load_extension, path) # type: ignore

async def set_progress_handler(
self, handler: Callable[[], Optional[int]], n: int
self, handler: Callable[[], int | None], n: int
) -> None:
await self._execute(self._conn.set_progress_handler, handler, n)

Expand Down Expand Up @@ -315,7 +316,7 @@ def dumper():

while True:
try:
line: Optional[str] = dump_queue.get_nowait()
line: str | None = dump_queue.get_nowait()
if line is None:
break
yield line
Expand All @@ -331,10 +332,10 @@ def dumper():

async def backup(
self,
target: Union["Connection", sqlite3.Connection],
target: Connection | sqlite3.Connection,
*,
pages: int = 0,
progress: Optional[Callable[[int, int, int], None]] = None,
progress: Callable[[int, int, int], None] | None = None,
name: str = "main",
sleep: float = 0.250,
) -> None:
Expand All @@ -357,10 +358,10 @@ async def backup(


def connect(
database: Union[str, Path],
database: str | Path,
*,
iter_chunk_size=64,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop | None = None,
**kwargs: Any,
) -> Connection:
"""Create and return a connection proxy to the sqlite database."""
Expand Down
23 changes: 12 additions & 11 deletions aiosqlite/cursor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright Amethyst Reese
# Licensed under the MIT license
from __future__ import annotations

import sqlite3
from collections.abc import AsyncIterator, Iterable
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING

if TYPE_CHECKING:
from .core import Connection


class Cursor:
def __init__(self, conn: "Connection", cursor: sqlite3.Cursor) -> None:
def __init__(self, conn: Connection, cursor: sqlite3.Cursor) -> None:
self.iter_chunk_size = conn._iter_chunk_size
self._conn = conn
self._cursor = cursor
Expand All @@ -32,8 +33,8 @@ async def _execute(self, fn, *args, **kwargs):
return await self._conn._execute(fn, *args, **kwargs)

async def execute(
self, sql: str, parameters: Optional[Iterable[Any]] = None
) -> "Cursor":
self, sql: str, parameters: Iterable[Any] | None = None
) -> Cursor:
"""Execute the given query."""
if parameters is None:
parameters = []
Expand All @@ -42,21 +43,21 @@ async def execute(

async def executemany(
self, sql: str, parameters: Iterable[Iterable[Any]]
) -> "Cursor":
) -> Cursor:
"""Execute the given multiquery."""
await self._execute(self._cursor.executemany, sql, parameters)
return self

async def executescript(self, sql_script: str) -> "Cursor":
async def executescript(self, sql_script: str) -> Cursor:
"""Execute a user script."""
await self._execute(self._cursor.executescript, sql_script)
return self

async def fetchone(self) -> Optional[sqlite3.Row]:
async def fetchone(self) -> sqlite3.Row | None:
"""Fetch a single row."""
return await self._execute(self._cursor.fetchone)

async def fetchmany(self, size: Optional[int] = None) -> Iterable[sqlite3.Row]:
async def fetchmany(self, size: int | None = None) -> Iterable[sqlite3.Row]:
"""Fetch up to `cursor.arraysize` number of rows."""
args: tuple[int, ...] = ()
if size is not None:
Expand All @@ -76,7 +77,7 @@ def rowcount(self) -> int:
return self._cursor.rowcount

@property
def lastrowid(self) -> Optional[int]:
def lastrowid(self) -> int | None:
return self._cursor.lastrowid

@property
Expand All @@ -92,11 +93,11 @@ def description(self) -> tuple[tuple[str, None, None, None, None, None, None], .
return self._cursor.description

@property
def row_factory(self) -> Optional[Callable[[sqlite3.Cursor, sqlite3.Row], object]]:
def row_factory(self) -> Callable[[sqlite3.Cursor, sqlite3.Row], object] | None:
return self._cursor.row_factory

@row_factory.setter
def row_factory(self, factory: Optional[type]) -> None:
def row_factory(self, factory: type | None) -> None:
self._cursor.row_factory = factory

@property
Expand Down
2 changes: 2 additions & 0 deletions aiosqlite/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
# Licensed under the MIT license

from .smoke import SmokeTest

__all__ = ("SmokeTest",)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ classifiers = [
]
requires-python = ">=3.9"
dependencies = [
"typing_extensions >= 4.0",
]

[project.optional-dependencies]
Expand Down