diff --git a/channeldb/models/channel_edge_policy.go b/channeldb/models/channel_edge_policy.go index eece025ffe8..ac7bf651ee7 100644 --- a/channeldb/models/channel_edge_policy.go +++ b/channeldb/models/channel_edge_policy.go @@ -256,3 +256,36 @@ func (c *ChannelEdgePolicy2) GetToNode() [33]byte { // A compile-time check to ensure that ChannelEdgePolicy2 implements the // ChannelEdgePolicy interface. var _ ChannelEdgePolicy = (*ChannelEdgePolicy2)(nil) + +// EdgePolicyFromUpdate converts the given lnwire.ChannelUpdate into the +// corresponding ChannelEdgePolicy type. +func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) ( + ChannelEdgePolicy, error) { + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + //nolint:lll + return &ChannelEdgePolicy1{ + SigBytes: upd.Signature.ToSignatureBytes(), + ChannelID: upd.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(upd.Timestamp), 0), + MessageFlags: upd.MessageFlags, + ChannelFlags: upd.ChannelFlags, + TimeLockDelta: upd.TimeLockDelta, + MinHTLC: upd.HtlcMinimumMsat, + MaxHTLC: upd.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil + + case *lnwire.ChannelUpdate2: + return &ChannelEdgePolicy2{ + ChannelUpdate2: *upd, + }, nil + + default: + return nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) + } +} diff --git a/discovery/chan_series.go b/discovery/chan_series.go index d91396a541b..93a37f0701c 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -61,7 +61,7 @@ type ChannelGraphTimeSeries interface { // specified short channel ID. If no channel updates are known for the // channel, then an empty slice will be returned. FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) } @@ -332,7 +332,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) { chanInfo, e1, e2, err := c.graph.FetchChannelEdgesByID( shortChanID.ToUint64(), @@ -341,7 +341,7 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, return nil, err } - chanUpdates := make([]*lnwire.ChannelUpdate1, 0, 2) + chanUpdates := make([]lnwire.ChannelUpdate, 0, 2) if e1 != nil { chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 4a7929614dc..842d58681e1 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -9,7 +9,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -262,14 +261,14 @@ type Config struct { // use to determine which messages need to be resent for a given peer. MessageStore GossipMessageStore - // AnnSigner is an instance of the MessageSigner interface which will - // be used to manually sign any outgoing channel updates. The signer - // implementation should be backed by the public key of the backing - // Lightning node. + // AnnSigner is an instance of the MessageSignerRing interface which + // will be used to manually sign any outgoing channel updates. The + // signer implementation should be backed by the public key of the + // backing Lightning node. // // TODO(roasbeef): extract ann crafting + sign from fundingMgr into // here? - AnnSigner lnwallet.MessageSigner + AnnSigner keychain.MessageSignerRing // ScidCloser is an instance of ClosedChannelTracker that helps the // gossiper cut down on spam channel announcements for already closed @@ -337,8 +336,7 @@ type Config struct { // SignAliasUpdate is used to re-sign a channel update using the // remote's alias if the option-scid-alias feature bit was negotiated. - SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) + SignAliasUpdate func(u lnwire.ChannelUpdate) error // FindBaseByAlias finds the SCID stored in the graph by an alias SCID. // This is used for channels that have negotiated the option-scid-alias @@ -941,10 +939,9 @@ type channelUpdateID struct { // retrieve all necessary data to validate the channel existence. channelID lnwire.ShortChannelID - // Flags least-significant bit must be set to 0 if the creating node - // corresponds to the first node in the previously sent channel - // announcement and 1 otherwise. - flags lnwire.ChanUpdateChanFlags + disabled bool + + direction bool } // msgWithSenders is a wrapper struct around a message, and the set of peers @@ -1053,32 +1050,49 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // Channel updates are identified by the (short channel id, // channelflags) tuple. - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: sender := route.NewVertex(message.source) deDupKey := channelUpdateID{ - msg.ShortChannelID, - msg.ChannelFlags, + msg.SCID(), + msg.IsDisabled(), + msg.IsNode1(), } - oldTimestamp := uint32(0) + var ( + older = false + newer = true + ) mws, ok := d.channelUpdates[deDupKey] if ok { // If we already have seen this message, record its // timestamp. - update, ok := mws.msg.(*lnwire.ChannelUpdate1) + oldMsg, ok := mws.msg.(lnwire.ChannelUpdate) if !ok { - log.Errorf("Expected *lnwire.ChannelUpdate1, "+ - "got: %T", mws.msg) + log.Errorf("expected type "+ + "lnwire.ChannelUpdate, got: %T", + mws.msg) + + return + } + cmp, err := msg.CmpAge(oldMsg) + if err != nil { return } - oldTimestamp = update.Timestamp + newer = false + switch cmp { + case lnwire.LessThan: + older = true + case lnwire.GreaterThan: + newer = true + default: + } } // If we already had this message with a strictly newer // timestamp, then we'll just discard the message we got. - if oldTimestamp > msg.Timestamp { + if older { log.Debugf("Ignored outdated network message: "+ "peer=%v, msg=%s", message.peer, msg.MsgType()) return @@ -1087,7 +1101,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // If the message we just got is newer than what we previously // have seen, or this is the first time we see it, then we'll // add it to our map of announcements. - if oldTimestamp < msg.Timestamp { + if newer { mws = msgWithSenders{ msg: msg, isLocal: !message.isRemote, @@ -1608,8 +1622,8 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, var scid uint64 switch m := msg.(type) { - case *lnwire.ChannelUpdate1: - scid = m.ShortChannelID.ToUint64() + case lnwire.ChannelUpdate: + scid = m.SCID().ToUint64() case lnwire.ChannelAnnouncement: scid = m.SCID().ToUint64() @@ -1828,23 +1842,14 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( var defaultAlias lnwire.ShortChannelID foundAlias, _ := d.cfg.GetAlias(chanID) if foundAlias != defaultAlias { - chanUpdate.ShortChannelID = foundAlias + chanUpdate.SetSCID(foundAlias) - sig, err := d.cfg.SignAliasUpdate(chanUpdate) + err := d.cfg.SignAliasUpdate(chanUpdate) if err != nil { log.Errorf("Unable to sign alias "+ "update: %v", err) continue } - - lnSig, err := lnwire.NewSigFromSignature(sig) - if err != nil { - log.Errorf("Unable to create sig: %v", - err) - continue - } - - chanUpdate.Signature = lnSig } remotePubKey := remotePubFromChanInfo( @@ -1857,7 +1862,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( log.Errorf("Unable to reliably send %v for "+ "channel=%v to peer=%x: %v", chanUpdate.MsgType(), - chanUpdate.ShortChannelID, + chanUpdate.SCID(), remotePubKey, err) } continue @@ -2093,7 +2098,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // A new authenticated channel edge update has arrived. This indicates // that the directional information for an already known channel has // been updated. - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: return d.handleChanUpdate(nMsg, msg, schedulerOp) // A new signature announcement has been received. This indicates @@ -2116,7 +2121,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // ChannelEdgeInfo1 should be inspected. func (d *AuthenticatedGossiper) processZombieUpdate( chanInfo models.ChannelEdgeInfo, scid lnwire.ShortChannelID, - msg *lnwire.ChannelUpdate1) error { + msg lnwire.ChannelUpdate) error { // Since we've deemed the update as not stale above, before marking it // live, we'll make sure it has been signed by the correct party. If we @@ -2132,7 +2137,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( } if pubKey == nil { return fmt.Errorf("incorrect pubkey to resurrect zombie "+ - "with chan_id=%v", msg.ShortChannelID) + "with chan_id=%v", msg.SCID()) } err := msg.VerifySig(pubKey) @@ -2140,7 +2145,6 @@ func (d *AuthenticatedGossiper) processZombieUpdate( return fmt.Errorf("unable to verify channel "+ "update signature: %v", err) } - // With the signature valid, we'll proceed to mark the // edge as live and wait for the channel announcement to // come through again. @@ -2155,13 +2159,13 @@ func (d *AuthenticatedGossiper) processZombieUpdate( case err != nil: return fmt.Errorf("unable to remove edge with "+ "chan_id=%v from zombie index: %v", - msg.ShortChannelID, err) + msg.SCID(), err) default: } log.Debugf("Removed edge with chan_id=%v from zombie "+ - "index", msg.ShortChannelID) + "index", msg.SCID()) return nil } @@ -2255,28 +2259,22 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, - edgePolicy models.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate1, error) { + edge models.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + lnwire.ChannelUpdate, error) { // Parse the unsigned edge into a channel update. chanUpdate, err := netann.UnsignedChannelUpdateFromEdge( - edgeInfo.GetChainHash(), edgePolicy, + edgeInfo.GetChainHash(), edge, ) if err != nil { return nil, nil, err } - edge, ok := edgePolicy.(*models.ChannelEdgePolicy1) - if !ok { - return nil, nil, fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", edgePolicy) - } - // We'll generate a new signature over a digest of the channel // announcement itself and update the timestamp to ensure it propagate. err = netann.SignChannelUpdate( d.cfg.AnnSigner, d.selfKeyLoc, chanUpdate, - netann.ChanUpdSetTimestamp, + netann.ChanUpdSetTimestamp(d.latestHeight()), ) if err != nil { return nil, nil, err @@ -2284,8 +2282,25 @@ func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, // Next, we'll set the new signature in place, and update the reference // in the backing slice. - edge.LastUpdate = time.Unix(int64(chanUpdate.Timestamp), 0) - edge.SigBytes = chanUpdate.Signature.ToSignatureBytes() + switch e := edge.(type) { + case *models.ChannelEdgePolicy1: + chanUpd, ok := chanUpdate.(*lnwire.ChannelUpdate1) + if !ok { + return nil, nil, fmt.Errorf("wanted chan update 1") + } + + e.LastUpdate = time.Unix(int64(chanUpd.Timestamp), 0) + e.SigBytes = chanUpd.Signature.ToSignatureBytes() + + case *models.ChannelEdgePolicy2: + chanUpd, ok := chanUpdate.(*lnwire.ChannelUpdate2) + if !ok { + return nil, nil, fmt.Errorf("wanted chan update 2") + } + + e.BlockHeight = chanUpd.BlockHeight + e.Signature = chanUpd.Signature + } // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. @@ -2313,6 +2328,10 @@ func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, if err != nil { return nil, nil, err } + + case *models.ChannelEdgeInfo2: + chanAnn = chanAnn2FromEdgeInfo2(info) + default: return nil, nil, fmt.Errorf("unhandled "+ "implementation of models.ChannelEdgeInfo: "+ @@ -2367,6 +2386,15 @@ func chanAnn1FromEdgeInfo1(info *models.ChannelEdgeInfo1) ( return chanAnn, nil } +func chanAnn2FromEdgeInfo2( + info *models.ChannelEdgeInfo2) *lnwire.ChannelAnnouncement2 { + + chanAnn := info.ChannelAnnouncement2 + chanAnn.Signature = info.Signature + + return &chanAnn +} + // SyncManager returns the gossiper's SyncManager instance. func (d *AuthenticatedGossiper) SyncManager() *SyncManager { return d.syncMgr @@ -2375,48 +2403,114 @@ func (d *AuthenticatedGossiper) SyncManager() *SyncManager { // IsKeepAliveUpdate determines whether this channel update is considered a // keep-alive update based on the previous channel update processed for the same // direction. -func IsKeepAliveUpdate(update *lnwire.ChannelUpdate1, - prev *models.ChannelEdgePolicy1) bool { +func IsKeepAliveUpdate(update lnwire.ChannelUpdate, + prevPolicy models.ChannelEdgePolicy) (bool, error) { - // Both updates should be from the same direction. - if update.ChannelFlags&lnwire.ChanUpdateDirection != - prev.ChannelFlags&lnwire.ChanUpdateDirection { + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + prev, ok := prevPolicy.(*models.ChannelEdgePolicy1) + if !ok { + return false, fmt.Errorf("expected chan edge policy 1") + } - return false - } + // Both updates should be from the same direction. + if upd.ChannelFlags&lnwire.ChanUpdateDirection != + prev.ChannelFlags&lnwire.ChanUpdateDirection { - // The timestamp should always increase for a keep-alive update. - timestamp := time.Unix(int64(update.Timestamp), 0) - if !timestamp.After(prev.LastUpdate) { - return false - } + return false, nil + } - // None of the remaining fields should change for a keep-alive update. - if update.ChannelFlags.IsDisabled() != prev.ChannelFlags.IsDisabled() { - return false - } - if lnwire.MilliSatoshi(update.BaseFee) != prev.FeeBaseMSat { - return false - } - if lnwire.MilliSatoshi(update.FeeRate) != prev.FeeProportionalMillionths { - return false - } - if update.TimeLockDelta != prev.TimeLockDelta { - return false - } - if update.HtlcMinimumMsat != prev.MinHTLC { - return false - } - if update.MessageFlags.HasMaxHtlc() && !prev.MessageFlags.HasMaxHtlc() { - return false - } - if update.HtlcMaximumMsat != prev.MaxHTLC { - return false - } - if !bytes.Equal(update.ExtraOpaqueData, prev.ExtraOpaqueData) { - return false + // The timestamp should always increase for a keep-alive update. + timestamp := time.Unix(int64(upd.Timestamp), 0) + if !timestamp.After(prev.LastUpdate) { + return false, nil + } + + // None of the remaining fields should change for a keep-alive + // update. + if upd.ChannelFlags.IsDisabled() != + prev.ChannelFlags.IsDisabled() { + + return false, nil + } + if lnwire.MilliSatoshi(upd.BaseFee) != prev.FeeBaseMSat { + return false, nil + } + if lnwire.MilliSatoshi(upd.FeeRate) != + prev.FeeProportionalMillionths { + + return false, nil + } + if upd.TimeLockDelta != prev.TimeLockDelta { + return false, nil + } + if upd.HtlcMinimumMsat != prev.MinHTLC { + return false, nil + } + if upd.MessageFlags.HasMaxHtlc() && + !prev.MessageFlags.HasMaxHtlc() { + + return false, nil + } + if upd.HtlcMaximumMsat != prev.MaxHTLC { + return false, nil + } + if !bytes.Equal(upd.ExtraOpaqueData, prev.ExtraOpaqueData) { + return false, nil + } + + return true, nil + + case *lnwire.ChannelUpdate2: + prev, ok := prevPolicy.(*models.ChannelEdgePolicy2) + if !ok { + return false, fmt.Errorf("expected chan edge policy 2") + } + + // Both updates should be from the same direction. + if upd.IsNode1() != prev.IsNode1() { + return false, nil + } + + // The block-height should always increase for a keep-alive + // update. + if upd.BlockHeight.Val <= prev.BlockHeight.Val { + return false, nil + } + + // None of the remaining fields should change for a keep-alive + // update. + if upd.IsDisabled() != prev.IsDisabled() { + return false, nil + } + fwd := upd.ForwardingPolicy() + prevFwd := upd.ForwardingPolicy() + + if fwd.BaseFee != prevFwd.BaseFee { + return false, nil + } + if fwd.FeeRate != prevFwd.FeeRate { + return false, nil + } + if fwd.TimeLockDelta != prevFwd.TimeLockDelta { + return false, nil + } + if fwd.MinHTLC != prevFwd.MinHTLC { + return false, nil + } + if fwd.MaxHTLC != prevFwd.MinHTLC { + return false, nil + } + if !bytes.Equal(upd.ExtraOpaqueData, prev.ExtraOpaqueData) { + return false, nil + } + + return true, nil + + default: + return false, fmt.Errorf("unhandled implementation of "+ + "ChannelUpdate: %T", update) } - return true } // latestHeight returns the gossiper's latest height known of the chain. @@ -2836,7 +2930,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // Reprocess the message, making sure we return an // error to the original caller in case the gossiper // shuts down. - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: log.Debugf("Reprocessing ChannelUpdate for "+ "shortChanID=%v", scid.ToUint64()) @@ -2879,22 +2973,28 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // handleChanUpdate processes a new channel update. func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, - upd *lnwire.ChannelUpdate1, - ops []batch.SchedulerOption) ([]networkMsg, bool) { + upd lnwire.ChannelUpdate, ops []batch.SchedulerOption) ([]networkMsg, + bool) { + + var ( + scid = upd.SCID() + chainHash = upd.GetChainHash() + ) log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", - nMsg.peer, upd.ShortChannelID.ToUint64()) + nMsg.peer, scid) // We'll ignore any channel updates that target any chain other than // the set of chains we know of. - if !bytes.Equal(upd.ChainHash[:], d.cfg.ChainHash[:]) { - err := fmt.Errorf("ignoring ChannelUpdate from chain=%v, "+ - "gossiper on chain=%v", upd.ChainHash, d.cfg.ChainHash) + if !bytes.Equal(chainHash[:], d.cfg.ChainHash[:]) { + err := fmt.Errorf("ignoring %s from chain=%v, "+ + "gossiper on chain=%v", upd.MsgType(), chainHash, + d.cfg.ChainHash) + log.Errorf(err.Error()) key := newRejectCacheKey( - upd.ShortChannelID.ToUint64(), - sourceToPub(nMsg.source), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2902,8 +3002,8 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, return nil, false } - blockHeight := upd.ShortChannelID.BlockHeight - shortChanID := upd.ShortChannelID.ToUint64() + blockHeight := upd.SCID().BlockHeight + shortChanID := upd.SCID().ToUint64() // If the advertised inclusionary block is beyond our knowledge of the // chain tip, then we'll put the announcement in limbo to be fully @@ -2911,8 +3011,8 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // alias SCID, we'll skip the isPremature check. This is necessary // since aliases start at block height 16_000_000. d.Lock() - if nMsg.isRemote && !d.cfg.IsAlias(upd.ShortChannelID) && - d.isPremature(upd.ShortChannelID, 0, nMsg) { + if nMsg.isRemote && !d.cfg.IsAlias(scid) && + d.isPremature(scid, 0, nMsg) { log.Warnf("Update announcement for short_chan_id(%v), is "+ "premature: advertises height %v, only height %v is "+ @@ -2923,27 +3023,22 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } d.Unlock() - // Before we perform any of the expensive checks below, we'll check - // whether this update is stale or is for a zombie channel in order to - // quickly reject it. - timestamp := time.Unix(int64(upd.Timestamp), 0) - // Fetch the SCID we should be using to lock the channelMtx and make // graph queries with. - graphScid, err := d.cfg.FindBaseByAlias(upd.ShortChannelID) + graphScid, err := d.cfg.FindBaseByAlias(scid) if err != nil { // Fallback and set the graphScid to the peer-provided SCID. // This will occur for non-option-scid-alias channels and for // public option-scid-alias channels after 6 confirmations. // Once public option-scid-alias channels have 6 confs, we'll // ignore ChannelUpdates with one of their aliases. - graphScid = upd.ShortChannelID + graphScid = scid } - if d.cfg.Graph.IsStaleEdgePolicy( - graphScid, timestamp, upd.ChannelFlags, - ) { - + // Before we perform any of the expensive checks below, we'll check + // whether this update is stale or is for a zombie channel in order to + // quickly reject it. + if d.cfg.Graph.IsStaleEdgePolicy(graphScid, upd) { log.Debugf("Ignored stale edge policy for short_chan_id(%v): "+ "peer=%v, msg=%s, is_remote=%v", shortChanID, nMsg.peer, nMsg.msg.MsgType(), nMsg.isRemote, @@ -2955,18 +3050,41 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // Check that the ChanUpdate is not too far into the future, this could // reveal some faulty implementation therefore we log an error. - if time.Until(timestamp) > graph.DefaultChannelPruneExpiry { - log.Errorf("Skewed timestamp (%v) for edge policy of "+ - "short_chan_id(%v), timestamp too far in the future: "+ - "peer=%v, msg=%s, is_remote=%v", timestamp.Unix(), - shortChanID, nMsg.peer, nMsg.msg.MsgType(), - nMsg.isRemote, - ) + // TODO(elle): abstract this check + switch u := upd.(type) { + case *lnwire.ChannelUpdate1: + timestamp := time.Unix(int64(u.Timestamp), 0) + + if time.Until(timestamp) > graph.DefaultChannelPruneExpiry { + log.Errorf("Skewed timestamp (%v) for edge policy of "+ + "short_chan_id(%v), timestamp too far in the future: "+ + "peer=%v, msg=%s, is_remote=%v", timestamp.Unix(), + shortChanID, nMsg.peer, nMsg.msg.MsgType(), + nMsg.isRemote, + ) - nMsg.err <- fmt.Errorf("skewed timestamp of edge policy, "+ - "timestamp too far in the future: %v", timestamp.Unix()) + nMsg.err <- fmt.Errorf("skewed timestamp of edge policy, "+ + "timestamp too far in the future: %v", timestamp.Unix()) - return nil, false + return nil, false + } + + case *lnwire.ChannelUpdate2: + if int64(u.BlockHeight.Val)-int64(d.latestHeight()) > + int64(graph.DefaultChannelPruneExpiry.Hours()*6) { + + log.Errorf("Skewed blockheight (%v) for edge policy "+ + "of short_chan_id(%v), blockheight too far "+ + "in the future: peer=%v, msg=%s, is_remote=%v", + u.BlockHeight.Val, shortChanID, nMsg.peer, + nMsg.msg.MsgType(), nMsg.isRemote, + ) + + nMsg.err <- fmt.Errorf("skewed blockheight of edge policy, "+ + "timestamp too far in the future: %v", u.BlockHeight) + + return nil, false + } } // Get the node pub key as far since we don't have it in the channel @@ -3007,7 +3125,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // If the edge corresponding to this ChannelUpdate was not // found in the graph, this might be a channel in the process // of being opened, and we haven't processed our own - // ChannelAnnouncement yet, hence it is not not found in the + // ChannelAnnouncement yet, hence it is not found in the // graph. This usually gets resolved after the channel proofs // are exchanged and the channel is broadcasted to the rest of // the network, but in case this is a private channel this @@ -3060,7 +3178,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, nMsg.err <- err key := newRejectCacheKey( - upd.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -3074,15 +3192,16 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, var ( pubKey *btcec.PublicKey edgeToUpdate models.ChannelEdgePolicy + direction int ) - direction := upd.ChannelFlags & lnwire.ChanUpdateDirection - switch direction { - case 0: + if upd.IsNode1() { pubKey, _ = chanInfo.NodeKey1() edgeToUpdate = e1 - case 1: + direction = 0 + } else { pubKey, _ = chanInfo.NodeKey2() edgeToUpdate = e2 + direction = 1 } var chanID = chanInfo.GetChanID() @@ -3100,39 +3219,42 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, if err != nil { rErr := fmt.Errorf("unable to validate channel update "+ "announcement for short_chan_id=%v: %v", - spew.Sdump(upd.ShortChannelID), err) + spew.Sdump(scid), err) log.Error(rErr) nMsg.err <- rErr return nil, false } - var edge *models.ChannelEdgePolicy1 - if edgeToUpdate != nil { - var ok bool - edge, ok = edgeToUpdate.(*models.ChannelEdgePolicy1) - if !ok { - rErr := fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", - edgeToUpdate) - - log.Error(rErr) - nMsg.err <- rErr - - return nil, false - } - } - // If we have a previous version of the edge being updated, we'll want // to rate limit its updates to prevent spam throughout the network. - if nMsg.isRemote && edge != nil { + if nMsg.isRemote && edgeToUpdate != nil { // If it's a keep-alive update, we'll only propagate one if // it's been a day since the previous. This follows our own // heuristic of sending keep-alive updates after the same // duration (see retransmitStaleAnns). - timeSinceLastUpdate := timestamp.Sub(edge.LastUpdate) - if IsKeepAliveUpdate(upd, edge) { - if timeSinceLastUpdate < d.cfg.RebroadcastInterval { + isKeepAlive, err := IsKeepAliveUpdate(upd, edgeToUpdate) + if err != nil { + log.Errorf("Could not determine if update is "+ + "keepalive: %v", err) + nMsg.err <- err + + return nil, false + } + + if isKeepAlive { + within, err := d.updateWithinRebroadcastInterval( + upd, edgeToUpdate, + ) + if err != nil { + log.Errorf("Could not determine if update is "+ + "within rebroadcast interval: %v", err) + nMsg.err <- err + + return nil, false + } + + if !within { log.Debugf("Ignoring keep alive update not "+ "within %v period for channel %v", d.cfg.RebroadcastInterval, shortChanID) @@ -3151,7 +3273,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // multiple aliases for a channel and we may otherwise // rate-limit only a single alias of the channel, // instead of the whole channel. - baseScid := chanID + baseScid := chanInfo.GetChanID() d.Lock() rls, ok := d.chanUpdateRateLimiter[baseScid] if !ok { @@ -3182,18 +3304,23 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // different alias. This might mean that SigBytes is incorrect as it // signs a different SCID than the database SCID, but since there will // only be a difference if AuthProof == nil, this is fine. - update := &models.ChannelEdgePolicy1{ - SigBytes: upd.Signature.ToSignatureBytes(), - ChannelID: chanID, - LastUpdate: timestamp, - MessageFlags: upd.MessageFlags, - ChannelFlags: upd.ChannelFlags, - TimeLockDelta: upd.TimeLockDelta, - MinHTLC: upd.HtlcMinimumMsat, - MaxHTLC: upd.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), - ExtraOpaqueData: upd.ExtraOpaqueData, + update, err := models.EdgePolicyFromUpdate(upd) + if err != nil { + rErr := fmt.Errorf("unable to convert update to policy for "+ + "short_chan_id=%v: %v", spew.Sdump(scid), err) + + log.Error(rErr) + nMsg.err <- rErr + + return nil, false + } + switch upd := update.(type) { + case *models.ChannelEdgePolicy1: + upd.ChannelID = chanInfo.GetChanID() + case *models.ChannelEdgePolicy2: + upd.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( + chanInfo.GetChanID(), + ) } if err := d.cfg.Graph.UpdateEdge(update, ops...); err != nil { @@ -3209,7 +3336,8 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // Since we know the stored SCID in the graph, we'll // cache that SCID. key := newRejectCacheKey( - chanID, sourceToPub(nMsg.source), + chanInfo.GetChanID(), + sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -3218,41 +3346,33 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } nMsg.err <- err + return nil, false } // If this is a local ChannelUpdate without an AuthProof, it means it // is an update to a channel that is not (yet) supposed to be announced - // to the greater network. However, our channel counter party will need + // to the greater network. However, our channel counterparty will need // to be given the update, so we'll try sending the update directly to // the remote peer. if !nMsg.isRemote && chanInfo.GetAuthProof() == nil { - if nMsg.optionalMsgFields != nil { + if nMsg.optionalMsgFields != nil && + nMsg.optionalMsgFields.remoteAlias != nil { + + // The remoteAlias field was specified, meaning + // that we should replace the SCID in the + // update with the remote's alias. We'll also + // need to re-sign the channel update. This is + // required for option-scid-alias feature-bit + // negotiated channels. remoteAlias := nMsg.optionalMsgFields.remoteAlias - if remoteAlias != nil { - // The remoteAlias field was specified, meaning - // that we should replace the SCID in the - // update with the remote's alias. We'll also - // need to re-sign the channel update. This is - // required for option-scid-alias feature-bit - // negotiated channels. - upd.ShortChannelID = *remoteAlias - - sig, err := d.cfg.SignAliasUpdate(upd) - if err != nil { - log.Error(err) - nMsg.err <- err - return nil, false - } + upd.SetSCID(*remoteAlias) - lnSig, err := lnwire.NewSigFromSignature(sig) - if err != nil { - log.Error(err) - nMsg.err <- err - return nil, false - } - - upd.Signature = lnSig + err := d.cfg.SignAliasUpdate(upd) + if err != nil { + log.Error(err) + nMsg.err <- err + return nil, false } } @@ -3269,7 +3389,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, if err != nil { err := fmt.Errorf("unable to reliably send %v for "+ "channel=%v to peer=%x: %v", upd.MsgType(), - upd.ShortChannelID, remotePubKey, err) + scid, remotePubKey, err) nMsg.err <- err return nil, false } @@ -3282,7 +3402,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // contains an alias because the network would reject this. var announcements []networkMsg if chanInfo.GetAuthProof() != nil && - !d.cfg.IsAlias(upd.ShortChannelID) { + !d.cfg.IsAlias(scid) { announcements = append(announcements, networkMsg{ peer: nMsg.peer, @@ -3294,9 +3414,9 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, nMsg.err <- nil - log.Debugf("Processed ChannelUpdate: peer=%v, short_chan_id=%v, "+ - "timestamp=%v", nMsg.peer, upd.ShortChannelID.ToUint64(), - timestamp) + log.Debugf("Processed %s: peer=%v, short_chan_id=%v, ", upd.MsgType(), + nMsg.peer, scid.ToUint64()) + return announcements, true } @@ -3756,6 +3876,39 @@ func (d *AuthenticatedGossiper) ShouldDisconnect(pubkey *btcec.PublicKey) ( return false, nil } +func (d *AuthenticatedGossiper) updateWithinRebroadcastInterval( + upd lnwire.ChannelUpdate, policy models.ChannelEdgePolicy) (bool, + error) { + + switch update := upd.(type) { + case *lnwire.ChannelUpdate1: + pol, ok := policy.(*models.ChannelEdgePolicy1) + if !ok { + return false, fmt.Errorf("expected chan edge policy 1") + } + + timestamp := time.Unix(int64(update.Timestamp), 0) + timeSinceLastUpdate := timestamp.Sub(pol.LastUpdate) + + return timeSinceLastUpdate >= d.cfg.RebroadcastInterval, nil + + case *lnwire.ChannelUpdate2: + pol, ok := policy.(*models.ChannelEdgePolicy2) + if !ok { + return false, fmt.Errorf("expected chan edge policy 2") + } + + blocksSinceLastUpdate := update.BlockHeight.Val - + pol.BlockHeight.Val + + return blocksSinceLastUpdate >= + uint32(d.cfg.RebroadcastInterval.Hours()*6), nil + + default: + return false, fmt.Errorf("unhandled impl of Chan Update") + } +} + func buildChanProof(ann lnwire.ChannelAnnouncement) ( models.ChannelAuthProof, error) { diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 3c994c203ef..908dc6cee57 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -14,7 +14,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" @@ -94,7 +93,7 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]models.ChannelEdgePolicy1 + edges map[uint64][]models.ChannelEdgePolicy zombies map[uint64][][33]byte chansToReject map[uint64]struct{} addEdgeErrCode fn.Option[graph.ErrorCode] @@ -104,7 +103,7 @@ func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, infos: make(map[uint64]models.ChannelEdgeInfo), - edges: make(map[uint64][]models.ChannelEdgePolicy1), + edges: make(map[uint64][]models.ChannelEdgePolicy), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), } @@ -162,20 +161,22 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.chansToReject[chanID] = struct{}{} } -func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy1, +func (r *mockGraphSource) UpdateEdge(edge models.ChannelEdgePolicy, _ ...batch.SchedulerOption) error { r.mu.Lock() defer r.mu.Unlock() - if len(r.edges[edge.ChannelID]) == 0 { - r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy1, 2) + chanID := edge.SCID().ToUint64() + + if len(r.edges[chanID]) == 0 { + r.edges[chanID] = make([]models.ChannelEdgePolicy, 2) } - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { - r.edges[edge.ChannelID][0] = *edge + if edge.IsNode1() { + r.edges[chanID][0] = edge } else { - r.edges[edge.ChannelID][1] = *edge + r.edges[chanID][1] = edge } return nil @@ -219,7 +220,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, r.mu.Lock() defer r.mu.Unlock() - chans := make(map[uint64]channeldb.ChannelEdge) for _, info := range r.infos { info := info @@ -231,9 +231,9 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, for _, edges := range r.edges { edges := edges - edge := chans[edges[0].ChannelID] - edge.Policy1 = &edges[0] - chans[edges[0].ChannelID] = edge + edge := chans[edges[0].SCID().ToUint64()] + edge.Policy1 = edges[0] + chans[edges[0].SCID().ToUint64()] = edge } for _, channel := range chans { @@ -241,7 +241,6 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, return err } } - return nil } @@ -267,22 +266,24 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( }, nil, nil, channeldb.ErrZombieEdge } + chanInfoCP := chanInfo.Copy() + edges := r.edges[chanID.ToUint64()] if len(edges) == 0 { - return chanInfo, nil, nil, nil + return chanInfoCP, nil, nil, nil } - var edge1 *models.ChannelEdgePolicy1 + var edge1 models.ChannelEdgePolicy if !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}) { - edge1 = &edges[0] + edge1 = edges[0] } - var edge2 *models.ChannelEdgePolicy1 + var edge2 models.ChannelEdgePolicy if !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}) { - edge2 = &edges[1] + edge2 = edges[1] } - return chanInfo, edge1, edge2, nil + return chanInfoCP, edge1, edge2, nil } func (r *mockGraphSource) FetchLightningNode( @@ -358,11 +359,18 @@ func (r *mockGraphSource) IsKnownEdge(chanID lnwire.ShortChannelID) bool { // IsStaleEdgePolicy returns true if the graph source has a channel edge for // the passed channel ID (and flags) that have a more recent timestamp. func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, - timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { + policy lnwire.ChannelUpdate) bool { r.mu.Lock() defer r.mu.Unlock() + pol, ok := policy.(*lnwire.ChannelUpdate1) + if !ok { + panic("expected chan update 1") + } + + timestamp := time.Unix(int64(pol.Timestamp), 0) + chanIDInt := chanID.ToUint64() edges, ok := r.edges[chanIDInt] if !ok { @@ -372,7 +380,6 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, if !isZombie { return false } - // Since it exists within our zombie index, we'll check that it // respects the router's live edge horizon to determine whether // it is stale or not. @@ -380,15 +387,21 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, } switch { - case flags&lnwire.ChanUpdateDirection == 0 && - !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}): - - return !timestamp.After(edges[0].LastUpdate) - - case flags&lnwire.ChanUpdateDirection == 1 && - !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}): + case policy.IsNode1() && edges[0] != nil: + switch edge := edges[0].(type) { + case *models.ChannelEdgePolicy1: + return !timestamp.After(edge.LastUpdate) + default: + panic(fmt.Sprintf("unhandled: %T", edges[0])) + } - return !timestamp.After(edges[1].LastUpdate) + case !policy.IsNode1() && edges[1] != nil: + switch edge := edges[1].(type) { + case *models.ChannelEdgePolicy1: + return !timestamp.After(edge.LastUpdate) + default: + panic(fmt.Sprintf("unhandled: %T", edges[1])) + } default: return false @@ -759,10 +772,8 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( return false } - signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) { - - return nil, nil + signAliasUpdate := func(lnwire.ChannelUpdate) error { + return nil } findBaseByAlias := func(lnwire.ShortChannelID) (lnwire.ShortChannelID, @@ -1472,10 +1483,8 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { return false } - signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) { - - return nil, nil + signAliasUpdate := func(lnwire.ChannelUpdate) error { + return nil } findBaseByAlias := func(lnwire.ShortChannelID) (lnwire.ShortChannelID, @@ -1881,7 +1890,8 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate1) { channelKey := channelUpdateID{ ua3.ShortChannelID, - ua3.ChannelFlags, + ua3.IsDisabled(), + ua3.IsNode1(), } mws, ok := announcements.channelUpdates[channelKey] @@ -2824,7 +2834,7 @@ func TestRetransmit(t *testing.T) { switch msg.(type) { case lnwire.ChannelAnnouncement: chanAnn++ - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: chanUpd++ case *lnwire.NodeAnnouncement: nodeAnn++ @@ -3311,7 +3321,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { } switch msg := msg.(type) { - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: assertMessage(t, staleChannelUpdate, msg) case *lnwire.AnnounceSignatures1: assertMessage(t, batch.localProofAnn, msg) diff --git a/discovery/message_store.go b/discovery/message_store.go index e336c1281bf..156d56caa37 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -85,8 +85,8 @@ func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) { switch msg := msg.(type) { case lnwire.AnnounceSignatures: shortChanID = msg.SCID() - case *lnwire.ChannelUpdate1: - shortChanID = msg.ShortChannelID + case lnwire.ChannelUpdate: + shortChanID = msg.SCID() default: return shortChanID, ErrUnsupportedMessage } @@ -160,7 +160,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // In the event that we're attempting to delete a ChannelUpdate // from the store, we'll make sure that we're actually deleting // the correct one as it can be overwritten. - if msg, ok := msg.(*lnwire.ChannelUpdate1); ok { + if msg, ok := msg.(lnwire.ChannelUpdate); ok { // Deleting a value from a bucket that doesn't exist // acts as a NOP, so we'll return if a message doesn't // exist under this key. @@ -176,13 +176,18 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // If the timestamps don't match, then the update stored // should be the latest one, so we'll avoid deleting it. - m, ok := dbMsg.(*lnwire.ChannelUpdate1) + m, ok := dbMsg.(lnwire.ChannelUpdate) if !ok { return fmt.Errorf("expected "+ - "*lnwire.ChannelUpdate1, got: %T", - dbMsg) + "lnwire.ChannelUpdate, got: %T", dbMsg) + } + + diff, err := msg.CmpAge(m) + if err != nil { + return err } - if msg.Timestamp != m.Timestamp { + + if diff != lnwire.EqualTo { return nil } } diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index 36c082e36f2..10189b9022a 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -116,10 +116,10 @@ func TestMessageStoreMessages(t *testing.T) { for _, msg := range peerMsgs { var shortChanID uint64 switch msg := msg.(type) { - case *lnwire.AnnounceSignatures1: - shortChanID = msg.ShortChannelID.ToUint64() - case *lnwire.ChannelUpdate1: - shortChanID = msg.ShortChannelID.ToUint64() + case lnwire.AnnounceSignatures: + shortChanID = msg.SCID().ToUint64() + case lnwire.ChannelUpdate: + shortChanID = msg.SCID().ToUint64() default: t.Fatalf("found unexpected message type %T", msg) } diff --git a/discovery/syncer.go b/discovery/syncer.go index 886aa4be011..a3a2945fa3d 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1413,16 +1413,16 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // to quickly check if we should forward a chan ann, based on the known // channel updates for a channel. chanUpdateIndex := make( - map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate1, + map[lnwire.ShortChannelID][]lnwire.ChannelUpdate, ) for _, msg := range msgs { - chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate1) + chanUpdate, ok := msg.msg.(lnwire.ChannelUpdate) if !ok { continue } - chanUpdateIndex[chanUpdate.ShortChannelID] = append( - chanUpdateIndex[chanUpdate.ShortChannelID], chanUpdate, + chanUpdateIndex[chanUpdate.SCID()] = append( + chanUpdateIndex[chanUpdate.SCID()], chanUpdate, ) } @@ -1475,7 +1475,16 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { } for _, chanUpdate := range chanUpdates { - if passesFilter(chanUpdate.Timestamp) { + update, ok := chanUpdate.(*lnwire.ChannelUpdate1) + if !ok { + log.Errorf("expected "+ + "*lnwire.ChannelUpdate1, "+ + "got: %T", update) + + continue + } + + if passesFilter(update.Timestamp) { msgsToSend = append(msgsToSend, msg) break } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 0ee635a0f2f..f176580578e 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -52,7 +52,7 @@ type mockChannelGraphTimeSeries struct { annResp chan []lnwire.Message updateReq chan lnwire.ShortChannelID - updateResp chan []*lnwire.ChannelUpdate1 + updateResp chan []lnwire.ChannelUpdate } func newMockChannelGraphTimeSeries( @@ -74,7 +74,7 @@ func newMockChannelGraphTimeSeries( annResp: make(chan []lnwire.Message, 1), updateReq: make(chan lnwire.ShortChannelID, 1), - updateResp: make(chan []*lnwire.ChannelUpdate1, 1), + updateResp: make(chan []lnwire.ChannelUpdate, 1), } } @@ -149,7 +149,7 @@ func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, return <-m.annResp, nil } func (m *mockChannelGraphTimeSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { + shortChanID lnwire.ShortChannelID) ([]lnwire.ChannelUpdate, error) { m.updateReq <- shortChanID @@ -369,8 +369,8 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { } // If so, then we'll send back the missing update. - chanSeries.updateResp <- []*lnwire.ChannelUpdate1{ - { + chanSeries.updateResp <- []lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), }, diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 382e9180178..5feedd25baf 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -57,6 +57,7 @@ * Use the new interfaces added for Gossip 1.75 throughout the codebase [1](https://github.com/lightningnetwork/lnd/pull/8252/). [2](https://github.com/lightningnetwork/lnd/pull/8253). + [3](https://github.com/lightningnetwork/lnd/pull/8254). ## Testing ## Database diff --git a/funding/manager.go b/funding/manager.go index 4f4334dd113..94a163dfea9 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -4144,7 +4144,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // send out to the network after a new channel has been created locally. type chanAnnouncement struct { chanAnn lnwire.ChannelAnnouncement - chanUpdateAnn *lnwire.ChannelUpdate1 + chanUpdateAnn lnwire.ChannelUpdate chanProof lnwire.AnnounceSignatures } diff --git a/funding/manager_test.go b/funding/manager_test.go index 26fd0ca3c13..78788ffe612 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -1210,7 +1210,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, switch m := msg.(type) { case lnwire.ChannelAnnouncement: gotChannelAnnouncement = true - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: // The channel update sent by the node should // advertise the MinHTLC value required by the @@ -1225,31 +1225,33 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, baseFee := aliceCfg.DefaultRoutingPolicy.BaseFee feeRate := aliceCfg.DefaultRoutingPolicy.FeeRate - require.EqualValues(t, 1, m.MessageFlags) + pol := m.ForwardingPolicy() + + require.True(t, pol.HasMaxHTLC) // We might expect a custom MinHTLC value. if len(customMinHtlc) > 0 { minHtlc = customMinHtlc[j] } - require.Equal(t, minHtlc, m.HtlcMinimumMsat) + require.Equal(t, minHtlc, pol.MinHTLC) // We might expect a custom MaxHltc value. if len(customMaxHtlc) > 0 { maxHtlc = customMaxHtlc[j] } - require.Equal(t, maxHtlc, m.HtlcMaximumMsat) + require.Equal(t, maxHtlc, pol.MaxHTLC) // We might expect a custom baseFee value. if len(baseFees) > 0 { baseFee = baseFees[j] } - require.EqualValues(t, baseFee, m.BaseFee) + require.EqualValues(t, baseFee, pol.BaseFee) // We might expect a custom feeRate value. if len(feeRates) > 0 { feeRate = feeRates[j] } - require.EqualValues(t, feeRate, m.FeeRate) + require.EqualValues(t, feeRate, pol.FeeRate) gotChannelUpdate = true } diff --git a/graph/builder.go b/graph/builder.go index d74c3b6b0df..107f5a7eebd 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1512,49 +1512,35 @@ type routingMsg struct { // ApplyChannelUpdate validates a channel update and if valid, applies it to the // database. It returns a bool indicating whether the updates were successful. -func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { - ch, _, _, err := b.GetChannelByID(msg.ShortChannelID) +func (b *Builder) ApplyChannelUpdate(msg lnwire.ChannelUpdate) bool { + ch, _, _, err := b.GetChannelByID(msg.SCID()) if err != nil { log.Errorf("Unable to retrieve channel by id: %v", err) return false } var pubKey *btcec.PublicKey - - switch msg.ChannelFlags & lnwire.ChanUpdateDirection { - case 0: + if msg.IsNode1() { pubKey, _ = ch.NodeKey1() - - case 1: + } else { pubKey, _ = ch.NodeKey2() } - // Exit early if the pubkey cannot be decided. - if pubKey == nil { - log.Errorf("Unable to decide pubkey with ChannelFlags=%v", - msg.ChannelFlags) + err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg) + if err != nil { + log.Errorf("Unable to validate channel update: %v", err) return false } - err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg) + edgePolicy, err := models.EdgePolicyFromUpdate(msg) if err != nil { - log.Errorf("Unable to validate channel update: %v", err) + log.Errorf("Unable to convert update message to edge "+ + "policy: %v", err) + return false } - err = b.UpdateEdge(&models.ChannelEdgePolicy1{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - ExtraOpaqueData: msg.ExtraOpaqueData, - }) + err = b.UpdateEdge(edgePolicy) if err != nil && !IsError(err, ErrIgnored, ErrOutdated) { log.Errorf("Unable to apply channel update: %v", err) return false @@ -1621,7 +1607,7 @@ func (b *Builder) AddEdge(edge models.ChannelEdgeInfo, // considered as not fully constructed. // // NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy1, +func (b *Builder) UpdateEdge(update models.ChannelEdgePolicy, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -1775,53 +1761,108 @@ func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool { // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, - timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool { + update lnwire.ChannelUpdate) bool { - edge1Timestamp, edge2Timestamp, exists, isZombie, err := - b.cfg.Graph.HasChannelEdge1(chanID.ToUint64()) - if err != nil { - log.Debugf("Check stale edge policy got error: %v", err) - return false - } + var ( + disabled = update.IsDisabled() + isNode1 = update.IsNode1() + ) + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + timestamp := time.Unix(int64(upd.Timestamp), 0) - // If we know of the edge as a zombie, then we'll make some additional - // checks to determine if the new policy is fresh. - if isZombie { - // When running with AssumeChannelValid, we also prune channels - // if both of their edges are disabled. We'll mark the new - // policy as stale if it remains disabled. - if b.cfg.AssumeChannelValid { - isDisabled := flags&lnwire.ChanUpdateDisabled == - lnwire.ChanUpdateDisabled - if isDisabled { - return true + edge1Timestamp, edge2Timestamp, exists, isZombie, err := + b.cfg.Graph.HasChannelEdge1(chanID.ToUint64()) + if err != nil { + log.Debugf("Check stale edge policy got error: %v", err) + + return false + } + + // If we know of the edge as a zombie, then we'll make some + // additional checks to determine if the new policy is fresh. + if isZombie { + // When running with AssumeChannelValid, we also prune + // channels if both of their edges are disabled. We'll + // mark the new policy as stale if it remains disabled. + if b.cfg.AssumeChannelValid { + if disabled { + return true + } } + + // Otherwise, we'll fall back to our usual + // ChannelPruneExpiry. + return time.Since(timestamp) > b.cfg.ChannelPruneExpiry } - // Otherwise, we'll fall back to our usual ChannelPruneExpiry. - return time.Since(timestamp) > b.cfg.ChannelPruneExpiry - } + // If we don't know of the edge, then it means it's fresh (thus + // not stale). + if !exists { + return false + } - // If we don't know of the edge, then it means it's fresh (thus not - // stale). - if !exists { - return false - } + // As edges are directional edge node has a unique policy for + // the direction of the edge they control. Therefore we first + // check if we already have the most up to date information for + // that edge. If so, then we can exit early. + switch { + case isNode1: + return !edge1Timestamp.Before(timestamp) - // As edges are directional edge node has a unique policy for the - // direction of the edge they control. Therefore, we first check if we - // already have the most up-to-date information for that edge. If so, - // then we can exit early. - switch { - // A flag set of 0 indicates this is an announcement for the "first" - // node in the channel. - case flags&lnwire.ChanUpdateDirection == 0: - return !edge1Timestamp.Before(timestamp) - - // Similarly, a flag set of 1 indicates this is an announcement for the - // "second" node in the channel. - case flags&lnwire.ChanUpdateDirection == 1: - return !edge2Timestamp.Before(timestamp) + case !isNode1: + return !edge2Timestamp.Before(timestamp) + } + + case *lnwire.ChannelUpdate2: + height := upd.BlockHeight + + edge1Height, edge2Height, exists, isZombie, err := + b.cfg.Graph.HasChannelEdge2(chanID.ToUint64()) + if err != nil { + log.Debugf("Check stale edge policy got error: %v", err) + + return false + } + + // If we know of the edge as a zombie, then we'll make some + // additional checks to determine if the new policy is fresh. + if isZombie { + // When running with AssumeChannelValid, we also prune + // channels if both of their edges are disabled. We'll + // mark the new policy as stale if it remains disabled. + if b.cfg.AssumeChannelValid { + if disabled { + return true + } + } + + // Otherwise, we'll fall back to our usual + // ChannelPruneExpiry. + blocksSince := b.SyncedHeight() - height.Val + + return blocksSince > + uint32(b.cfg.ChannelPruneExpiry.Hours()*6) + } + + // If we don't know of the edge, then it means it's fresh (thus + // not stale). + if !exists { + return false + } + + // As edges are directional edge node has a unique policy for + // the direction of the edge they control. Therefore we first + // check if we already have the most up to date information for + // that edge. If so, then we can exit early. + switch { + case isNode1: + return edge1Height >= height.Val + + case !isNode1: + return edge2Height >= height.Val + } } return false diff --git a/graph/builder_test.go b/graph/builder_test.go index de3afaaabf9..b883b2cc6f9 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -1142,13 +1142,17 @@ func TestIsStaleEdgePolicy(t *testing.T) { // If we query for staleness before adding the edge, we should get // false. - updateTimeStamp := time.Unix(123, 0) - if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { - t.Fatalf("router failed to detect fresh edge policy") + time1 := 123 + updateTimeStamp := time.Unix(int64(time1), 0) + update1 := &lnwire.ChannelUpdate1{ + Timestamp: uint32(time1), } - if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { - t.Fatalf("router failed to detect fresh edge policy") + update2 := &lnwire.ChannelUpdate1{ + Timestamp: uint32(time1), + ChannelFlags: lnwire.ChanUpdateDirection, } + require.False(t, ctx.builder.IsStaleEdgePolicy(*chanID, update1)) + require.False(t, ctx.builder.IsStaleEdgePolicy(*chanID, update2)) edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), @@ -1193,20 +1197,22 @@ func TestIsStaleEdgePolicy(t *testing.T) { // Now that the edges have been added, an identical (chanID, flag, // timestamp) tuple for each edge should be detected as a stale edge. - if !ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { + if !ctx.builder.IsStaleEdgePolicy(*chanID, update1) { t.Fatalf("router failed to detect stale edge policy") } - if !ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { + if !ctx.builder.IsStaleEdgePolicy(*chanID, update2) { t.Fatalf("router failed to detect stale edge policy") } // If we now update the timestamp for both edges, the router should // detect that this tuple represents a fresh edge. - updateTimeStamp = time.Unix(9999, 0) - if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 0) { + time2 := 9999 + update1.Timestamp = uint32(time2) + update2.Timestamp = uint32(time2) + if ctx.builder.IsStaleEdgePolicy(*chanID, update1) { t.Fatalf("router failed to detect fresh edge policy") } - if ctx.builder.IsStaleEdgePolicy(*chanID, updateTimeStamp, 1) { + if ctx.builder.IsStaleEdgePolicy(*chanID, update2) { t.Fatalf("router failed to detect fresh edge policy") } } diff --git a/graph/interfaces.go b/graph/interfaces.go index df8d16932bc..1eede0a70d9 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -39,7 +39,7 @@ type ChannelGraphSource interface { // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. - UpdateEdge(policy *models.ChannelEdgePolicy1, + UpdateEdge(policy models.ChannelEdgePolicy, op ...batch.SchedulerOption) error // IsStaleNode returns true if the graph source has a node announcement @@ -59,8 +59,8 @@ type ChannelGraphSource interface { // IsStaleEdgePolicy returns true if the graph source has a channel // edge for the passed channel ID (and flags) that have a more recent // timestamp. - IsStaleEdgePolicy(chanID lnwire.ShortChannelID, timestamp time.Time, - flags lnwire.ChanUpdateChanFlags) bool + IsStaleEdgePolicy(chanID lnwire.ShortChannelID, + policy lnwire.ChannelUpdate) bool // MarkEdgeLive clears an edge from our zombie index, deeming it as // live. diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index c1de127bad9..74ca8962048 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -146,7 +146,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // initialization needs to be done beyond just occupying a job slot. case models.ChannelEdgePolicy: return - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: return case *lnwire.NodeAnnouncement: // TODO(roasbeef): node ann needs to wait on existing channel updates @@ -201,11 +201,11 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { jobDesc = fmt.Sprintf("job=channeldb.LightningNode, pub=%s", vertex) - case *lnwire.ChannelUpdate1: - signals, ok = v.chanEdgeDependencies[msg.ShortChannelID] + case lnwire.ChannelUpdate: + signals, ok = v.chanEdgeDependencies[msg.SCID()] jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate, scid=%v", - msg.ShortChannelID.ToUint64()) + msg.SCID().ToUint64()) case *lnwire.NodeAnnouncement: vertex := route.Vertex(msg.NodeID) @@ -296,8 +296,8 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes)) case *lnwire.NodeAnnouncement: delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) - case *lnwire.ChannelUpdate1: - delete(v.chanEdgeDependencies, msg.ShortChannelID) + case lnwire.ChannelUpdate: + delete(v.chanEdgeDependencies, msg.SCID()) case models.ChannelEdgePolicy: delete(v.chanEdgeDependencies, msg.SCID()) diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 5517eb82dd1..07e22bec589 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -702,7 +702,7 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { return err } - failureMsg = lnwire.NewExpiryTooSoon(*update) + failureMsg = lnwire.NewExpiryTooSoon(update) default: return ErrUnsupportedFailureCode diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index c2520428f51..423402e4fdd 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -85,7 +85,7 @@ type scidAliasHandler interface { // HTLCs on option_scid_alias channels. attachFailAliasUpdate(failClosure func( sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate1) + incoming bool) lnwire.ChannelUpdate) // getAliases fetches the link's underlying aliases. This is used by // the Switch to determine whether to forward an HTLC and where to diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 5934048eba0..428f60a85f5 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -120,7 +120,7 @@ type ChannelLinkConfig struct { // provide payment senders our latest policy when sending encrypted // error messages. FetchLastChannelUpdate func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate1, error) + lnwire.ChannelUpdate, error) // Peer is a lightning network node with which we have the channel link // opened. @@ -262,7 +262,7 @@ type ChannelLinkConfig struct { // FailAliasUpdate is a function used to fail an HTLC for an // option_scid_alias channel. FailAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate1 + incoming bool) lnwire.ChannelUpdate // GetAliases is used by the link and switch to fetch the set of // aliases for a given link. @@ -764,7 +764,7 @@ func shouldAdjustCommitFee(netFee, chanFee, } // failCb is used to cut down on the argument verbosity. -type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage +type failCb func(update lnwire.ChannelUpdate) lnwire.FailureMessage // createFailureWithUpdate creates a ChannelUpdate when failing an incoming or // outgoing HTLC. It may return a FailureMessage that references a channel's @@ -2962,7 +2962,7 @@ func (l *channelLink) getAliases() []lnwire.ShortChannelID { // // Part of the scidAliasHandler interface. func (l *channelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { + sid lnwire.ShortChannelID, incoming bool) lnwire.ChannelUpdate) { l.Lock() l.cfg.FailAliasUpdate = closure @@ -3054,8 +3054,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewFeeInsufficient(amtToForward, *upd) + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewFeeInsufficient(amtToForward, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -3082,9 +3082,9 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Grab the latest routing policy so the sending node is up to // date with our current policy. - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { return lnwire.NewIncorrectCltvExpiry( - incomingTimeout, *upd, + incomingTimeout, upd, ) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3131,8 +3131,8 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewAmountBelowMinimum(amt, *upd) + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewAmountBelowMinimum(amt, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -3146,7 +3146,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up-to-date data. - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3161,8 +3161,8 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, "outgoing_expiry=%v, best_height=%v", payHash[:], timeout, heightNow) - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewExpiryTooSoon(*upd) + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { + return lnwire.NewExpiryTooSoon(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -3181,7 +3181,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, if amt > l.Bandwidth() { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3694,7 +3694,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, l.log.Errorf("unable to encode the "+ "remaining route %v", err) - cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { //nolint:lll + //nolint:lll + cb := func(upd lnwire.ChannelUpdate) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 4d07b28efc4..01b7d18b6ca 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6166,13 +6166,13 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) { // forwarding policy. func TestCheckHtlcForward(t *testing.T) { fetchLastChannelUpdate := func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate1, error) { + lnwire.ChannelUpdate, error) { return &lnwire.ChannelUpdate1{}, nil } failAliasUpdate := func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate1 { + incoming bool) lnwire.ChannelUpdate { return nil } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 761e77801ce..9deb762ff39 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -16,7 +16,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" @@ -166,10 +165,20 @@ type mockServer struct { var _ lnpeer.Peer = (*mockServer)(nil) func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { - signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) { + signAliasUpdate := func(update lnwire.ChannelUpdate) error { + s, err := lnwire.NewSigFromSignature(testSig) + if err != nil { + return err + } - return testSig, nil + switch u := update.(type) { + case *lnwire.ChannelUpdate1: + u.Signature = s + case *lnwire.ChannelUpdate2: + u.Signature = s + } + + return nil } cfg := Config{ @@ -182,7 +191,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) events: make(map[time.Time]channeldb.ForwardingEvent), }, FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate1, error) { + lnwire.ChannelUpdate, error) { return &lnwire.ChannelUpdate1{ ShortChannelID: scid, @@ -734,7 +743,7 @@ type mockChannelLink struct { checkHtlcForwardResult *LinkError failAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate1 + incoming bool) lnwire.ChannelUpdate confirmedZC bool } @@ -869,7 +878,7 @@ func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { } func (f *mockChannelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { + sid lnwire.ShortChannelID, incoming bool) lnwire.ChannelUpdate) { f.failAliasUpdate = closure } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 2c7a6287455..58506a99867 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -9,7 +9,6 @@ import ( "sync/atomic" "time" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -174,7 +173,7 @@ type Config struct { // provide payment senders our latest policy when sending encrypted // error messages. FetchLastChannelUpdate func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate1, error) + lnwire.ChannelUpdate, error) // Notifier is an instance of a chain notifier that we'll use to signal // the switch when a new block has arrived. @@ -221,8 +220,7 @@ type Config struct { // option_scid_alias channels. This avoids a potential privacy leak by // replacing the public, confirmed SCID with the alias in the // ChannelUpdate. - SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) + SignAliasUpdate func(u lnwire.ChannelUpdate) error // IsAlias returns whether or not a given SCID is an alias. IsAlias func(scid lnwire.ShortChannelID) bool @@ -2616,7 +2614,7 @@ func (s *Switch) failMailboxUpdate(outgoingScid, // and the caller is expected to handle this properly. In this case, a return // to the original non-alias behavior is expected. func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate1 { + incoming bool) lnwire.ChannelUpdate { // This function does not defer the unlocking because of the database // lookups for ChannelUpdate. @@ -2645,13 +2643,8 @@ func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, } // Replace the baseScid with the passed-in alias. - update.ShortChannelID = scid - sig, err := s.cfg.SignAliasUpdate(update) - if err != nil { - return nil - } - - update.Signature, err = lnwire.NewSigFromSignature(sig) + update.SetSCID(scid) + err = s.cfg.SignAliasUpdate(update) if err != nil { return nil } @@ -2671,13 +2664,8 @@ func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, // In the incoming case, we want to ensure that we don't leak // the UTXO in case the channel is private. In the outgoing // case, since the alias was used, we do the same thing. - update.ShortChannelID = scid - sig, err := s.cfg.SignAliasUpdate(update) - if err != nil { - return nil - } - - update.Signature, err = lnwire.NewSigFromSignature(sig) + update.SetSCID(scid) + err = s.cfg.SignAliasUpdate(update) if err != nil { return nil } @@ -2726,13 +2714,8 @@ func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, // We will replace and sign the update with the first alias. // Since this happens on the incoming side, it's not actually // possible to know what the sender used in the onion. - update.ShortChannelID = aliases[0] - sig, err := s.cfg.SignAliasUpdate(update) - if err != nil { - return nil - } - - update.Signature, err = lnwire.NewSigFromSignature(sig) + update.SetSCID(aliases[0]) + err := s.cfg.SignAliasUpdate(update) if err != nil { return nil } @@ -2794,7 +2777,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, // sure that HTLC is not from the source node. if s.cfg.RejectHTLC { failure := NewDetailedLinkError( - &lnwire.FailChannelDisabled{}, + &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, OutgoingFailureForwardsDisabled, ) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 825ee6c652b..6220b71c997 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3382,7 +3382,10 @@ func TestHtlcNotifier(t *testing.T) { return getThreeHopEvents( channels, htlcID, ts, htlc, hops, &LinkError{ - msg: &lnwire.FailChannelDisabled{}, + //nolint:lll + msg: &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, FailureDetail: OutgoingFailureForwardsDisabled, }, preimage, @@ -3951,7 +3954,7 @@ func TestSwitchHoldForward(t *testing.T) { // Simulate an error during the composition of the failure message. currentCallback := c.s.cfg.FetchLastChannelUpdate c.s.cfg.FetchLastChannelUpdate = func( - lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { + lnwire.ShortChannelID) (lnwire.ChannelUpdate, error) { return nil, errors.New("cannot fetch update") } @@ -5045,7 +5048,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) + require.Equal(t, aliceAlias, failMsg.Update.SCID()) case <-s2.quit: t.Fatal("switch shutting down, failed to forward packet") } @@ -5228,7 +5231,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive fail packet") } @@ -5428,7 +5431,7 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive failure") } @@ -5583,7 +5586,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - failScid := failureMsg.Update.ShortChannelID + failScid := failureMsg.Update.SCID() isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index abd48e806dd..1c5be3976cc 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -91,8 +91,8 @@ func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, // mockGetChanUpdateMessage helper function which returns topology update of // the channel -func mockGetChanUpdateMessage(_ lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, - error) { +func mockGetChanUpdateMessage(_ lnwire.ShortChannelID) ( + lnwire.ChannelUpdate, error) { return &lnwire.ChannelUpdate1{ Signature: wireSig, diff --git a/keychain/derivation.go b/keychain/derivation.go index 2b1c43444eb..050058ca6ce 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -262,6 +262,12 @@ type SingleKeyMessageSigner interface { // hashing it first, with the wrapped private key and returns the // signature in the compact, public key recoverable format. SignMessageCompact(message []byte, doubleHash bool) ([]byte, error) + + // SignMessageSchnorr signs the given message, single or double SHA256 + // hashing it first, with the private key described in the key locator + // and the optional Taproot tweak applied to the private key. + SignMessageSchnorr(keyLoc KeyLocator, msg []byte, doubleHash bool, + taprootTweak, tag []byte) (*schnorr.Signature, error) } // ECDHRing is an interface that abstracts away basic low-level ECDH shared key diff --git a/keychain/signer.go b/keychain/signer.go index 9605e72ec1f..6fd856f7f4f 100644 --- a/keychain/signer.go +++ b/keychain/signer.go @@ -3,7 +3,9 @@ package keychain import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" ) func NewPubKeyMessageSigner(pubKey *btcec.PublicKey, keyLoc KeyLocator, @@ -42,6 +44,14 @@ func (p *PubKeyMessageSigner) SignMessageCompact(msg []byte, return p.digestSigner.SignMessageCompact(p.keyLoc, msg, doubleHash) } +func (p *PubKeyMessageSigner) SignMessageSchnorr(keyLoc KeyLocator, msg []byte, + doubleHash bool, taprootTweak, tag []byte) (*schnorr.Signature, error) { + + return p.digestSigner.SignMessageSchnorr( + keyLoc, msg, doubleHash, taprootTweak, tag, + ) +} + func NewPrivKeyMessageSigner(privKey *btcec.PrivateKey, keyLoc KeyLocator) *PrivKeyMessageSigner { @@ -88,5 +98,28 @@ func (p *PrivKeyMessageSigner) SignMessageCompact(msg []byte, return ecdsa.SignCompact(p.privKey, digest, true) } +func (p *PrivKeyMessageSigner) SignMessageSchnorr(_ KeyLocator, msg []byte, + doubleHash bool, taprootTweak, tag []byte) (*schnorr.Signature, error) { + + // If a tag was provided, we need to take the tagged hash of the input. + var digest []byte + switch { + case len(tag) > 0: + taggedHash := chainhash.TaggedHash(tag, msg) + digest = taggedHash[:] + case doubleHash: + digest = chainhash.DoubleHashB(msg) + default: + digest = chainhash.HashB(msg) + } + + privKey := p.privKey + if len(taprootTweak) > 0 { + privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak) + } + + return schnorr.Sign(privKey, digest) +} + var _ SingleKeyMessageSigner = (*PubKeyMessageSigner)(nil) var _ SingleKeyMessageSigner = (*PrivKeyMessageSigner)(nil) diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index a5d199b8e66..fbd56d7c7a7 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -571,7 +571,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // key is used to sign the invoice so that the sender // can derive the true pub key of the recipient. if !blind { - return cfg.NodeSigner.SignMessageCompact( + return cfg.NodeSigner.SignMessageCompactNoKeyLoc( //nolint:lll msg, false, ) } diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index b37df7218f8..49c53d3c592 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1523,8 +1523,14 @@ func marshallWireError(msg lnwire.FailureMessage, response.Code = lnrpc.Failure_INVALID_REALM case *lnwire.FailExpiryTooSoon: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_EXPIRY_TOO_SOON - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 case *lnwire.FailExpiryTooFar: response.Code = lnrpc.Failure_EXPIRY_TOO_FAR @@ -1542,28 +1548,58 @@ func marshallWireError(msg lnwire.FailureMessage, response.OnionSha_256 = onionErr.OnionSHA256[:] case *lnwire.FailAmountBelowMinimum: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_AMOUNT_BELOW_MINIMUM - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailFeeInsufficient: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_FEE_INSUFFICIENT - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailIncorrectCltvExpiry: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_INCORRECT_CLTV_EXPIRY - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.CltvExpiry = onionErr.CltvExpiry case *lnwire.FailChannelDisabled: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_CHANNEL_DISABLED - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.Flags = uint32(onionErr.Flags) case *lnwire.FailTemporaryChannelFailure: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE - response.ChannelUpdate = marshallChannelUpdate(onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 case *lnwire.FailRequiredNodeFeatureMissing: response.Code = lnrpc.Failure_REQUIRED_NODE_FEATURE_MISSING @@ -1605,24 +1641,49 @@ func marshallWireError(msg lnwire.FailureMessage, // marshallChannelUpdate marshalls a channel update as received over the wire to // the router rpc format. -func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { +func marshallChannelUpdate(update lnwire.ChannelUpdate) (*lnrpc.ChannelUpdate, + *lnrpc.ChannelUpdate2, error) { + if update == nil { - return nil - } - - return &lnrpc.ChannelUpdate{ - Signature: update.Signature.RawBytes(), - ChainHash: update.ChainHash[:], - ChanId: update.ShortChannelID.ToUint64(), - Timestamp: update.Timestamp, - MessageFlags: uint32(update.MessageFlags), - ChannelFlags: uint32(update.ChannelFlags), - TimeLockDelta: uint32(update.TimeLockDelta), - HtlcMinimumMsat: uint64(update.HtlcMinimumMsat), - BaseFee: update.BaseFee, - FeeRate: update.FeeRate, - HtlcMaximumMsat: uint64(update.HtlcMaximumMsat), - ExtraOpaqueData: update.ExtraOpaqueData, + return nil, nil, nil + } + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + return &lnrpc.ChannelUpdate{ + Signature: upd.Signature.RawBytes(), + ChainHash: upd.ChainHash[:], + ChanId: upd.ShortChannelID.ToUint64(), + Timestamp: upd.Timestamp, + MessageFlags: uint32(upd.MessageFlags), + ChannelFlags: uint32(upd.ChannelFlags), + TimeLockDelta: uint32(upd.TimeLockDelta), + HtlcMinimumMsat: uint64(upd.HtlcMinimumMsat), + BaseFee: upd.BaseFee, + FeeRate: upd.FeeRate, + HtlcMaximumMsat: uint64(upd.HtlcMaximumMsat), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil, nil + + case *lnwire.ChannelUpdate2: + return nil, &lnrpc.ChannelUpdate2{ + Signature: upd.Signature.RawBytes(), + ChainHash: upd.ChainHash.Val[:], + ChanId: upd.ShortChannelID.Val.ToUint64(), + BlockHeight: upd.BlockHeight.Val, + DisabledFlags: uint32(upd.DisabledFlags.Val), + Direction: upd.SecondPeer.IsSome(), + TimeLockDelta: uint32(upd.CLTVExpiryDelta.Val), + BaseFee: upd.FeeBaseMsat.Val, + FeeRate: upd.FeeProportionalMillionths.Val, + HtlcMinimumMsat: uint64(upd.HTLCMinimumMsat.Val), + HtlcMaximumMsat: uint64(upd.HTLCMaximumMsat.Val), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil + + default: + return nil, nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) } } diff --git a/lntest/mock/signer.go b/lntest/mock/signer.go index 1d30204ea9c..a1853f353d0 100644 --- a/lntest/mock/signer.go +++ b/lntest/mock/signer.go @@ -205,3 +205,68 @@ func (s *SingleSigner) SignMessage(keyLoc keychain.KeyLocator, } return ecdsa.Sign(s.Privkey, digest), nil } + +// SignMessageCompact signs the given message, single or double SHA256 hashing +// it first, with the private key described in the key locator and returns the +// signature in the compact, public key recoverable format. +// +// NOTE: This is part of the keychain.MessageSignerRing interface. +func (s *SingleSigner) SignMessageCompact(keyLoc keychain.KeyLocator, + msg []byte, doubleHash bool) ([]byte, error) { + + mockKeyLoc := s.KeyLoc + if s.KeyLoc.IsEmpty() { + mockKeyLoc = idKeyLoc + } + + if keyLoc != mockKeyLoc { + return nil, fmt.Errorf("unknown public key") + } + + var digest []byte + if doubleHash { + digest = chainhash.DoubleHashB(msg) + } else { + digest = chainhash.HashB(msg) + } + + return ecdsa.SignCompact(s.Privkey, digest, true) +} + +// SignMessageSchnorr signs the given message, single or double SHA256 hashing +// it first, with the private key described in the key locator and the optional +// Taproot tweak applied to the private key. +// +// NOTE: this is part of the keychain.MessageSignerRing interface. +func (s *SingleSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, + msg []byte, doubleHash bool, taprootTweak, tag []byte) ( + *schnorr.Signature, error) { + + mockKeyLoc := s.KeyLoc + if s.KeyLoc.IsEmpty() { + mockKeyLoc = idKeyLoc + } + + if keyLoc != mockKeyLoc { + return nil, fmt.Errorf("unknown public key") + } + + privKey := s.Privkey + if len(taprootTweak) > 0 { + privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak) + } + + // If a tag was provided, we need to take the tagged hash of the input. + var digest []byte + switch { + case len(tag) > 0: + taggedHash := chainhash.TaggedHash(tag, msg) + digest = taggedHash[:] + case doubleHash: + digest = chainhash.DoubleHashB(msg) + default: + digest = chainhash.HashB(msg) + } + + return schnorr.Sign(privKey, digest) +} diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 9a669271d9d..1def2d9f97e 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -600,7 +600,7 @@ func (f *FailInvalidOnionKey) Error() string { // unable to pull out a fully valid version, then we'll fall back to the // regular parsing mechanism which includes the length prefix an NO type byte. func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, - chanUpdate *ChannelUpdate1, pver uint32) error { + pver uint32) (ChannelUpdate, error) { // Instantiate a LimitReader because there may be additional data // present after the channel update. Without limiting the stream, the @@ -613,28 +613,50 @@ func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, // buffer so we can decide how to parse the remainder of it. maybeTypeBytes, err := r.Peek(2) if err != nil { - return err + return nil, err + } + + var ( + typeInt = binary.BigEndian.Uint16(maybeTypeBytes) + chanUpdate ChannelUpdate + hasTypeBytes bool + ) + switch typeInt { + case MsgChannelUpdate: + chanUpdate = &ChannelUpdate1{} + hasTypeBytes = true + + case MsgChannelUpdate2: + chanUpdate = &ChannelUpdate2{} + hasTypeBytes = true + + default: + // Some older nodes will not have the type prefix in front of + // their channel updates as there was initially some ambiguity + // in the spec. This should ony be the case for the + // ChannelUpdate2 message. + chanUpdate = &ChannelUpdate1{} } - // Some nodes well prefix an additional set of bytes in front of their - // channel updates. These bytes will _almost_ always be 258 or the type - // of the ChannelUpdate message. - typeInt := binary.BigEndian.Uint16(maybeTypeBytes) - if typeInt == MsgChannelUpdate { + if hasTypeBytes { // At this point it's likely the case that this is a channel // update message with its type prefixed, so we'll snip off the // first two bytes and parse it as normal. var throwAwayTypeBytes [2]byte _, err := r.Read(throwAwayTypeBytes[:]) if err != nil { - return err + return nil, err } } // At this pint, we've either decided to keep the entire thing, or snip // off the first two bytes. In either case, we can just read it as // normal. - return chanUpdate.Decode(r, pver) + if err = chanUpdate.Decode(r, pver); err != nil { + return nil, err + } + + return chanUpdate, nil } // FailTemporaryChannelFailure is if an otherwise unspecified transient error @@ -647,12 +669,12 @@ type FailTemporaryChannelFailure struct { // which caused the failure. // // NOTE: This field is optional. - Update *ChannelUpdate1 + Update ChannelUpdate } // NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. func NewTemporaryChannelFailure( - update *ChannelUpdate1) *FailTemporaryChannelFailure { + update ChannelUpdate) *FailTemporaryChannelFailure { return &FailTemporaryChannelFailure{Update: update} } @@ -687,11 +709,14 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { } if length != 0 { - f.Update = &ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update } return nil @@ -722,12 +747,12 @@ type FailAmountBelowMinimum struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. func NewAmountBelowMinimum(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailAmountBelowMinimum { + update ChannelUpdate) *FailAmountBelowMinimum { return &FailAmountBelowMinimum{ HtlcMsat: htlcMsat, @@ -763,11 +788,16 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -778,7 +808,7 @@ func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we @@ -792,12 +822,13 @@ type FailFeeInsufficient struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewFeeInsufficient creates new instance of the FailFeeInsufficient. func NewFeeInsufficient(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailFeeInsufficient { + update ChannelUpdate) *FailFeeInsufficient { + return &FailFeeInsufficient{ HtlcMsat: htlcMsat, Update: update, @@ -832,11 +863,14 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -847,7 +881,7 @@ func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailIncorrectCltvExpiry is returned if outgoing cltv value does not match @@ -863,12 +897,12 @@ type FailIncorrectCltvExpiry struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. func NewIncorrectCltvExpiry(cltvExpiry uint32, - update ChannelUpdate1) *FailIncorrectCltvExpiry { + update ChannelUpdate) *FailIncorrectCltvExpiry { return &FailIncorrectCltvExpiry{ CltvExpiry: cltvExpiry, @@ -901,11 +935,14 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } + + f.Update = update - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + return nil } // Encode writes the failure in bytes stream. @@ -916,7 +953,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them @@ -926,11 +963,11 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { type FailExpiryTooSoon struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. -func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon { +func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { return &FailExpiryTooSoon{ Update: update, } @@ -959,18 +996,21 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Encode(w *bytes.Buffer, pver uint32) error { - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailChannelDisabled is returned if the channel is disabled, we tell them the @@ -985,12 +1025,12 @@ type FailChannelDisabled struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewChannelDisabled creates new instance of the FailChannelDisabled. func NewChannelDisabled(flags uint16, - update ChannelUpdate1) *FailChannelDisabled { + update ChannelUpdate) *FailChannelDisabled { return &FailChannelDisabled{ Flags: flags, @@ -1026,11 +1066,14 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -1041,7 +1084,7 @@ func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not @@ -1514,7 +1557,7 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { // writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate ChannelUpdate, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 37cd94b8acd..6d540586921 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,7 +20,7 @@ var ( testType = uint64(3) testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) - testChannelUpdate = ChannelUpdate1{ + testChannelUpdate = &ChannelUpdate1{ Signature: sig, ShortChannelID: NewShortChanIDFromInt(1), Timestamp: 1, @@ -46,7 +46,7 @@ var onionFailures = []FailureMessage{ NewInvalidOnionVersion(testOnionHash), NewInvalidOnionHmac(testOnionHash), NewInvalidOnionKey(testOnionHash), - NewTemporaryChannelFailure(&testChannelUpdate), + NewTemporaryChannelFailure(testChannelUpdate), NewTemporaryChannelFailure(nil), NewAmountBelowMinimum(testAmount, testChannelUpdate), NewFeeInsufficient(testAmount, testChannelUpdate), @@ -137,9 +137,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // Now that we have the set of bytes encoded, we'll ensure that we're // able to decode it using our compatibility method, as it's a regular // encoded channel update message. - var newChanUpdate ChannelUpdate1 - err := parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate, 0, + newChanUpdate, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -164,9 +163,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // We should be able to properly parse the encoded channel update // message even with the extra two bytes. - var newChanUpdate2 ChannelUpdate1 - err = parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate2, 0, + newChanUpdate2, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -185,7 +183,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // raw serialized length. var b bytes.Buffer update := testChannelUpdate - trueUpdateLength, err := WriteMessage(&b, &update, 0) + trueUpdateLength, err := WriteMessage(&b, update, 0) if err != nil { t.Fatalf("unable to write update: %v", err) } @@ -193,7 +191,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // Next, we'll use the function to encode the update as we would in a // onion error message. var errorBuf bytes.Buffer - err = writeOnionErrorChanUpdate(&errorBuf, &update, 0) + err = writeOnionErrorChanUpdate(&errorBuf, update, 0) require.NoError(t, err, "unable to encode onion error") // Finally, read the length encoded and ensure that it matches the raw diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index de1dc427734..a75db58ddac 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -7,9 +7,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -49,7 +49,10 @@ type ChanStatusConfig struct { OurKeyLoc keychain.KeyLocator // MessageSigner signs messages that validate under OurPubKey. - MessageSigner lnwallet.MessageSigner + MessageSigner keychain.MessageSignerRing + + // BestBlockView gives access to the current best block. + BestBlockView chainntnfs.BestBlockView // IsChannelActive checks whether the channel identified by the provided // ChannelID is considered active. This should only return true if the @@ -60,7 +63,7 @@ type ChanStatusConfig struct { // ApplyChannelUpdate processes new ChannelUpdates signed by our node by // updating our local routing table and broadcasting the update to our // peers. - ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint, + ApplyChannelUpdate func(lnwire.ChannelUpdate, *wire.OutPoint, bool) error // DB stores the set of channels that are to be monitored. @@ -634,9 +637,14 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, return err } + height, err := m.cfg.BestBlockView.BestHeight() + if err != nil { + return err + } + err = SignChannelUpdate( m.cfg.MessageSigner, m.cfg.OurKeyLoc, chanUpdate, - ChanUpdSetDisable(disabled), ChanUpdSetTimestamp, + ChanUpdSetDisable(disabled), ChanUpdSetTimestamp(height), ) if err != nil { return err @@ -650,7 +658,7 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate1, bool, error) { + lnwire.ChannelUpdate, bool, error) { // Get the edge info and policies for this channel from the graph. info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op) @@ -681,7 +689,7 @@ func (m *ChanStatusManager) loadInitialChanState( // Determine the channel's starting status by inspecting the disable bit // on last announcement we sent out. var initialStatus ChanStatus - if lastUpdate.ChannelFlags&lnwire.ChanUpdateDisabled == 0 { + if !lastUpdate.IsDisabled() { initialStatus = ChanStatusEnabled } else { initialStatus = ChanStatusDisabled diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 1288056a883..e5e0e48b9f8 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" @@ -135,7 +136,7 @@ type mockGraph struct { chanPols2 map[wire.OutPoint]*models.ChannelEdgePolicy1 sidToCid map[lnwire.ShortChannelID]wire.OutPoint - updates chan *lnwire.ChannelUpdate1 + updates chan lnwire.ChannelUpdate } func newMockGraph(t *testing.T, numChannels int, @@ -147,7 +148,7 @@ func newMockGraph(t *testing.T, numChannels int, chanPols1: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), chanPols2: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), - updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), + updates: make(chan lnwire.ChannelUpdate, 2*numChannels), } for i := 0; i < numChannels; i++ { @@ -186,46 +187,47 @@ func (g *mockGraph) FetchChannelEdgesByOutpoint( return info, pol1, pol2, nil } -func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, +func (g *mockGraph) ApplyChannelUpdate(update lnwire.ChannelUpdate, op *wire.OutPoint, private bool) error { g.mu.Lock() defer g.mu.Unlock() - outpoint, ok := g.sidToCid[update.ShortChannelID] + outpoint, ok := g.sidToCid[update.SCID()] if !ok { return fmt.Errorf("unknown short channel id: %v", - update.ShortChannelID) + update.SCID()) } pol1 := g.chanPols1[outpoint] pol2 := g.chanPols2[outpoint] - // Determine which policy we should update by making the flags on the // policies and updates, and seeing which match up. var update1 bool + switch { - case update.ChannelFlags&lnwire.ChanUpdateDirection == - pol1.ChannelFlags&lnwire.ChanUpdateDirection: + case update.IsNode1() == pol1.IsNode1(): update1 = true - case update.ChannelFlags&lnwire.ChanUpdateDirection == - pol2.ChannelFlags&lnwire.ChanUpdateDirection: + case update.IsNode1() == pol2.IsNode1(): update1 = false default: return fmt.Errorf("unable to find policy to update") } - timestamp := time.Unix(int64(update.Timestamp), 0) + upd, ok := update.(*lnwire.ChannelUpdate1) + if !ok { + return fmt.Errorf("expected channel update 1") + } + timestamp := time.Unix(int64(upd.Timestamp), 0) policy := &models.ChannelEdgePolicy1{ - ChannelID: update.ShortChannelID.ToUint64(), - ChannelFlags: update.ChannelFlags, + ChannelID: upd.ShortChannelID.ToUint64(), + ChannelFlags: upd.ChannelFlags, LastUpdate: timestamp, SigBytes: testSigBytes, } - if update1 { g.chanPols1[outpoint] = policy } else { @@ -344,6 +346,7 @@ func newManagerCfg(t *testing.T, numChannels int, ApplyChannelUpdate: graph.ApplyChannelUpdate, DB: graph, Graph: graph, + BestBlockView: &mockBlockView{}, } return cfg, graph, htlcSwitch @@ -515,23 +518,23 @@ func (h *testHarness) assertUpdates(channels []*channeldb.OpenChannel, for { select { case upd := <-h.graph.updates: + scid := upd.SCID() + // Assert that the received short channel id is one that // we expect. If no updates were expected, this will // always fail on the first update received. - if _, ok := expSids[upd.ShortChannelID]; !ok { + if _, ok := expSids[scid]; !ok { h.t.Fatalf("received update for unexpected "+ - "short chan id: %v", upd.ShortChannelID) + "short chan id: %v", scid) } // Assert that the disabled bit is set properly. - enabled := upd.ChannelFlags&lnwire.ChanUpdateDisabled != - lnwire.ChanUpdateDisabled - if expEnabled != enabled { + if expEnabled != !upd.IsDisabled() { h.t.Fatalf("expected enabled: %v, actual: %v", - expEnabled, enabled) + expEnabled, !upd.IsDisabled()) } - recvdSids[upd.ShortChannelID] = struct{}{} + recvdSids[scid] = struct{}{} case <-timeout: // Time is up, assert that the correct number of unique @@ -937,3 +940,11 @@ func TestChanStatusManagerStateMachine(t *testing.T) { }) } } + +type mockBlockView struct { + chainntnfs.BestBlockView +} + +func (m *mockBlockView) BestHeight() (uint32, error) { + return 0, nil +} diff --git a/netann/channel_update.go b/netann/channel_update.go index 902b8a17663..41ff9f1f397 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -18,37 +17,47 @@ import ( var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate") // ChannelUpdateModifier is a closure that makes in-place modifications to an -// lnwire.ChannelUpdate. -type ChannelUpdateModifier func(*lnwire.ChannelUpdate1) +type ChannelUpdateModifier func(lnwire.ChannelUpdate) // ChanUpdSetDisable is a functional option that sets the disabled channel flag // if disabled is true, and clears the bit otherwise. func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { - return func(update *lnwire.ChannelUpdate1) { - if disabled { - // Set the bit responsible for marking a channel as - // disabled. - update.ChannelFlags |= lnwire.ChanUpdateDisabled - } else { - // Clear the bit responsible for marking a channel as - // disabled. - update.ChannelFlags &= ^lnwire.ChanUpdateDisabled - } + return func(update lnwire.ChannelUpdate) { + update.SetDisabledFlag(disabled) } } // ChanUpdSetTimestamp is a functional option that sets the timestamp of the // update to the current time, or increments it if the timestamp is already in // the future. -func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { - newTimestamp := uint32(time.Now().Unix()) - if newTimestamp <= update.Timestamp { - // Increment the prior value to ensure the timestamp - // monotonically increases, otherwise the update won't - // propagate. - newTimestamp = update.Timestamp + 1 +func ChanUpdSetTimestamp(bestBlockHeight uint32) ChannelUpdateModifier { + return func(update lnwire.ChannelUpdate) { + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + newTimestamp := uint32(time.Now().Unix()) + if newTimestamp <= upd.Timestamp { + // Increment the prior value to ensure the + // timestamp monotonically increases, otherwise + // the update won't propagate. + newTimestamp = upd.Timestamp + 1 + } + upd.Timestamp = newTimestamp + + case *lnwire.ChannelUpdate2: + newBlockHeight := bestBlockHeight + if newBlockHeight <= upd.BlockHeight.Val { + // Increment the prior value to ensure the + // blockHeight monotonically increases, + // otherwise the update won't propagate. + newBlockHeight = upd.BlockHeight.Val + 1 + } + upd.BlockHeight.Val = newBlockHeight + + default: + log.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) + } } - update.Timestamp = newTimestamp } // SignChannelUpdate applies the given modifiers to the passed @@ -57,24 +66,54 @@ func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { // monotonically increase from the prior. // // NOTE: This method modifies the given update. -func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, - update *lnwire.ChannelUpdate1, mods ...ChannelUpdateModifier) error { +func SignChannelUpdate(signer keychain.MessageSignerRing, + keyLoc keychain.KeyLocator, update lnwire.ChannelUpdate, + mods ...ChannelUpdateModifier) error { // Apply the requested changes to the channel update. for _, modifier := range mods { modifier(update) } - // Create the DER-encoded ECDSA signature over the message digest. - sig, err := SignAnnouncement(signer, keyLoc, update) - if err != nil { - return err - } + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + data, err := upd.DataToSign() + if err != nil { + return err + } + + sig, err := signer.SignMessage(keyLoc, data, true) + if err != nil { + return err + } + + // Parse the DER-encoded signature into a fixed-size 64-byte + // array. + upd.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + + case *lnwire.ChannelUpdate2: + data, err := upd.DataToSign() + if err != nil { + return err + } + + sig, err := signer.SignMessageSchnorr( + keyLoc, data, false, nil, upd.DigestTag(), + ) + if err != nil { + return err + } - // Parse the DER-encoded signature into a fixed-size 64-byte array. - update.Signature, err = lnwire.NewSigFromSignature(sig) - if err != nil { - return err + upd.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelUpdate: %T", update) } return nil @@ -86,12 +125,12 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, info models.ChannelEdgeInfo, policies ...models.ChannelEdgePolicy) ( - *lnwire.ChannelUpdate1, error) { + lnwire.ChannelUpdate, error) { // Helper function to extract the owner of the given policy. - owner := func(edge *models.ChannelEdgePolicy1) []byte { + owner := func(edge models.ChannelEdgePolicy) []byte { var pubKey *btcec.PublicKey - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + if edge.IsNode1() { pubKey, _ = info.NodeKey1() } else { pubKey, _ = info.NodeKey2() @@ -107,13 +146,7 @@ func ExtractChannelUpdate(ownerPubKey []byte, // Extract the channel update from the policy we own, if any. for _, edge := range policies { - e, ok := edge.(*models.ChannelEdgePolicy1) - if !ok { - return nil, fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", edge) - } - - if edge != nil && bytes.Equal(ownerPubKey, owner(e)) { + if edge != nil && bytes.Equal(ownerPubKey, owner(edge)) { return ChannelUpdateFromEdge(info, edge) } } @@ -124,12 +157,15 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the // given edge info and policy. func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, - policy models.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { switch p := policy.(type) { case *models.ChannelEdgePolicy1: return unsignedChanPolicy1ToUpdate(chainHash, p), nil + case *models.ChannelEdgePolicy2: + return unsignedChanPolicy2ToUpdate(chainHash, p), nil + default: return nil, fmt.Errorf("unhandled implementation of the "+ "models.ChanelEdgePolicy interface: %T", policy) @@ -154,10 +190,36 @@ func unsignedChanPolicy1ToUpdate(chainHash chainhash.Hash, } } +func unsignedChanPolicy2ToUpdate(chainHash chainhash.Hash, + policy *models.ChannelEdgePolicy2) *lnwire.ChannelUpdate2 { + + update := &lnwire.ChannelUpdate2{ + ShortChannelID: policy.ShortChannelID, + BlockHeight: policy.BlockHeight, + DisabledFlags: policy.DisabledFlags, + SecondPeer: policy.SecondPeer, + CLTVExpiryDelta: policy.CLTVExpiryDelta, + HTLCMinimumMsat: policy.HTLCMinimumMsat, + HTLCMaximumMsat: policy.HTLCMaximumMsat, + FeeBaseMsat: policy.FeeBaseMsat, + FeeProportionalMillionths: policy.FeeProportionalMillionths, + ExtraOpaqueData: policy.ExtraOpaqueData, + } + update.ChainHash.Val = chainHash + + return update +} + // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given // edge info and policy. func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, - policy models.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { + + return signedChannelUpdateFromEdge(info.GetChainHash(), policy) +} + +func signedChannelUpdateFromEdge(chainHash chainhash.Hash, + policy models.ChannelEdgePolicy) (lnwire.ChannelUpdate, error) { switch p := policy.(type) { case *models.ChannelEdgePolicy1: @@ -171,7 +233,23 @@ func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, return nil, err } - update := unsignedChanPolicy1ToUpdate(info.GetChainHash(), p) + update := unsignedChanPolicy1ToUpdate(chainHash, p) + update.Signature = s + + return update, nil + + case *models.ChannelEdgePolicy2: + sig, err := p.Signature.ToSignature() + if err != nil { + return nil, err + } + + s, err := lnwire.NewSigFromSignature(sig) + if err != nil { + return nil, err + } + + update := unsignedChanPolicy2ToUpdate(chainHash, p) update.Signature = s return update, nil diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go index 689b48caba4..7dfc753fb31 100644 --- a/netann/channel_update_test.go +++ b/netann/channel_update_test.go @@ -14,6 +14,8 @@ import ( ) type mockSigner struct { + keychain.MessageSignerRing + err error } @@ -43,7 +45,7 @@ type updateDisableTest struct { startEnabled bool disable bool startTime time.Time - signer lnwallet.MessageSigner + signer keychain.MessageSignerRing expErr error } @@ -131,7 +133,7 @@ func TestUpdateDisableFlag(t *testing.T) { err := netann.SignChannelUpdate( tc.signer, testKeyLoc, newUpdate, netann.ChanUpdSetDisable(tc.disable), - netann.ChanUpdSetTimestamp, + netann.ChanUpdSetTimestamp(0), ) var fail bool diff --git a/netann/node_signer.go b/netann/node_signer.go index 4cb8cea01c8..aeb95882fe4 100644 --- a/netann/node_signer.go +++ b/netann/node_signer.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" ) // NodeSigner is an implementation of the MessageSigner interface backed by the @@ -43,15 +43,52 @@ func (n *NodeSigner) SignMessage(keyLoc keychain.KeyLocator, return sig, nil } -// SignMessageCompact signs a single or double sha256 digest of the msg +// SignMessageCompactNoKeyLoc signs a single or double sha256 digest of the msg // parameter under the resident node's private key. The returned signature is a -// pubkey-recoverable signature. -func (n *NodeSigner) SignMessageCompact(msg []byte, doubleHash bool) ([]byte, - error) { +// pubkey-recoverable signature. No key locator is required for this since the +// NodeSigner already has the key to sign with. +func (n *NodeSigner) SignMessageCompactNoKeyLoc(msg []byte, doubleHash bool) ( + []byte, error) { return n.keySigner.SignMessageCompact(msg, doubleHash) } +// SignMessageCompact signs the given message, single or double SHA256 hashing +// it first, with the private key described in the key locator and returns the +// signature in the compact, public key recoverable format. +// +// NOTE: this is part of the keychain.MessageSignerRing interface. +func (n *NodeSigner) SignMessageCompact(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool) ([]byte, error) { + + // If this isn't our identity public key, then we'll exit early with an + // error as we can't sign with this key. + if keyLoc != n.keySigner.KeyLocator() { + return nil, fmt.Errorf("unknown public key locator") + } + + return n.SignMessageCompactNoKeyLoc(msg, doubleHash) +} + +// SignMessageSchnorr signs the given message, single or double SHA256 hashing +// it first, with the private key described in the key locator and the optional +// Taproot tweak applied to the private key. +// +// NOTE: this is part of the keychain.MessageSignerRing interface. +func (n *NodeSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool, taprootTweak, tag []byte) (*schnorr.Signature, error) { + + // If this isn't our identity public key, then we'll exit early with an + // error as we can't sign with this key. + if keyLoc != n.keySigner.KeyLocator() { + return nil, fmt.Errorf("unknown public key locator") + } + + return n.keySigner.SignMessageSchnorr( + keyLoc, msg, doubleHash, taprootTweak, tag, + ) +} + // A compile time check to ensure that NodeSigner implements the MessageSigner // interface. -var _ lnwallet.MessageSigner = (*NodeSigner)(nil) +var _ keychain.MessageSignerRing = (*NodeSigner)(nil) diff --git a/peer/brontide.go b/peer/brontide.go index 6446f202c87..e1289959b37 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -301,7 +301,7 @@ type Config struct { // FetchLastChanUpdate fetches our latest channel update for a target // channel. - FetchLastChanUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, + FetchLastChanUpdate func(lnwire.ShortChannelID) (lnwire.ChannelUpdate, error) // FundingManager is an implementation of the funding.Controller interface. @@ -1966,6 +1966,7 @@ out: } case *lnwire.ChannelUpdate1, + *lnwire.ChannelUpdate2, *lnwire.ChannelAnnouncement1, *lnwire.ChannelAnnouncement2, *lnwire.NodeAnnouncement, @@ -2242,6 +2243,12 @@ func messageSummary(msg lnwire.Message) string { msg.ShortChannelID.ToUint64(), msg.MessageFlags, msg.ChannelFlags, time.Unix(int64(msg.Timestamp), 0)) + case *lnwire.ChannelUpdate2: + return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+ + "is_disabled=%v, is_node_1=%v, block_height=%v", + msg.ChainHash, msg.ShortChannelID.Val.ToUint64(), + msg.IsDisabled(), msg.IsNode1(), msg.BlockHeight) + case *lnwire.NodeAnnouncement: return fmt.Sprintf("node=%x, update_time=%v", msg.NodeID, time.Unix(int64(msg.Timestamp), 0)) diff --git a/peer/test_utils.go b/peer/test_utils.go index 0575acca55c..4e1e5f60025 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -611,7 +611,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate1, + ApplyChannelUpdate: func(lnwire.ChannelUpdate, *wire.OutPoint, bool) error { return nil @@ -718,8 +718,9 @@ func createTestPeer(t *testing.T) *peerTestCtx { return nil }, PongBuf: make([]byte, lnwire.MaxPongBytes), - FetchLastChanUpdate: func(chanID lnwire.ShortChannelID, - ) (*lnwire.ChannelUpdate1, error) { + FetchLastChanUpdate: func( + chanID lnwire.ShortChannelID) (lnwire.ChannelUpdate, + error) { return &lnwire.ChannelUpdate1{}, nil }, diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index 4a0f7387152..f68e433d30a 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -197,7 +197,7 @@ func TestMissionControl(t *testing.T) { // A node level failure should bring probability of all known channels // back to zero. - ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{})) + ctx.reportFailure(0, lnwire.NewExpiryTooSoon(&lnwire.ChannelUpdate1{})) ctx.expectP(1000, 0) // Check whether history snapshot looks sane. @@ -219,14 +219,14 @@ func TestMissionControlChannelUpdate(t *testing.T) { // Report a policy related failure. Because it is the first, we don't // expect a penalty. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, testAprioriHopProbability) // Report another failure for the same channel. We expect it to be // pruned. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, 0) } diff --git a/routing/mock_test.go b/routing/mock_test.go index e0dab5e7989..d8870b3431c 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -182,7 +182,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, return r, nil } -func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1, +func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ lnwire.ChannelUpdate, _ *btcec.PublicKey, _ *models.CachedEdgePolicy) bool { return false @@ -702,7 +702,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, return args.Get(0).(*route.Route), args.Error(1) } -func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, +func (m *mockPaymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 5244d4d636c..8318f40a487 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -897,7 +897,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, // SendToRoute where there's no payment lifecycle. if p.paySession != nil { policy = p.paySession.GetAdditionalEdgePolicy( - errSource, update.ShortChannelID.ToUint64(), + errSource, update.SCID().ToUint64(), ) if policy != nil { isAdditionalEdge = true @@ -907,7 +907,8 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, // Apply channel update to additional edge policy. if isAdditionalEdge { if !p.paySession.UpdateAdditionalEdge( - update, errSource, policy) { + update, errSource, policy, + ) { log.Debugf("Invalid channel update received: node=%v", errVertex) diff --git a/routing/payment_session.go b/routing/payment_session.go index 88c45e0151e..3010579808f 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -144,8 +144,9 @@ type PaymentSession interface { // (private channels) and applies the update from the message. Returns // a boolean to indicate whether the update has been applied without // error. - UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, - pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool + UpdateAdditionalEdge(msg lnwire.ChannelUpdate, + pubKey *btcec.PublicKey, + policy *models.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query // the ephemeral channel edge policy for additional edges. Returns a nil @@ -431,7 +432,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // validates the message signature and checks it's up to date, then applies the // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. -func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, +func (p *paymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool { // Validate the message signature. @@ -442,10 +443,12 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, return false } + fwdingPolicy := msg.ForwardingPolicy() + // Update channel policy for the additional edge. - policy.TimeLockDelta = msg.TimeLockDelta - policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee) - policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate) + policy.TimeLockDelta = fwdingPolicy.TimeLockDelta + policy.FeeBaseMSat = fwdingPolicy.BaseFee + policy.FeeProportionalMillionths = fwdingPolicy.FeeRate log.Debugf("New private channel update applied: %v", lnutils.SpewLogClosure(msg)) diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index 68b527e5a9c..2a9c153833d 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -164,8 +164,9 @@ var resultTestCases = []resultTestCase{ name: "fail expiry too soon", route: &routeFourHop, failureSrcIdx: 3, - failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), - + failure: lnwire.NewExpiryTooSoon( + &lnwire.ChannelUpdate1{}, + ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ getTestPair(0, 1): failPairResult(0), @@ -267,7 +268,7 @@ var resultTestCases = []resultTestCase{ route: &routeFourHop, failureSrcIdx: 2, failure: lnwire.NewFeeInsufficient( - 0, lnwire.ChannelUpdate1{}, + 0, &lnwire.ChannelUpdate1{}, ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ diff --git a/routing/router.go b/routing/router.go index e9282db6c1e..01d2ebb1390 100644 --- a/routing/router.go +++ b/routing/router.go @@ -284,7 +284,7 @@ type Config struct { // ApplyChannelUpdate can be called to apply a new channel update to the // graph that we received from a payment failure. - ApplyChannelUpdate func(msg *lnwire.ChannelUpdate1) bool + ApplyChannelUpdate func(msg lnwire.ChannelUpdate) bool // ClosedSCIDs is used by the router to fetch closed channels. // @@ -1270,20 +1270,20 @@ func (r *ChannelRouter) sendPayment(ctx context.Context, // extractChannelUpdate examines the error and extracts the channel update. func (r *ChannelRouter) extractChannelUpdate( - failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 { + failure lnwire.FailureMessage) lnwire.ChannelUpdate { - var update *lnwire.ChannelUpdate1 + var update lnwire.ChannelUpdate switch onionErr := failure.(type) { case *lnwire.FailExpiryTooSoon: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailAmountBelowMinimum: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailFeeInsufficient: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailIncorrectCltvExpiry: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailChannelDisabled: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailTemporaryChannelFailure: update = onionErr.Update } diff --git a/routing/router_test.go b/routing/router_test.go index cab0bff27ff..cf2c64794f5 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -15,8 +15,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" @@ -28,6 +30,7 @@ import ( "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -146,7 +149,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, MissionControl: mc, } - graphBuilder := newMockGraphBuilder(graphInstance.graph) + graphBuilder := newMockGraphBuilder(t, graphInstance.graph) router, err := New(Config{ SelfNode: sourceNode.PubKeyBytes, @@ -221,16 +224,50 @@ func createTestCtxFromFile(t *testing.T, // Add valid signature to channel update simulated as error received from the // network. func signErrChanUpdate(t *testing.T, key *btcec.PrivateKey, - errChanUpdate *lnwire.ChannelUpdate1) { + errChanUpdate lnwire.ChannelUpdate) { - chanUpdateMsg, err := errChanUpdate.DataToSign() - require.NoError(t, err, "failed to retrieve data to sign") + signer := &mockSigner{key: key} + err := netann.SignChannelUpdate( + signer, keychain.KeyLocator{}, errChanUpdate, + ) + require.NoError(t, err) +} + +type mockSigner struct { + key *btcec.PrivateKey + keychain.MessageSignerRing +} + +func (s *mockSigner) SignMessage(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool) (*ecdsa.Signature, error) { + + digest := chainhash.DoubleHashB(msg) + sig := ecdsa.Sign(s.key, digest) + + return sig, nil +} + +func (s *mockSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool, taprootTweak, tag []byte) (*schnorr.Signature, + error) { + + var digest []byte + switch { + case len(tag) > 0: + taggedHash := chainhash.TaggedHash(tag, msg) + digest = taggedHash[:] + case doubleHash: + digest = chainhash.DoubleHashB(msg) + default: + digest = chainhash.HashB(msg) + } - digest := chainhash.DoubleHashB(chanUpdateMsg) - sig := ecdsa.Sign(key, digest) + privKey := s.key + if len(taprootTweak) > 0 { + privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak) + } - errChanUpdate.Signature, err = lnwire.NewSigFromSignature(sig) - require.NoError(t, err, "failed to create new signature") + return schnorr.Sign(privKey, digest) } // TestFindRoutesWithFeeLimit asserts that routes found by the FindRoutes method @@ -513,7 +550,7 @@ func TestChannelUpdateValidation(t *testing.T) { func(firstHop lnwire.ShortChannelID) ([32]byte, error) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) @@ -613,11 +650,8 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { ) require.NoError(t, err, "unable to fetch chan id") - edgeUpdToFail, ok := edgeUpdateToFail.(*models.ChannelEdgePolicy1) - require.True(t, ok) - errChanUpdate, err := netann.UnsignedChannelUpdateFromEdge( - chainhash.Hash{}, edgeUpdToFail, + chainhash.Hash{}, edgeUpdateToFail, ) require.NoError(t, err) @@ -636,15 +670,17 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { roasbeefSongokuChanID, ) if firstHop == roasbeefSongoku { - return [32]byte{}, htlcswitch.NewForwardingError( - // Within our error, we'll add a - // channel update which is meant to - // reflect the new fee schedule for the - // node/channel. - &lnwire.FailFeeInsufficient{ - Update: *errChanUpdate, - }, 1, - ) + if firstHop == roasbeefSongoku { + return [32]byte{}, htlcswitch.NewForwardingError( + // Within our error, we'll add a + // channel update which is meant to + // reflect the new fee schedule for the + // node/channel. + &lnwire.FailFeeInsufficient{ + Update: errChanUpdate, + }, 1, + ) + } } return preImage, nil @@ -751,7 +787,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -877,7 +913,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -970,7 +1006,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ - Update: *errChanUpdate, + Update: errChanUpdate, }, 1, ) } @@ -1018,7 +1054,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ - Update: *errChanUpdate, + Update: errChanUpdate, }, 1, ) } @@ -1412,7 +1448,7 @@ func TestSendToRouteStructuredError(t *testing.T) { testCases := map[int]lnwire.FailureMessage{ finalHopIndex: lnwire.NewFailIncorrectDetails(payAmt, 100), 1: &lnwire.FailFeeInsufficient{ - Update: lnwire.ChannelUpdate1{}, + Update: &lnwire.ChannelUpdate1{}, }, } @@ -2940,13 +2976,15 @@ func createDummyLightningPayment(t *testing.T, } type mockGraphBuilder struct { + t *testing.T rejectUpdate bool - updateEdge func(update *models.ChannelEdgePolicy1) error + updateEdge func(update models.ChannelEdgePolicy) error } -func newMockGraphBuilder(graph graph.DB) *mockGraphBuilder { +func newMockGraphBuilder(t *testing.T, graph graph.DB) *mockGraphBuilder { return &mockGraphBuilder{ - updateEdge: func(update *models.ChannelEdgePolicy1) error { + t: t, + updateEdge: func(update models.ChannelEdgePolicy) error { return graph.UpdateEdgePolicy(update) }, } @@ -2956,26 +2994,15 @@ func (m *mockGraphBuilder) setNextReject(reject bool) { m.rejectUpdate = reject } -func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { +func (m *mockGraphBuilder) ApplyChannelUpdate(msg lnwire.ChannelUpdate) bool { if m.rejectUpdate { return false } - err := m.updateEdge(&models.ChannelEdgePolicy1{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - ExtraOpaqueData: msg.ExtraOpaqueData, - }) + policy, err := models.EdgePolicyFromUpdate(msg) + require.NoError(m.t, err) - return err == nil + return m.updateEdge(policy) == nil } type mockChain struct { diff --git a/rpcserver.go b/rpcserver.go index 2396db13ed9..dfa6de53cfc 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1671,7 +1671,7 @@ func (r *rpcServer) SignMessage(_ context.Context, } in.Msg = append(signedMsgPrefix, in.Msg...) - sigBytes, err := r.server.nodeSigner.SignMessageCompact( + sigBytes, err := r.server.nodeSigner.SignMessageCompactNoKeyLoc( in.Msg, !in.SingleHash, ) if err != nil { diff --git a/server.go b/server.go index 763c1e61815..58f5ac7fe68 100644 --- a/server.go +++ b/server.go @@ -16,7 +16,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" @@ -717,6 +716,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ApplyChannelUpdate: s.applyChannelUpdate, DB: s.chanStateDB, Graph: dbs.GraphDB.ChannelGraph(), + BestBlockView: s.cc.BestBlockTracker, } chanStatusMgr, err := netann.NewChanStatusManager(chanStatusMgrCfg) @@ -1740,18 +1740,11 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) { routerCfg.MaxMcHistory = cfg.MaxMcHistory } -// signAliasUpdate takes a ChannelUpdate and returns the signature. This is -// used for option_scid_alias channels where the ChannelUpdate to be sent back -// may differ from what is on disk. -func (s *server) signAliasUpdate(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, - error) { - - data, err := u.DataToSign() - if err != nil { - return nil, err - } - - return s.cc.MsgSigner.SignMessage(s.identityKeyLoc, data, true) +// signAliasUpdate takes a ChannelUpdate and re-signs it. The signature is set +// the update accordingly. This is used for option_scid_alias channels where the +// ChannelUpdate to be sent back may differ from what is on disk. +func (s *server) signAliasUpdate(u lnwire.ChannelUpdate) error { + return netann.SignChannelUpdate(s.cc.KeyRing, s.identityKeyLoc, u) } // createLivenessMonitor creates a set of health checks using our configured @@ -4817,10 +4810,10 @@ func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, err // fetchLastChanUpdate returns a function which is able to retrieve our latest // channel update for a target channel. func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate1, error) { + lnwire.ChannelUpdate, error) { ourPubKey := s.identityECDH.PubKey().SerializeCompressed() - return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { + return func(cid lnwire.ShortChannelID) (lnwire.ChannelUpdate, error) { info, edge1, edge2, err := s.graphBuilder.GetChannelByID(cid) if err != nil { return nil, err @@ -4835,7 +4828,7 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( // applyChannelUpdate applies the channel update to the different sub-systems of // the server. The useAlias boolean denotes whether or not to send an alias in // place of the real SCID. -func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, +func (s *server) applyChannelUpdate(update lnwire.ChannelUpdate, op *wire.OutPoint, useAlias bool) error { var (