Skip to content

[5/7]: multi: thread ChannelUpdate through codebase #8254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: elle-g175-thread-interfaces-2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions channeldb/models/channel_edge_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,36 @@ func (c *ChannelEdgePolicy2) GetToNode() [33]byte {
// A compile-time check to ensure that ChannelEdgePolicy2 implements the
// ChannelEdgePolicy interface.
var _ ChannelEdgePolicy = (*ChannelEdgePolicy2)(nil)

// EdgePolicyFromUpdate converts the given lnwire.ChannelUpdate into the
// corresponding ChannelEdgePolicy type.
func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) (
ChannelEdgePolicy, error) {

switch upd := update.(type) {
case *lnwire.ChannelUpdate1:
//nolint:lll
return &ChannelEdgePolicy1{
SigBytes: upd.Signature.ToSignatureBytes(),
ChannelID: upd.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(upd.Timestamp), 0),
MessageFlags: upd.MessageFlags,
ChannelFlags: upd.ChannelFlags,
TimeLockDelta: upd.TimeLockDelta,
MinHTLC: upd.HtlcMinimumMsat,
MaxHTLC: upd.HtlcMaximumMsat,
FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate),
ExtraOpaqueData: upd.ExtraOpaqueData,
}, nil

case *lnwire.ChannelUpdate2:
return &ChannelEdgePolicy2{
ChannelUpdate2: *upd,
}, nil

default:
return nil, fmt.Errorf("unhandled implementation of "+
"lnwire.ChannelUpdate: %T", update)
}
}
56 changes: 31 additions & 25 deletions discovery/gossiper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type mockGraphSource struct {
mu sync.Mutex
nodes []channeldb.LightningNode
infos map[uint64]models.ChannelEdgeInfo
edges map[uint64][]models.ChannelEdgePolicy1
edges map[uint64][]models.ChannelEdgePolicy
zombies map[uint64][][33]byte
chansToReject map[uint64]struct{}
addEdgeErrCode fn.Option[graph.ErrorCode]
Expand All @@ -103,7 +103,7 @@ func newMockRouter(height uint32) *mockGraphSource {
return &mockGraphSource{
bestHeight: height,
infos: make(map[uint64]models.ChannelEdgeInfo),
edges: make(map[uint64][]models.ChannelEdgePolicy1),
edges: make(map[uint64][]models.ChannelEdgePolicy),
zombies: make(map[uint64][][33]byte),
chansToReject: make(map[uint64]struct{}),
}
Expand Down Expand Up @@ -161,20 +161,22 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) {
r.chansToReject[chanID] = struct{}{}
}

func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy1,
func (r *mockGraphSource) UpdateEdge(edge models.ChannelEdgePolicy,
_ ...batch.SchedulerOption) error {

r.mu.Lock()
defer r.mu.Unlock()

if len(r.edges[edge.ChannelID]) == 0 {
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy1, 2)
chanID := edge.SCID().ToUint64()

if len(r.edges[chanID]) == 0 {
r.edges[chanID] = make([]models.ChannelEdgePolicy, 2)
}

if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
r.edges[edge.ChannelID][0] = *edge
if edge.IsNode1() {
r.edges[chanID][0] = edge
} else {
r.edges[edge.ChannelID][1] = *edge
r.edges[chanID][1] = edge
}

return nil
Expand Down Expand Up @@ -218,7 +220,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,

r.mu.Lock()
defer r.mu.Unlock()

chans := make(map[uint64]channeldb.ChannelEdge)
for _, info := range r.infos {
info := info
Expand All @@ -230,17 +231,16 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
for _, edges := range r.edges {
edges := edges

edge := chans[edges[0].ChannelID]
edge.Policy1 = &edges[0]
chans[edges[0].ChannelID] = edge
edge := chans[edges[0].SCID().ToUint64()]
edge.Policy1 = edges[0]
chans[edges[0].SCID().ToUint64()] = edge
}

for _, channel := range chans {
if err := cb(nil, channel.Info, channel.Policy1); err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -271,14 +271,14 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
return chanInfo, nil, nil, nil
}

var edge1 *models.ChannelEdgePolicy1
var edge1 models.ChannelEdgePolicy
if !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}) {
edge1 = &edges[0]
edge1 = edges[0]
}

var edge2 *models.ChannelEdgePolicy1
var edge2 models.ChannelEdgePolicy
if !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}) {
edge2 = &edges[1]
edge2 = edges[1]
}

return chanInfo, edge1, edge2, nil
Expand Down Expand Up @@ -379,15 +379,21 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
}

switch {
case flags&lnwire.ChanUpdateDirection == 0 &&
!reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}):

return !timestamp.After(edges[0].LastUpdate)

case flags&lnwire.ChanUpdateDirection == 1 &&
!reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}):
case flags&lnwire.ChanUpdateDirection == 0 && edges[0] != nil:
switch edge := edges[0].(type) {
case *models.ChannelEdgePolicy1:
return !timestamp.After(edge.LastUpdate)
default:
panic(fmt.Sprintf("unhandled: %T", edges[0]))
}

