Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 1.6.1 /2025-02-03
* RuntimeCache updates by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/260
* fix memory leak by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/261
* Avoid Race Condition on SQLite Table Creation by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/263

**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.6.0...v1.6.1

## 1.6.0 /2025-01-27
* Fix typo by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/258
* Improve Disk Caching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/227
Expand Down
43 changes: 34 additions & 9 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import socket
import ssl
import warnings
from contextlib import suppress
from unittest.mock import AsyncMock
from hashlib import blake2b
from typing import (
Expand Down Expand Up @@ -40,7 +39,6 @@
from websockets.asyncio.client import connect, ClientConnection
from websockets.exceptions import (
ConnectionClosed,
WebSocketException,
)
from websockets.protocol import State

Expand Down Expand Up @@ -708,6 +706,10 @@ async def _cancel(self):
logger.debug("Cancelling send/recv tasks")
if self._send_recv_task is not None:
self._send_recv_task.cancel()
try:
await self._send_recv_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
except Exception as e:
Expand Down Expand Up @@ -777,16 +779,31 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]:
logger.debug("WS handler attached")
recv_task = asyncio.create_task(self._start_receiving(ws))
send_task = asyncio.create_task(self._start_sending(ws))
done, pending = await asyncio.wait(
[recv_task, send_task],
return_when=asyncio.FIRST_COMPLETED,
)
try:
done, pending = await asyncio.wait(
[recv_task, send_task],
return_when=asyncio.FIRST_COMPLETED,
)
except asyncio.CancelledError:
# Handler was cancelled, clean up child tasks
for task in [recv_task, send_task]:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
raise
loop = asyncio.get_running_loop()
should_reconnect = False
is_retry = False

for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

for task in done:
task_res = task.result()
Expand Down Expand Up @@ -887,6 +904,14 @@ async def _exit_with_timer(self):

