Skip to content

Commit 6f1099c

Browse files
committed
lnworker: pass PaySession to route building methods.
Simplifies method signatures and groups quite a large set of parameters that are mostly passed down along the call stack but functionally belong together. Also removes bolt11 specific parameters in method signatures, e.g. r_tags.
1 parent 1e4728b commit 6f1099c

File tree

4 files changed

+193
-228
lines changed

4 files changed

+193
-228
lines changed

electrum/lnrouter.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@
2222
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
2323
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2424
# SOFTWARE.
25-
25+
import asyncio
2626
import queue
2727
from collections import defaultdict
28-
from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set, Callable
28+
from typing import Sequence, Tuple, Optional, Dict, TYPE_CHECKING, Set, Callable, NamedTuple
2929
import time
3030
import threading
3131
from threading import RLock
3232
from math import inf
3333

3434
import attr
3535

36+
from .crypto import sha256
37+
from .lnonion import OnionRoutingFailure, OnionFailureCode
3638
from .util import profiler, with_lock
3739
from .logging import Logger
3840
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
39-
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, PaymentFeeBudget)
41+
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, PaymentFeeBudget, HtlcLog, PaymentFailure)
4042
from .channel_db import ChannelDB, Policy, NodeInfo
4143

4244
if TYPE_CHECKING:
@@ -719,3 +721,152 @@ def find_route(
719721
route = self.create_route_from_path(
720722
path, my_channels=my_sending_channels, private_route_edges=private_route_edges)
721723
return route
724+
725+
726+
class SentHtlcInfo(NamedTuple):
727+
route: LNPaymentRoute
728+
payment_secret_orig: bytes
729+
payment_secret_bucket: bytes
730+
amount_msat: int
731+
bucket_msat: int
732+
amount_receiver_msat: int
733+
trampoline_fee_level: Optional[int]
734+
trampoline_route: Optional[LNPaymentRoute]
735+
736+
737+
class PaySession(Logger):
738+
def __init__(
739+
self,
740+
*,
741+
payment_hash: bytes,
742+
payment_secret: bytes,
743+
initial_trampoline_fee_level: int,
744+
invoice_features: int,
745+
r_tags,
746+
min_final_cltv_delta: int, # delta for last node (typically from invoice)
747+
amount_to_pay: int, # total payment amount final receiver will get
748+
invoice_pubkey: bytes,
749+
uses_trampoline: bool, # whether sender uses trampoline or gossip
750+
use_two_trampolines: bool, # whether legacy payments will try to use two trampolines
751+
):
752+
assert payment_hash
753+
assert payment_secret
754+
self.payment_hash = payment_hash
755+
self.payment_secret = payment_secret
756+
self.payment_key = payment_hash + payment_secret
757+
Logger.__init__(self)
758+
759+
self.invoice_features = LnFeatures(invoice_features)
760+
self.r_tags = r_tags
761+
self.min_final_cltv_delta = min_final_cltv_delta
762+
self.amount_to_pay = amount_to_pay
763+
self.invoice_pubkey = invoice_pubkey
764+
765+
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
766+
self.start_time = time.time()
767+
768+
self.uses_trampoline = uses_trampoline
769+
self.trampoline_fee_level = initial_trampoline_fee_level
770+
self.failed_trampoline_routes = []
771+
self.use_two_trampolines = use_two_trampolines
772+
self._sent_buckets = dict() # psecret_bucket -> (amount_sent, amount_failed)
773+
774+
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
775+
self._nhtlcs_inflight = 0
776+
self.is_active = True # is still trying to send new htlcs?
777+
778+
def diagnostic_name(self):
779+
pkey = sha256(self.payment_key)
780+
return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
781+
782+
def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
783+
if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
784+
self.trampoline_fee_level += 1
785+
self.failed_trampoline_routes = []
786+
self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
787+
else:
788+
self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
789+
790+
def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
791+
# FIXME The trampoline nodes in the path are chosen randomly.
792+
# Some of the errors might depend on how we have chosen them.
793+
# Having more attempts is currently useful in part because of the randomness,
794+
# instead we should give feedback to create_routes_for_payment.
795+
# Sometimes the trampoline node fails to send a payment and returns
796+
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
797+
if failure_msg.code in (
798+
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
799+
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
800+
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
801+
# TODO: parse the node policy here (not returned by eclair yet)
802+
# TODO: erring node is always the first trampoline even if second
803+
# trampoline demands more fees, we can't influence this
804+
self.maybe_raise_trampoline_fee(htlc_log)
805+
elif self.use_two_trampolines:
806+
self.use_two_trampolines = False
807+
elif failure_msg.code in (
808+
OnionFailureCode.UNKNOWN_NEXT_PEER,
809+
OnionFailureCode.TEMPORARY_NODE_FAILURE):
810+
trampoline_route = htlc_log.route
811+
r = [hop.end_node.hex() for hop in trampoline_route]
812+
self.logger.info(f'failed trampoline route: {r}')
813+
if r not in self.failed_trampoline_routes:
814+
self.failed_trampoline_routes.append(r)
815+
else:
816+
pass # maybe the route was reused between different MPP parts
817+
else:
818+
raise PaymentFailure(failure_msg.code_name())
819+
820+
async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
821+
self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
822+
htlc_log = await self.sent_htlcs_q.get()
823+
self._amount_inflight -= htlc_log.amount_msat
824+
self._nhtlcs_inflight -= 1
825+
if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
826+
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
827+
return htlc_log
828+
829+
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
830+
self._nhtlcs_inflight += 1
831+
self._amount_inflight += sent_htlc_info.amount_receiver_msat
832+
if self._amount_inflight > self.amount_to_pay: # safety belts
833+
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
834+
shi = sent_htlc_info
835+
bkey = shi.payment_secret_bucket
836+
# if we sent MPP to a trampoline, add item to sent_buckets
837+
if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
838+
if bkey not in self._sent_buckets:
839+
self._sent_buckets[bkey] = (0, 0)
840+
amount_sent, amount_failed = self._sent_buckets[bkey]
841+
amount_sent += shi.amount_receiver_msat
842+
self._sent_buckets[bkey] = amount_sent, amount_failed
843+
844+
def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
845+
shi = sent_htlc_info
846+
# check sent_buckets if we use trampoline
847+
bkey = shi.payment_secret_bucket
848+
if self.uses_trampoline and bkey in self._sent_buckets:
849+
amount_sent, amount_failed = self._sent_buckets[bkey]
850+
amount_failed += shi.amount_receiver_msat
851+
self._sent_buckets[bkey] = amount_sent, amount_failed
852+
if amount_sent != amount_failed:
853+
self.logger.info('bucket still active...')
854+
return None
855+
self.logger.info('bucket failed')
856+
return amount_sent
857+
# not using trampoline buckets
858+
return shi.amount_receiver_msat
859+
860+
def get_outstanding_amount_to_send(self) -> int:
861+
return self.amount_to_pay - self._amount_inflight
862+
863+
def can_be_deleted(self) -> bool:
864+
"""Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
865+
if self.is_active:
866+
return False
867+
# note: no one is consuming from sent_htlcs_q anymore
868+
nhtlcs_resolved = self.sent_htlcs_q.qsize()
869+
assert nhtlcs_resolved <= self._nhtlcs_inflight
870+
return nhtlcs_resolved == self._nhtlcs_inflight
871+
872+

0 commit comments

Comments
 (0)