Skip to content

Commit

Permalink
wip: channel database changes
Browse files Browse the repository at this point in the history
  • Loading branch information
M1nd3r committed Feb 17, 2025
1 parent 12a00e5 commit c26c721
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
13 changes: 9 additions & 4 deletions python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import atexit
import functools
import logging
import os
Expand All @@ -30,7 +31,7 @@
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1
from ..transport.thp.channel_database import get_channel_db
from ..transport.thp.channel_database import ChannelDatabase, get_channel_db

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,8 +103,9 @@ def get_passphrase(
raise exceptions.Cancelled from None


def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
def get_client(transport: Transport, channel_database: ChannelDatabase) -> TrezorClient:
stored_channels = channel_database.load_stored_channels()

stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
Expand All @@ -120,6 +122,7 @@ def get_client(transport: Transport) -> TrezorClient:
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
atexit.register(lambda: channel_database.save_channel(client.protocol))
return client


Expand All @@ -131,11 +134,13 @@ def __init__(
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
channel_database: ChannelDatabase,
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
self.script = script
self.channel_database = channel_database

def get_session(
self,
Expand Down Expand Up @@ -195,7 +200,7 @@ def get_transport(self) -> "Transport":
return transport.get_transport(self.path, prefix_search=True)

def get_client(self) -> TrezorClient:
return get_client(self.get_transport())
return get_client(self.get_transport(), self.channel_database)

def get_seedless_session(self) -> Session:
client = self.get_client()
Expand Down
6 changes: 4 additions & 2 deletions python/src/trezorlib/cli/trezorctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def cli_main(
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")

ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
ctx.obj = TrezorConnection(
path, bytes_session_id, passphrase_on_host, script, get_channel_db()
)

# Optionally record the screen into a specified directory.
if record:
Expand Down Expand Up @@ -305,7 +307,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:

for transport in enumerate_devices():
try:
client = get_client(transport)
client = get_client(transport, get_channel_db())
description = format_device_name(client.features)
except DeviceIsBusy:
description = "Device is in use by another process"
Expand Down

0 comments on commit c26c721

Please sign in to comment.