diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index a9249c2..97aa604 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -706,6 +706,10 @@ async def _cancel(self): logger.debug("Cancelling send/recv tasks") if self._send_recv_task is not None: self._send_recv_task.cancel() + try: + await self._send_recv_task + except asyncio.CancelledError: + pass except asyncio.CancelledError: pass except Exception as e: @@ -775,16 +779,31 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]: logger.debug("WS handler attached") recv_task = asyncio.create_task(self._start_receiving(ws)) send_task = asyncio.create_task(self._start_sending(ws)) - done, pending = await asyncio.wait( - [recv_task, send_task], - return_when=asyncio.FIRST_COMPLETED, - ) + try: + done, pending = await asyncio.wait( + [recv_task, send_task], + return_when=asyncio.FIRST_COMPLETED, + ) + except asyncio.CancelledError: + # Handler was cancelled, clean up child tasks + for task in [recv_task, send_task]: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise loop = asyncio.get_running_loop() should_reconnect = False is_retry = False for task in pending: task.cancel() + try: + await task + except asyncio.CancelledError: + pass for task in done: task_res = task.result() @@ -885,6 +904,14 @@ async def _exit_with_timer(self): async def shutdown(self): logger.debug("Shutdown requested") + # Cancel the exit timer task if it exists + if self._exit_task is not None: + self._exit_task.cancel() + try: + await self._exit_task + except asyncio.CancelledError: + pass + self._exit_task = None try: await asyncio.wait_for(self._cancel(), timeout=10.0) except asyncio.TimeoutError: @@ -988,8 +1015,9 @@ async def _start_sending(self, ws) -> Exception: ) if to_send is not None: to_send_ = json.loads(to_send) - self._received[to_send_["id"]].set_exception(e) - self._received[to_send_["id"]].cancel() + if to_send_["id"] in self._received: + self._received[to_send_["id"]].set_exception(e) + self._received[to_send_["id"]].cancel() else: for i in self._received.keys(): self._received[i].set_exception(e) diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 8ddd90b..5b6db72 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -3398,5 +3398,12 @@ def close(self): self.ws.close() except AttributeError: pass + # Clear lru_cache on instance methods to allow garbage collection + self.get_runtime_for_version.cache_clear() + self.get_parent_block_hash.cache_clear() + self.get_block_runtime_info.cache_clear() + self.get_block_runtime_version_for.cache_clear() + self.supports_rpc_method.cache_clear() + self.get_block_hash.cache_clear() encode_scale = SubstrateMixin._encode_scale diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 1ee82b3..1b235e2 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,5 +1,6 @@ import asyncio import inspect +import weakref from collections import OrderedDict import functools import logging @@ -419,6 +420,26 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: self._inflight.pop(key, None) +class _WeakMethod: + """ + Weak reference to a bound method that allows the instance to be garbage collected. + Preserves the method's signature for introspection. + """ + + def __init__(self, method): + self._func = method.__func__ + self._instance_ref = weakref.ref(method.__self__) + # Store the bound method's signature (without 'self') for inspect.signature() to find. + # We capture this once at creation time to avoid holding references to the bound method. + self.__signature__ = inspect.signature(method) + + def __call__(self, *args, **kwargs): + instance = self._instance_ref() + if instance is None: + raise ReferenceError("Instance has been garbage collected") + return self._func(instance, *args, **kwargs) + + class _CachedFetcherMethod: """ Helper class for using CachedFetcher with method caches (rather than functions) @@ -428,18 +449,21 @@ def __init__(self, method, max_size: int, cache_key_index: int): self.method = method self.max_size = max_size self.cache_key_index = cache_key_index - self._instances = {} + # Use WeakKeyDictionary to avoid preventing garbage collection of instances + self._instances: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def __get__(self, instance, owner): if instance is None: return self - # Cache per-instance + # Cache per-instance (weak references allow GC when instance is no longer used) if instance not in self._instances: bound_method = self.method.__get__(instance, owner) + # Use weak reference wrapper to avoid preventing GC of instance + weak_method = _WeakMethod(bound_method) self._instances[instance] = CachedFetcher( max_size=self.max_size, - method=bound_method, + method=weak_method, cache_key_index=self.cache_key_index, ) return self._instances[instance] diff --git a/tests/helpers/settings.py b/tests/helpers/settings.py index 0e9e1da..ff6b3f7 100644 --- a/tests/helpers/settings.py +++ b/tests/helpers/settings.py @@ -33,6 +33,6 @@ environ.get("SUBSTRATE_AURA_NODE_URL") or "wss://acala-rpc-1.aca-api.network" ) -ARCHIVE_ENTRYPOINT = "wss://archive.chain.opentensor.ai:443" +ARCHIVE_ENTRYPOINT = "wss://archive.sub.latent.to" LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py index cdebcc6..063eca1 100644 --- a/tests/integration_tests/test_disk_cache.py +++ b/tests/integration_tests/test_disk_cache.py @@ -5,13 +5,15 @@ AsyncSubstrateInterface, ) from async_substrate_interface.sync_substrate import SubstrateInterface +from tests.helpers.settings import LATENT_LITE_ENTRYPOINT @pytest.mark.asyncio async def test_disk_cache(): print("Testing test_disk_cache") - entrypoint = "wss://entrypoint-finney.opentensor.ai:443" - async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + async with DiskCachedAsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as disk_cached_substrate: current_block = await disk_cached_substrate.get_block_number(None) block_hash = await disk_cached_substrate.get_block_hash(current_block) parent_block_hash = await disk_cached_substrate.get_parent_block_hash( @@ -42,7 +44,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_from_cache assert block_runtime_version_for == block_runtime_version_from_cache # Verify data integrity with non-disk cached Async Substrate Interface - async with AsyncSubstrateInterface(entrypoint) as non_cache_substrate: + async with AsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as non_cache_substrate: block_hash_non_cache = await non_cache_substrate.get_block_hash(current_block) parent_block_hash_non_cache = await non_cache_substrate.get_parent_block_hash( block_hash_non_cache @@ -60,7 +64,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_non_cache assert block_runtime_version_for == block_runtime_version_for_non_cache # Verify data integrity with sync Substrate Interface - with SubstrateInterface(entrypoint) as sync_substrate: + with SubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as sync_substrate: block_hash_sync = sync_substrate.get_block_hash(current_block) parent_block_hash_sync = sync_substrate.get_parent_block_hash( block_hash_non_cache @@ -76,7 +82,9 @@ async def test_disk_cache(): assert block_runtime_info == block_runtime_info_sync assert block_runtime_version_for == block_runtime_version_for_sync # Verify data is pulling from disk cache - async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + async with DiskCachedAsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as disk_cached_substrate: start = time.monotonic() new_block_hash = await disk_cached_substrate.get_block_hash(current_block) new_time = time.monotonic() diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index 1253e6c..721804b 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -1,13 +1,17 @@ import asyncio +import tracemalloc from unittest.mock import AsyncMock, MagicMock, ANY import pytest from websockets.exceptions import InvalidURI from websockets.protocol import State -from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.async_substrate import ( + AsyncSubstrateInterface, + get_async_substrate_interface, +) from async_substrate_interface.types import ScaleObj -from tests.helpers.settings import ARCHIVE_ENTRYPOINT +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT @pytest.mark.asyncio @@ -139,3 +143,35 @@ async def test_runtime_switching(): assert one is not None assert two is not None print("test_runtime_switching succeeded") + + +@pytest.mark.asyncio +async def test_memory_leak(): + import gc + + # Stop any existing tracemalloc and start fresh + tracemalloc.stop() + tracemalloc.start() + two_mb = 2 * 1024 * 1024 + + # Warmup: populate caches before taking baseline + for _ in range(2): + subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT) + await subtensor.close() + + baseline_snapshot = tracemalloc.take_snapshot() + + for i in range(5): + subtensor = await get_async_substrate_interface(LATENT_LITE_ENTRYPOINT) + await subtensor.close() + gc.collect() + + snapshot = tracemalloc.take_snapshot() + stats = snapshot.compare_to(baseline_snapshot, "lineno") + total_diff = sum(stat.size_diff for stat in stats) + current, peak = tracemalloc.get_traced_memory() + # Allow cumulative growth up to 2MB per iteration from baseline + assert total_diff < two_mb * (i + 1), ( + f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " + f"peak={peak / 1024:.2f} KiB" + ) diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 68f51b4..54a5b7d 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -1,9 +1,10 @@ +import tracemalloc from unittest.mock import MagicMock from async_substrate_interface.sync_substrate import SubstrateInterface from async_substrate_interface.types import ScaleObj -from tests.helpers.settings import ARCHIVE_ENTRYPOINT +from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT def test_runtime_call(monkeypatch): @@ -90,3 +91,34 @@ def test_runtime_switching(): assert substrate.get_extrinsics(block_number=block) is not None assert substrate.get_extrinsics(block_number=block - 21) is not None print("test_runtime_switching succeeded") + + +def test_memory_leak(): + import gc + + # Stop any existing tracemalloc and start fresh + tracemalloc.stop() + tracemalloc.start() + two_mb = 2 * 1024 * 1024 + + # Warmup: populate caches before taking baseline + for _ in range(2): + subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT) + subtensor.close() + + baseline_snapshot = tracemalloc.take_snapshot() + + for i in range(5): + subtensor = SubstrateInterface(LATENT_LITE_ENTRYPOINT) + subtensor.close() + gc.collect() + + snapshot = tracemalloc.take_snapshot() + stats = snapshot.compare_to(baseline_snapshot, "lineno") + total_diff = sum(stat.size_diff for stat in stats) + current, peak = tracemalloc.get_traced_memory() + # Allow cumulative growth up to 2MB per iteration from baseline + assert total_diff < two_mb * (i + 1), ( + f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " + f"peak={peak / 1024:.2f} KiB" + )