Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,10 @@ async def _cancel(self):
logger.debug("Cancelling send/recv tasks")
if self._send_recv_task is not None:
self._send_recv_task.cancel()
try:
await self._send_recv_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
except Exception as e:
Expand Down Expand Up @@ -775,16 +779,31 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]:
logger.debug("WS handler attached")
recv_task = asyncio.create_task(self._start_receiving(ws))
send_task = asyncio.create_task(self._start_sending(ws))
done, pending = await asyncio.wait(
[recv_task, send_task],
return_when=asyncio.FIRST_COMPLETED,
)
try:
done, pending = await asyncio.wait(
[recv_task, send_task],
return_when=asyncio.FIRST_COMPLETED,
)
except asyncio.CancelledError:
# Handler was cancelled, clean up child tasks
for task in [recv_task, send_task]:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
raise
loop = asyncio.get_running_loop()
should_reconnect = False
is_retry = False

for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

for task in done:
task_res = task.result()
Expand Down Expand Up @@ -885,6 +904,14 @@ async def _exit_with_timer(self):

async def shutdown(self):
logger.debug("Shutdown requested")
# Cancel the exit timer task if it exists
if self._exit_task is not None:
self._exit_task.cancel()
try:
await self._exit_task
except asyncio.CancelledError:
pass
self._exit_task = None
try:
await asyncio.wait_for(self._cancel(), timeout=10.0)
except asyncio.TimeoutError:
Expand Down Expand Up @@ -988,8 +1015,9 @@ async def _start_sending(self, ws) -> Exception:
)
if to_send is not None:
to_send_ = json.loads(to_send)
self._received[to_send_["id"]].set_exception(e)
self._received[to_send_["id"]].cancel()
if to_send_["id"] in self._received:
self._received[to_send_["id"]].set_exception(e)
self._received[to_send_["id"]].cancel()
else:
for i in self._received.keys():
self._received[i].set_exception(e)
Expand Down
7 changes: 7 additions & 0 deletions async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,5 +3398,12 @@ def close(self):
self.ws.close()
except AttributeError:
pass
# Clear lru_cache on instance methods to allow garbage collection
self.get_runtime_for_version.cache_clear()
self.get_parent_block_hash.cache_clear()
self.get_block_runtime_info.cache_clear()
self.get_block_runtime_version_for.cache_clear()
self.supports_rpc_method.cache_clear()
self.get_block_hash.cache_clear()

encode_scale = SubstrateMixin._encode_scale
30 changes: 27 additions & 3 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import weakref
from collections import OrderedDict
import functools
import logging
Expand Down Expand Up @@ -419,6 +420,26 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._inflight.pop(key, None)


class _WeakMethod:
"""
Weak reference to a bound method that allows the instance to be garbage collected.
Preserves the method's signature for introspection.
"""

def __init__(self, method):
self._func = method.__func__
self._instance_ref = weakref.ref(method.__self__)
# Store the bound method's signature (without 'self') for inspect.signature() to find.
# We capture this once at creation time to avoid holding references to the bound method.
self.__signature__ = inspect.signature(method)

def __call__(self, *args, **kwargs):
instance = self._instance_ref()
if instance is None:
raise ReferenceError("Instance has been garbage collected")
return self._func(instance, *args, **kwargs)