return !timestamp.After(edges[1].LastUpdate)
case flags&lnwire.ChanUpdateDirection == 1 && edges[1] != nil:
switch edge := edges[1].(type) {
case *models.ChannelEdgePolicy1:
return !timestamp.After(edge.LastUpdate)
default:
panic(fmt.Sprintf("unhandled: %T", edges[1]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the chan upd2 case?

}

default:
return false
Expand Down
40 changes: 13 additions & 27 deletions graph/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1512,49 +1512,35 @@ type routingMsg struct {

// ApplyChannelUpdate validates a channel update and if valid, applies it to the
// database. It returns a bool indicating whether the updates were successful.
func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool {
ch, _, _, err := b.GetChannelByID(msg.ShortChannelID)
func (b *Builder) ApplyChannelUpdate(msg lnwire.ChannelUpdate) bool {
ch, _, _, err := b.GetChannelByID(msg.SCID())
if err != nil {
log.Errorf("Unable to retrieve channel by id: %v", err)
return false
}

var pubKey *btcec.PublicKey

switch msg.ChannelFlags & lnwire.ChanUpdateDirection {
case 0:
if msg.IsNode1() {
pubKey, _ = ch.NodeKey1()

case 1:
} else {
pubKey, _ = ch.NodeKey2()
}

// Exit early if the pubkey cannot be decided.
if pubKey == nil {
log.Errorf("Unable to decide pubkey with ChannelFlags=%v",
msg.ChannelFlags)
err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg)
if err != nil {
log.Errorf("Unable to validate channel update: %v", err)
return false
}

err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg)
edgePolicy, err := models.EdgePolicyFromUpdate(msg)
if err != nil {
log.Errorf("Unable to validate channel update: %v", err)
log.Errorf("Unable to convert update message to edge "+
"policy: %v", err)

return false
}

err = b.UpdateEdge(&models.ChannelEdgePolicy1{
SigBytes: msg.Signature.ToSignatureBytes(),
ChannelID: msg.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(msg.Timestamp), 0),
MessageFlags: msg.MessageFlags,
ChannelFlags: msg.ChannelFlags,
TimeLockDelta: msg.TimeLockDelta,
MinHTLC: msg.HtlcMinimumMsat,
MaxHTLC: msg.HtlcMaximumMsat,
FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate),
ExtraOpaqueData: msg.ExtraOpaqueData,
})
err = b.UpdateEdge(edgePolicy)
if err != nil && !IsError(err, ErrIgnored, ErrOutdated) {
log.Errorf("Unable to apply channel update: %v", err)
return false
Expand Down Expand Up @@ -1621,7 +1607,7 @@ func (b *Builder) AddEdge(edge models.ChannelEdgeInfo,
// considered as not fully constructed.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy1,
func (b *Builder) UpdateEdge(update models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error {

rMsg := &routingMsg{
Expand Down
2 changes: 1 addition & 1 deletion graph/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type ChannelGraphSource interface {

// UpdateEdge is used to update edge information, without this message
// edge considered as not fully constructed.
UpdateEdge(policy *models.ChannelEdgePolicy1,
UpdateEdge(policy models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error

// IsStaleNode returns true if the graph source has a node announcement
Expand Down
4 changes: 2 additions & 2 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
return r, nil
}

func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1,
func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ lnwire.ChannelUpdate,
_ *btcec.PublicKey, _ *models.CachedEdgePolicy) bool {

return false
Expand Down Expand Up @@ -702,7 +702,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
return args.Get(0).(*route.Route), args.Error(1)
}

func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
func (m *mockPaymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {

args := m.Called(msg, pubKey, policy)
Expand Down
5 changes: 3 additions & 2 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route,
// SendToRoute where there's no payment lifecycle.
if p.paySession != nil {
policy = p.paySession.GetAdditionalEdgePolicy(
errSource, update.ShortChannelID.ToUint64(),
errSource, update.SCID().ToUint64(),
)
if policy != nil {
isAdditionalEdge = true
Expand All @@ -907,7 +907,8 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route,
// Apply channel update to additional edge policy.
if isAdditionalEdge {
if !p.paySession.UpdateAdditionalEdge(
update, errSource, policy) {
update, errSource, policy,
) {

log.Debugf("Invalid channel update received: node=%v",
errVertex)
Expand Down
15 changes: 9 additions & 6 deletions routing/payment_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ type PaymentSession interface {
// (private channels) and applies the update from the message. Returns
// a boolean to indicate whether the update has been applied without
// error.
UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool
UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey,
policy *models.CachedEdgePolicy) bool

// GetAdditionalEdgePolicy uses the public key and channel ID to query
// the ephemeral channel edge policy for additional edges. Returns a nil
Expand Down Expand Up @@ -431,7 +432,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// validates the message signature and checks it's up to date, then applies the
// updates to the supplied policy. It returns a boolean to indicate whether
// there's an error when applying the updates.
func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
func (p *paymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {

// Validate the message signature.
Expand All @@ -442,10 +443,12 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
return false
}

fwdingPolicy := msg.ForwardingPolicy()

// Update channel policy for the additional edge.
policy.TimeLockDelta = msg.TimeLockDelta
policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee)
policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate)
policy.TimeLockDelta = fwdingPolicy.TimeLockDelta
policy.FeeBaseMSat = fwdingPolicy.BaseFee
policy.FeeProportionalMillionths = fwdingPolicy.FeeRate

log.Debugf("New private channel update applied: %v",
lnutils.SpewLogClosure(msg))
Expand Down
2 changes: 1 addition & 1 deletion routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ type Config struct {

// ApplyChannelUpdate can be called to apply a new channel update to the
// graph that we received from a payment failure.
ApplyChannelUpdate func(msg *lnwire.ChannelUpdate1) bool
ApplyChannelUpdate func(msg lnwire.ChannelUpdate) bool

// ClosedSCIDs is used by the router to fetch closed channels.
//
Expand Down