diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 91c4ca0..a9249c2 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -11,7 +11,6 @@ import socket import ssl import warnings -from contextlib import suppress from unittest.mock import AsyncMock from hashlib import blake2b from typing import ( @@ -40,7 +39,6 @@ from websockets.asyncio.client import connect, ClientConnection from websockets.exceptions import ( ConnectionClosed, - WebSocketException, ) from websockets.protocol import State diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 8878497..842e260 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -1,11 +1,13 @@ +import bisect import logging +import os from abc import ABC from collections import defaultdict, deque from collections.abc import Iterable from contextlib import suppress from dataclasses import dataclass from datetime import datetime -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Sequence import scalecodec.types from bt_decode import PortableRegistry, encode as encode_by_type_string @@ -17,9 +19,11 @@ from .const import SS58_FORMAT from .utils import json -from .utils.cache import AsyncSqliteDB +from .utils.cache import AsyncSqliteDB, LRUCache logger = logging.getLogger("async_substrate_interface") +SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16")) +SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) class RuntimeCache: @@ -41,11 +45,45 @@ class RuntimeCache: versions: dict[int, "Runtime"] last_used: Optional["Runtime"] - def __init__(self): - self.blocks = {} - self.block_hashes = {} - self.versions = {} - self.last_used = None + def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None): + # {block: block_hash, ...} + self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {block_hash: specVersion, ...} + self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {specVersion: Runtime, ...} + self.versions: LRUCache = LRUCache(max_size=SUBSTRATE_RUNTIME_CACHE_SIZE) + # [(block, specVersion), ...] + self.known_versions: list[tuple[int, int]] = [] + # [block, ...] for binary search (excludes last item) + self._known_version_blocks: list[int] = [] + if known_versions: + self.add_known_versions(known_versions) + self.last_used: Optional["Runtime"] = None + + def add_known_versions(self, known_versions: Sequence[tuple[int, int]]): + """ + Known versions are a map of {block: specVersion} for when runtimes change. + + E.g. + [ + (561, 102), + (1075, 103), + ..., + (7257645, 367) + ] + + This mapping is generally user-created or pulled from an external API, such as + https://api.tao.app/docs#/chain/get_runtime_versions_api_beta_chain_runtime_version_get + + By preloading the known versions, there can be significantly fewer chain calls to determine version. + + Note that because the last runtime in the supplied known versions will be ignored, as otherwise we would + have to assume that the final known version never changes. + """ + known_versions = list(sorted(known_versions, key=lambda v: v[0])) + self.known_versions = known_versions + # Cache block numbers (excluding last) for O(log n) binary search lookups + self._known_version_blocks = [v[0] for v in known_versions[:-1]] def add_item( self, @@ -59,11 +97,11 @@ def add_item( """ self.last_used = runtime if block is not None and block_hash is not None: - self.blocks[block] = block_hash + self.blocks.set(block, block_hash) if block_hash is not None and runtime_version is not None: - self.block_hashes[block_hash] = runtime_version + self.block_hashes.set(block_hash, runtime_version) if runtime_version is not None: - self.versions[runtime_version] = runtime + self.versions.set(runtime_version, runtime) def retrieve( self, @@ -75,26 +113,35 @@ def retrieve( Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version. Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ + # No reason to do this lookup if the runtime version is already supplied in this call + if block is not None and runtime_version is None and self._known_version_blocks: + # _known_version_blocks excludes the last item (see note in `add_known_versions`) + idx = bisect.bisect_right(self._known_version_blocks, block) - 1 + if idx >= 0: + runtime_version = self.known_versions[idx][1] + runtime = None if block is not None: if block_hash is not None: - self.blocks[block] = block_hash + self.blocks.set(block, block_hash) if runtime_version is not None: - self.block_hashes[block_hash] = runtime_version - with suppress(KeyError): - runtime = self.versions[self.block_hashes[self.blocks[block]]] + self.block_hashes.set(block_hash, runtime_version) + with suppress(AttributeError): + runtime = self.versions.get( + self.block_hashes.get(self.blocks.get(block)) + ) self.last_used = runtime return runtime if block_hash is not None: if runtime_version is not None: - self.block_hashes[block_hash] = runtime_version - with suppress(KeyError): - runtime = self.versions[self.block_hashes[block_hash]] + self.block_hashes.set(block_hash, runtime_version) + with suppress(AttributeError): + runtime = self.versions.get(self.block_hashes.get(block_hash)) self.last_used = runtime return runtime if runtime_version is not None: - with suppress(KeyError): - runtime = self.versions[runtime_version] + runtime = self.versions.get(runtime_version) + if runtime is not None: self.last_used = runtime return runtime return runtime @@ -110,16 +157,21 @@ async def load_from_disk(self, chain_endpoint: str): logger.debug("No runtime mappings in disk cache") else: logger.debug("Found runtime mappings in disk cache") - self.blocks = block_mapping - self.block_hashes = block_hash_mapping - self.versions = { - x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items() - } + self.blocks.cache = block_mapping + self.block_hashes.cache = block_hash_mapping + for x, y in runtime_version_mapping.items(): + self.versions.cache[x] = Runtime.deserialize(y) async def dump_to_disk(self, chain_endpoint: str): db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + blocks = self.blocks.cache + block_hashes = self.block_hashes.cache + versions = self.versions.cache await db.dump_runtime_cache( - chain_endpoint, self.blocks, self.block_hashes, self.versions + chain=chain_endpoint, + block_mapping=blocks, + block_hash_mapping=block_hashes, + version_mapping=versions, ) diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 24c609c..1ee82b3 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -111,9 +111,9 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: if not self._db: _ensure_dir() self._db = await aiosqlite.connect(CACHE_LOCATION) - block_mapping = {} - block_hash_mapping = {} - version_mapping = {} + block_mapping = OrderedDict() + block_hash_mapping = OrderedDict() + version_mapping = OrderedDict() tables = { "RuntimeCache_blocks": block_mapping, "RuntimeCache_block_hashes": block_hash_mapping, diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/test_types.py index f2e13b4..928d809 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/test_types.py @@ -111,15 +111,15 @@ async def test_runtime_cache_from_disk(): substrate.initialized = True # runtime cache should be completely empty - assert substrate.runtime_cache.block_hashes == {} - assert substrate.runtime_cache.blocks == {} - assert substrate.runtime_cache.versions == {} + assert len(substrate.runtime_cache.block_hashes.cache) == 0 + assert len(substrate.runtime_cache.blocks.cache) == 0 + assert len(substrate.runtime_cache.versions.cache) == 0 await substrate.initialize() # after initialization, runtime cache should still be completely empty - assert substrate.runtime_cache.block_hashes == {} - assert substrate.runtime_cache.blocks == {} - assert substrate.runtime_cache.versions == {} + assert len(substrate.runtime_cache.block_hashes.cache) == 0 + assert len(substrate.runtime_cache.blocks.cache) == 0 + assert len(substrate.runtime_cache.versions.cache) == 0 await substrate.close() # ensure we have created the SQLite DB during initialize() @@ -136,7 +136,7 @@ async def test_runtime_cache_from_disk(): substrate.initialized = True await substrate.initialize() - assert substrate.runtime_cache.blocks == {fake_block: fake_hash} + assert substrate.runtime_cache.blocks.cache == {fake_block: fake_hash} # add an item to the cache substrate.runtime_cache.add_item( runtime=None, block_hash=new_fake_hash, block=new_fake_block