async def shutdown(self):
logger.debug("Shutdown requested")
# Cancel the exit timer task if it exists
if self._exit_task is not None:
self._exit_task.cancel()
try:
await self._exit_task
except asyncio.CancelledError:
pass
self._exit_task = None
try:
await asyncio.wait_for(self._cancel(), timeout=10.0)
except asyncio.TimeoutError:
Expand Down Expand Up @@ -990,8 +1015,9 @@ async def _start_sending(self, ws) -> Exception:
)
if to_send is not None:
to_send_ = json.loads(to_send)
self._received[to_send_["id"]].set_exception(e)
self._received[to_send_["id"]].cancel()
if to_send_["id"] in self._received:
self._received[to_send_["id"]].set_exception(e)
self._received[to_send_["id"]].cancel()
else:
for i in self._received.keys():
self._received[i].set_exception(e)
Expand Down Expand Up @@ -1975,7 +2001,6 @@ async def result_handler(

if subscription_result is not None:
reached = True
logger.info("REACHED!")
# Handler returned end result: unsubscribe from further updates
async with self.ws as ws:
await ws.unsubscribe(
Expand Down
7 changes: 7 additions & 0 deletions async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,5 +3398,12 @@ def close(self):
self.ws.close()
except AttributeError:
pass
# Clear lru_cache on instance methods to allow garbage collection
self.get_runtime_for_version.cache_clear()
self.get_parent_block_hash.cache_clear()
self.get_block_runtime_info.cache_clear()
self.get_block_runtime_version_for.cache_clear()
self.supports_rpc_method.cache_clear()
self.get_block_hash.cache_clear()

encode_scale = SubstrateMixin._encode_scale
102 changes: 77 additions & 25 deletions async_substrate_interface/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import bisect
import logging
import os
from abc import ABC
from collections import defaultdict, deque
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Sequence

import scalecodec.types
from bt_decode import PortableRegistry, encode as encode_by_type_string
Expand All @@ -17,9 +19,11 @@

from .const import SS58_FORMAT
from .utils import json
from .utils.cache import AsyncSqliteDB
from .utils.cache import AsyncSqliteDB, LRUCache

logger = logging.getLogger("async_substrate_interface")
SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16"))
SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512"))


class RuntimeCache:
Expand All @@ -41,11 +45,45 @@ class RuntimeCache:
versions: dict[int, "Runtime"]
last_used: Optional["Runtime"]

def __init__(self):
self.blocks = {}
self.block_hashes = {}
self.versions = {}
self.last_used = None
def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None):
# {block: block_hash, ...}
self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {block_hash: specVersion, ...}
self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {specVersion: Runtime, ...}
self.versions: LRUCache = LRUCache(max_size=SUBSTRATE_RUNTIME_CACHE_SIZE)
# [(block, specVersion), ...]
self.known_versions: list[tuple[int, int]] = []
# [block, ...] for binary search (excludes last item)
self._known_version_blocks: list[int] = []
if known_versions:
self.add_known_versions(known_versions)
self.last_used: Optional["Runtime"] = None

def add_known_versions(self, known_versions: Sequence[tuple[int, int]]):
"""
Known versions are a map of {block: specVersion} for when runtimes change.

E.g.
[
(561, 102),
(1075, 103),
...,
(7257645, 367)
]

This mapping is generally user-created or pulled from an external API, such as
https://api.tao.app/docs#/chain/get_runtime_versions_api_beta_chain_runtime_version_get

By preloading the known versions, there can be significantly fewer chain calls to determine version.

Note that because the last runtime in the supplied known versions will be ignored, as otherwise we would
have to assume that the final known version never changes.
"""
known_versions = list(sorted(known_versions, key=lambda v: v[0]))
self.known_versions = known_versions
# Cache block numbers (excluding last) for O(log n) binary search lookups
self._known_version_blocks = [v[0] for v in known_versions[:-1]]

def add_item(
self,
Expand All @@ -59,11 +97,11 @@ def add_item(
"""
self.last_used = runtime
if block is not None and block_hash is not None:
self.blocks[block] = block_hash
self.blocks.set(block, block_hash)
if block_hash is not None and runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
self.block_hashes.set(block_hash, runtime_version)
if runtime_version is not None:
self.versions[runtime_version] = runtime
self.versions.set(runtime_version, runtime)

def retrieve(
self,
Expand All @@ -75,26 +113,35 @@ def retrieve(
Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version.
Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`.
"""
# No reason to do this lookup if the runtime version is already supplied in this call
if block is not None and runtime_version is None and self._known_version_blocks:
# _known_version_blocks excludes the last item (see note in `add_known_versions`)
idx = bisect.bisect_right(self._known_version_blocks, block) - 1
if idx >= 0:
runtime_version = self.known_versions[idx][1]

runtime = None
if block is not None:
if block_hash is not None:
self.blocks[block] = block_hash
self.blocks.set(block, block_hash)
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[self.blocks[block]]]
self.block_hashes.set(block_hash, runtime_version)
with suppress(AttributeError):
runtime = self.versions.get(
self.block_hashes.get(self.blocks.get(block))
)
self.last_used = runtime
return runtime
if block_hash is not None:
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[block_hash]]
self.block_hashes.set(block_hash, runtime_version)
with suppress(AttributeError):
runtime = self.versions.get(self.block_hashes.get(block_hash))
self.last_used = runtime
return runtime
if runtime_version is not None:
with suppress(KeyError):
runtime = self.versions[runtime_version]
runtime = self.versions.get(runtime_version)
if runtime is not None:
self.last_used = runtime
return runtime
return runtime
Expand All @@ -110,16 +157,21 @@ async def load_from_disk(self, chain_endpoint: str):
logger.debug("No runtime mappings in disk cache")
else:
logger.debug("Found runtime mappings in disk cache")
self.blocks = block_mapping
self.block_hashes = block_hash_mapping
self.versions = {
x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items()
}
self.blocks.cache = block_mapping
self.block_hashes.cache = block_hash_mapping
for x, y in runtime_version_mapping.items():
self.versions.cache[x] = Runtime.deserialize(y)

async def dump_to_disk(self, chain_endpoint: str):
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
blocks = self.blocks.cache
block_hashes = self.block_hashes.cache
versions = self.versions.cache
await db.dump_runtime_cache(
chain_endpoint, self.blocks, self.block_hashes, self.versions
chain=chain_endpoint,
block_mapping=blocks,
block_hash_mapping=block_hashes,
version_mapping=versions,
)


Expand Down
41 changes: 33 additions & 8 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import weakref
from collections import OrderedDict
import functools
import logging
Expand Down Expand Up @@ -60,6 +61,7 @@ async def _create_if_not_exists(self, chain: str, table_name: str):
);
"""
)
await self._db.commit()
await self._db.execute(
f"""
CREATE TRIGGER IF NOT EXISTS prune_rows_trigger_{table_name} AFTER INSERT ON {table_name}
Expand All @@ -81,8 +83,8 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
if not self._db:
_ensure_dir()
self._db = await aiosqlite.connect(CACHE_LOCATION)
table_name = _get_table_name(func)
local_chain = await self._create_if_not_exists(chain, table_name)
table_name = _get_table_name(func)
local_chain = await self._create_if_not_exists(chain, table_name)
key = pickle.dumps((args, kwargs or None))
try:
cursor: aiosqlite.Cursor = await self._db.execute(
Expand Down Expand Up @@ -111,9 +113,9 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
if not self._db:
_ensure_dir()
self._db = await aiosqlite.connect(CACHE_LOCATION)
block_mapping = {}
block_hash_mapping = {}
version_mapping = {}
block_mapping = OrderedDict()
block_hash_mapping = OrderedDict()
version_mapping = OrderedDict()
tables = {
"RuntimeCache_blocks": block_mapping,
"RuntimeCache_block_hashes": block_hash_mapping,
Expand Down Expand Up @@ -419,6 +421,26 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._inflight.pop(key, None)


class _WeakMethod:
"""
Weak reference to a bound method that allows the instance to be garbage collected.
Preserves the method's signature for introspection.
"""

def __init__(self, method):
self._func = method.__func__
self._instance_ref = weakref.ref(method.__self__)
# Store the bound method's signature (without 'self') for inspect.signature() to find.
# We capture this once at creation time to avoid holding references to the bound method.
self.__signature__ = inspect.signature(method)

def __call__(self, *args, **kwargs):
instance = self._instance_ref()
if instance is None:
raise ReferenceError("Instance has been garbage collected")
return self._func(instance, *args, **kwargs)


class _CachedFetcherMethod:
"""
Helper class for using CachedFetcher with method caches (rather than functions)
Expand All @@ -428,18 +450,21 @@ def __init__(self, method, max_size: int, cache_key_index: int):
self.method = method
self.max_size = max_size
self.cache_key_index = cache_key_index
self._instances = {}
# Use WeakKeyDictionary to avoid preventing garbage collection of instances
self._instances: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()

def __get__(self, instance, owner):
if instance is None:
return self

# Cache per-instance
# Cache per-instance (weak references allow GC when instance is no longer used)
if instance not in self._instances:
bound_method = self.method.__get__(instance, owner)
# Use weak reference wrapper to avoid preventing GC of instance
weak_method = _WeakMethod(bound_method)
self._instances[instance] = CachedFetcher(
max_size=self.max_size,
method=bound_method,
method=weak_method,
cache_key_index=self.cache_key_index,
)
return self._instances[instance]
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "async-substrate-interface"
version = "1.6.0"
version = "1.6.1"
description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down Expand Up @@ -37,6 +37,8 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3 :: Only",
]

Expand Down
Loading
Loading