class _CachedFetcherMethod:
"""
Helper class for using CachedFetcher with method caches (rather than functions)
Expand All @@ -428,18 +449,21 @@ def __init__(self, method, max_size: int, cache_key_index: int):
self.method = method
self.max_size = max_size
self.cache_key_index = cache_key_index
self._instances = {}
# Use WeakKeyDictionary to avoid preventing garbage collection of instances
self._instances: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()

def __get__(self, instance, owner):
if instance is None:
return self

# Cache per-instance
# Cache per-instance (weak references allow GC when instance is no longer used)
if instance not in self._instances:
bound_method = self.method.__get__(instance, owner)
# Use weak reference wrapper to avoid preventing GC of instance
weak_method = _WeakMethod(bound_method)
self._instances[instance] = CachedFetcher(
max_size=self.max_size,
method=bound_method,
method=weak_method,
cache_key_index=self.cache_key_index,
)
return self._instances[instance]
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
environ.get("SUBSTRATE_AURA_NODE_URL") or "wss://acala-rpc-1.aca-api.network"
)

ARCHIVE_ENTRYPOINT = "wss://archive.chain.opentensor.ai:443"
ARCHIVE_ENTRYPOINT = "wss://archive.sub.latent.to"

LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443"
18 changes: 13 additions & 5 deletions tests/integration_tests/test_disk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
AsyncSubstrateInterface,
)
from async_substrate_interface.sync_substrate import SubstrateInterface
from tests.helpers.settings import LATENT_LITE_ENTRYPOINT


@pytest.mark.asyncio
async def test_disk_cache():
print("Testing test_disk_cache")
entrypoint = "wss://entrypoint-finney.opentensor.ai:443"
async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate:
async with DiskCachedAsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as disk_cached_substrate:
current_block = await disk_cached_substrate.get_block_number(None)
block_hash = await disk_cached_substrate.get_block_hash(current_block)
parent_block_hash = await disk_cached_substrate.get_parent_block_hash(
Expand Down Expand Up @@ -42,7 +44,9 @@ async def test_disk_cache():
assert block_runtime_info == block_runtime_info_from_cache
assert block_runtime_version_for == block_runtime_version_from_cache
# Verify data integrity with non-disk cached Async Substrate Interface
async with AsyncSubstrateInterface(entrypoint) as non_cache_substrate:
async with AsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as non_cache_substrate:
block_hash_non_cache = await non_cache_substrate.get_block_hash(current_block)
parent_block_hash_non_cache = await non_cache_substrate.get_parent_block_hash(
block_hash_non_cache
Expand All @@ -60,7 +64,9 @@ async def test_disk_cache():
assert block_runtime_info == block_runtime_info_non_cache
assert block_runtime_version_for == block_runtime_version_for_non_cache
# Verify data integrity with sync Substrate Interface
with SubstrateInterface(entrypoint) as sync_substrate:
with SubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as sync_substrate:
block_hash_sync = sync_substrate.get_block_hash(current_block)
parent_block_hash_sync = sync_substrate.get_parent_block_hash(
block_hash_non_cache
Expand All @@ -76,7 +82,9 @@ async def test_disk_cache():
assert block_runtime_info == block_runtime_info_sync
assert block_runtime_version_for == block_runtime_version_for_sync
# Verify data is pulling from disk cache
async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate:
async with DiskCachedAsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as disk_cached_substrate:
start = time.monotonic()
new_block_hash = await disk_cached_substrate.get_block_hash(current_block)
new_time = time.monotonic()
Expand Down
40 changes: 38 additions & 2 deletions tests/unit_tests/asyncio_/test_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import asyncio
import tracemalloc
from unittest.mock import AsyncMock, MagicMock, ANY

import pytest
from websockets.exceptions import InvalidURI
from websockets.protocol import State

from async_substrate_interface.async_substrate import AsyncSubstrateInterface
from async_substrate_interface.async_substrate import (
AsyncSubstrateInterface,
get_async_substrate_interface,
)
from async_substrate_interface.types import ScaleObj
from tests.helpers.settings import ARCHIVE_ENTRYPOINT
from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT


@pytest.mark.asyncio
Expand Down Expand Up @@ -139,3 +143,35 @@ async def test_runtime_switching():
assert one is not None
assert two is not None
print("test_runtime_switching succeeded")


@pytest.mark.asyncio
async def test_memory_leak():
import gc

# Stop any existing tracemalloc and start fresh
tracemalloc.stop()
tracemalloc.start()
two_mb = 2 * 1024 * 1024

# Warmup: populate caches before taking baseline
for _ in range(2):
subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT)
await subtensor.close()

baseline_snapshot = tracemalloc.take_snapshot()

for i in range(5):
subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT)
await subtensor.close()
gc.collect()

snapshot = tracemalloc.take_snapshot()
stats = snapshot.compare_to(baseline_snapshot, "lineno")
total_diff = sum(stat.size_diff for stat in stats)
current, peak = tracemalloc.get_traced_memory()
# Allow cumulative growth up to 2MB per iteration from baseline
assert total_diff < two_mb * (i + 1), (
f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, "
f"peak={peak / 1024:.2f} KiB"
)
34 changes: 33 additions & 1 deletion tests/unit_tests/sync/test_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tracemalloc
from unittest.mock import MagicMock

from async_substrate_interface.sync_substrate import SubstrateInterface
from async_substrate_interface.types import ScaleObj

from tests.helpers.settings import ARCHIVE_ENTRYPOINT
from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT


def test_runtime_call(monkeypatch):
Expand Down Expand Up @@ -90,3 +91,34 @@ def test_runtime_switching():
assert substrate.get_extrinsics(block_number=block) is not None
assert substrate.get_extrinsics(block_number=block - 21) is not None
print("test_runtime_switching succeeded")


def test_memory_leak():
import gc

# Stop any existing tracemalloc and start fresh
tracemalloc.stop()
tracemalloc.start()
two_mb = 2 * 1024 * 1024

# Warmup: populate caches before taking baseline
for _ in range(2):
subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT)
subtensor.close()

baseline_snapshot = tracemalloc.take_snapshot()

for i in range(5):
subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT)
subtensor.close()
gc.collect()

snapshot = tracemalloc.take_snapshot()
stats = snapshot.compare_to(baseline_snapshot, "lineno")
total_diff = sum(stat.size_diff for stat in stats)
current, peak = tracemalloc.get_traced_memory()
# Allow cumulative growth up to 2MB per iteration from baseline
assert total_diff < two_mb * (i + 1), (
f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, "
f"peak={peak / 1024:.2f} KiB"
)