Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import socket
import ssl
import warnings
from contextlib import suppress
from unittest.mock import AsyncMock
from hashlib import blake2b
from typing import (
Expand Down Expand Up @@ -40,7 +39,6 @@
from websockets.asyncio.client import connect, ClientConnection
from websockets.exceptions import (
ConnectionClosed,
WebSocketException,
)
from websockets.protocol import State

Expand Down
102 changes: 77 additions & 25 deletions async_substrate_interface/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import bisect
import logging
import os
from abc import ABC
from collections import defaultdict, deque
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Sequence

import scalecodec.types
from bt_decode import PortableRegistry, encode as encode_by_type_string
Expand All @@ -17,9 +19,11 @@

from .const import SS58_FORMAT
from .utils import json
from .utils.cache import AsyncSqliteDB
from .utils.cache import AsyncSqliteDB, LRUCache

logger = logging.getLogger("async_substrate_interface")
SUBSTRATE_RUNTIME_CACHE_SIZE = int(os.getenv("SUBSTRATE_RUNTIME_CACHE_SIZE", "16"))
SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512"))


class RuntimeCache:
Expand All @@ -41,11 +45,45 @@ class RuntimeCache:
versions: dict[int, "Runtime"]
last_used: Optional["Runtime"]

def __init__(self):
self.blocks = {}
self.block_hashes = {}
self.versions = {}
self.last_used = None
def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None):
# {block: block_hash, ...}
self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {block_hash: specVersion, ...}
self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {specVersion: Runtime, ...}
self.versions: LRUCache = LRUCache(max_size=SUBSTRATE_RUNTIME_CACHE_SIZE)
# [(block, specVersion), ...]
self.known_versions: list[tuple[int, int]] = []
# [block, ...] for binary search (excludes last item)
self._known_version_blocks: list[int] = []
if known_versions:
self.add_known_versions(known_versions)
self.last_used: Optional["Runtime"] = None

def add_known_versions(self, known_versions: Sequence[tuple[int, int]]):
"""
Known versions are a map of {block: specVersion} for when runtimes change.

E.g.
[
(561, 102),
(1075, 103),
...,
(7257645, 367)
]

This mapping is generally user-created or pulled from an external API, such as
https://api.tao.app/docs#/chain/get_runtime_versions_api_beta_chain_runtime_version_get

By preloading the known versions, there can be significantly fewer chain calls to determine version.

Note that because the last runtime in the supplied known versions will be ignored, as otherwise we would
have to assume that the final known version never changes.
"""
known_versions = list(sorted(known_versions, key=lambda v: v[0]))
self.known_versions = known_versions
# Cache block numbers (excluding last) for O(log n) binary search lookups
self._known_version_blocks = [v[0] for v in known_versions[:-1]]

def add_item(
self,
Expand All @@ -59,11 +97,11 @@ def add_item(
"""
self.last_used = runtime
if block is not None and block_hash is not None:
self.blocks[block] = block_hash
self.blocks.set(block, block_hash)
if block_hash is not None and runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
self.block_hashes.set(block_hash, runtime_version)
if runtime_version is not None:
self.versions[runtime_version] = runtime
self.versions.set(runtime_version, runtime)

def retrieve(
self,
Expand All @@ -75,26 +113,35 @@ def retrieve(
Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version.
Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`.
"""
# No reason to do this lookup if the runtime version is already supplied in this call
if block is not None and runtime_version is None and self._known_version_blocks:
# _known_version_blocks excludes the last item (see note in `add_known_versions`)
idx = bisect.bisect_right(self._known_version_blocks, block) - 1
if idx >= 0:
runtime_version = self.known_versions[idx][1]

runtime = None
if block is not None:
if block_hash is not None:
self.blocks[block] = block_hash
self.blocks.set(block, block_hash)
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[self.blocks[block]]]
self.block_hashes.set(block_hash, runtime_version)
with suppress(AttributeError):
runtime = self.versions.get(
self.block_hashes.get(self.blocks.get(block))
)
self.last_used = runtime
return runtime
if block_hash is not None:
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[block_hash]]
self.block_hashes.set(block_hash, runtime_version)
with suppress(AttributeError):
runtime = self.versions.get(self.block_hashes.get(block_hash))
self.last_used = runtime
return runtime
if runtime_version is not None:
with suppress(KeyError):
runtime = self.versions[runtime_version]
runtime = self.versions.get(runtime_version)
if runtime is not None:
self.last_used = runtime
return runtime
return runtime
Expand All @@ -110,16 +157,21 @@ async def load_from_disk(self, chain_endpoint: str):
logger.debug("No runtime mappings in disk cache")
else:
logger.debug("Found runtime mappings in disk cache")
self.blocks = block_mapping
self.block_hashes = block_hash_mapping
self.versions = {
x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items()
}
self.blocks.cache = block_mapping
self.block_hashes.cache = block_hash_mapping
for x, y in runtime_version_mapping.items():
self.versions.cache[x] = Runtime.deserialize(y)

async def dump_to_disk(self, chain_endpoint: str):
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
blocks = self.blocks.cache
block_hashes = self.block_hashes.cache
versions = self.versions.cache
await db.dump_runtime_cache(
chain_endpoint, self.blocks, self.block_hashes, self.versions
chain=chain_endpoint,
block_mapping=blocks,
block_hash_mapping=block_hashes,
version_mapping=versions,
)


Expand Down
6 changes: 3 additions & 3 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
if not self._db:
_ensure_dir()
self._db = await aiosqlite.connect(CACHE_LOCATION)
block_mapping = {}
block_hash_mapping = {}
version_mapping = {}
block_mapping = OrderedDict()
block_hash_mapping = OrderedDict()
version_mapping = OrderedDict()
tables = {
"RuntimeCache_blocks": block_mapping,
"RuntimeCache_block_hashes": block_hash_mapping,
Expand Down
14 changes: 7 additions & 7 deletions tests/unit_tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ async def test_runtime_cache_from_disk():
substrate.initialized = True

# runtime cache should be completely empty
assert substrate.runtime_cache.block_hashes == {}
assert substrate.runtime_cache.blocks == {}
assert substrate.runtime_cache.versions == {}
assert len(substrate.runtime_cache.block_hashes.cache) == 0
assert len(substrate.runtime_cache.blocks.cache) == 0
assert len(substrate.runtime_cache.versions.cache) == 0
await substrate.initialize()

# after initialization, runtime cache should still be completely empty
assert substrate.runtime_cache.block_hashes == {}
assert substrate.runtime_cache.blocks == {}
assert substrate.runtime_cache.versions == {}
assert len(substrate.runtime_cache.block_hashes.cache) == 0
assert len(substrate.runtime_cache.blocks.cache) == 0
assert len(substrate.runtime_cache.versions.cache) == 0
await substrate.close()

# ensure we have created the SQLite DB during initialize()
Expand All @@ -136,7 +136,7 @@ async def test_runtime_cache_from_disk():

substrate.initialized = True
await substrate.initialize()
assert substrate.runtime_cache.blocks == {fake_block: fake_hash}
assert substrate.runtime_cache.blocks.cache == {fake_block: fake_hash}
# add an item to the cache
substrate.runtime_cache.add_item(
runtime=None, block_hash=new_fake_hash, block=new_fake_block
Expand Down
Loading