diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 1a74e3e1590..764205e6c0d 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -12,7 +12,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightninglabs/neutrino/cache" @@ -192,14 +194,9 @@ type PinnedSyncers map[route.Vertex]struct{} // Config defines the configuration for the service. ALL elements within the // configuration MUST be non-nil for the service to carry out its duties. type Config struct { - // ChainHash is a hash that indicates which resident chain of the - // AuthenticatedGossiper. Any announcements that don't match this - // chain hash will be ignored. - // - // TODO(roasbeef): eventually make into map so can de-multiplex - // incoming announcements - // * also need to do same for Notifier - ChainHash chainhash.Hash + // ChainParams holds the chain parameters for the active network this + // node is participating on. + ChainParams *chaincfg.Params // Graph is the subsystem which is responsible for managing the // topology of lightning network. After incoming channel, node, channel @@ -574,7 +571,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper gossiper.vb = NewValidationBarrier(1000, gossiper.quit) gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ - ChainHash: cfg.ChainHash, + ChainHash: *cfg.ChainParams.GenesisHash, ChanSeries: cfg.ChanSeries, RotateTicker: cfg.RotateTicker, HistoricalSyncTicker: cfg.HistoricalSyncTicker, @@ -1992,10 +1989,29 @@ func (d *AuthenticatedGossiper) processRejectedEdge( } // fetchPKScript fetches the output script for the given SCID. -func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( - []byte, error) { +func (d *AuthenticatedGossiper) fetchPKScript(chanID lnwire.ShortChannelID) ( + txscript.ScriptClass, btcutil.Address, error) { - return lnwallet.FetchPKScriptWithQuit(d.cfg.ChainIO, chanID, d.quit) + pkScript, err := lnwallet.FetchPKScriptWithQuit( + d.cfg.ChainIO, chanID, d.quit, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + scriptClass, addrs, _, err := txscript.ExtractPkScriptAddrs( + pkScript, d.cfg.ChainParams, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + if len(addrs) != 1 { + return txscript.WitnessUnknownTy, nil, fmt.Errorf("expected "+ + "1 address, got: %d", len(addrs)) + } + + return scriptClass, addrs[0], nil } // addNode processes the given node announcement, and adds it to our channel @@ -2482,16 +2498,16 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, ops ...batch.SchedulerOption) ([]networkMsg, bool) { scid := ann.ShortChannelID + chainHash := d.cfg.ChainParams.GenesisHash log.Debugf("Processing ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, scid.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. - if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(ann.ChainHash[:], chainHash[:]) { err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ - ", gossiper on chain=%v", ann.ChainHash, - d.cfg.ChainHash) + ", gossiper on chain=%v", ann.ChainHash, chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( @@ -2942,11 +2958,13 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", nMsg.peer, upd.ShortChannelID.ToUint64()) + chainHash := d.cfg.ChainParams.GenesisHash + // We'll ignore any channel updates that target any chain other than // the set of chains we know of. - if !bytes.Equal(upd.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(upd.ChainHash[:], chainHash[:]) { err := fmt.Errorf("ignoring ChannelUpdate from chain=%v, "+ - "gossiper on chain=%v", upd.ChainHash, d.cfg.ChainHash) + "gossiper on chain=%v", upd.ChainHash, chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( @@ -3700,7 +3718,7 @@ func (d *AuthenticatedGossiper) validateFundingTransaction( // Before we can add the channel to the channel graph, we need to obtain // the full funding outpoint that's encoded within the channel ID. fundingTx, err := lnwallet.FetchFundingTxWrapper( - d.cfg.ChainIO, &scid, d.quit, + d.cfg.ChainIO, scid, d.quit, ) if err != nil { //nolint:ll diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index b48ea8dbdf2..6b67b4d92da 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -616,6 +617,7 @@ func createUpdateAnnouncement(blockHeight uint32, htlcMinMsat := lnwire.MilliSatoshi(100) a := &lnwire.ChannelUpdate1{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -768,6 +770,7 @@ func (ctx *testCtx) createAnnouncementWithoutProof(blockHeight uint32, } a := &lnwire.ChannelAnnouncement1{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, TxIndex: 0, @@ -934,8 +937,9 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( } gossiper := New(Config{ - ChainIO: chain, - Notifier: notifier, + ChainIO: chain, + ChainParams: &chaincfg.MainNetParams, + Notifier: notifier, Broadcast: func(senders map[route.Vertex]struct{}, msgs ...lnwire.Message) error { @@ -1653,6 +1657,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { //nolint:ll gossiper := New(Config{ + ChainParams: &chaincfg.MainNetParams, Notifier: ctx.gossiper.cfg.Notifier, Broadcast: ctx.gossiper.cfg.Broadcast, NotifyWhenOnline: ctx.gossiper.reliableSender.cfg.NotifyWhenOnline, diff --git a/lnwallet/interface.go b/lnwallet/interface.go index 64f85463104..effd1bf9fba 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -754,7 +754,7 @@ func SupportedWallets() []string { // FetchFundingTxWrapper is a wrapper around FetchFundingTx, except that it will // exit when the supplied quit channel is closed. -func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, +func FetchFundingTxWrapper(chain BlockChainIO, chanID lnwire.ShortChannelID, quit chan struct{}) (*wire.MsgTx, error) { txChan := make(chan *wire.MsgTx, 1) @@ -789,7 +789,7 @@ func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, // TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to // later use getblocktxn). func FetchFundingTx(chain BlockChainIO, - chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { + chanID lnwire.ShortChannelID) (*wire.MsgTx, error) { // First fetch the block hash by the block number encoded, then use // that hash to fetch the block itself. @@ -820,7 +820,7 @@ func FetchFundingTx(chain BlockChainIO, // FetchPKScriptWithQuit fetches the output script for the given SCID and exits // early with an error if the provided quit channel is closed before // completion. -func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID, +func FetchPKScriptWithQuit(chain BlockChainIO, chanID lnwire.ShortChannelID, quit chan struct{}) ([]byte, error) { tx, err := FetchFundingTxWrapper(chain, chanID, quit) @@ -829,7 +829,7 @@ func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID, } outputLocator := chanvalidate.ShortChanIDChanLocator{ - ID: *chanID, + ID: chanID, } output, _, err := outputLocator.Locate(tx) diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index 6e893dafdd6..7be90181ef9 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -3,6 +3,8 @@ package lnwire import ( "bytes" "io" + + "github.com/lightningnetwork/lnd/tlv" ) // AnnounceSignatures2 is a direct message between two endpoints of a @@ -14,27 +16,40 @@ type AnnounceSignatures2 struct { // Channel id is better for users and debugging and short channel id is // used for quick test on existence of the particular utxo inside the // blockchain, because it contains information about block. - ChannelID ChannelID + ChannelID tlv.RecordT[tlv.TlvType0, ChannelID] // ShortChannelID is the unique description of the funding transaction. // It is constructed with the most significant 3 bytes as the block // height, the next 3 bytes indicating the transaction index within the // block, and the least significant two bytes indicating the output // index which pays to the channel. - ShortChannelID ShortChannelID + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] // PartialSignature is the combination of the partial Schnorr signature // created for the node's bitcoin key with the partial signature created // for the node's node ID key. - PartialSignature PartialSig - - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + PartialSignature tlv.RecordT[tlv.TlvType4, PartialSig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// NewAnnSigs2 is a constructor for AnnounceSignatures2. +func NewAnnSigs2(chanID ChannelID, scid ShortChannelID, + partialSig PartialSig) *AnnounceSignatures2 { + + return &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID](chanID), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + scid, + ), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + partialSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } } // A compile time check to ensure AnnounceSignatures2 implements the @@ -50,32 +65,29 @@ var _ SizeableMessage = (*AnnounceSignatures2)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { - return ReadElements(r, - &a.ChannelID, - &a.ShortChannelID, - &a.PartialSignature, - &a.ExtraOpaqueData, - ) -} - -// Encode serializes the target AnnounceSignatures2 into the passed io.Writer -// observing the protocol version specified. -// -// This is part of the lnwire.Message interface. -func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { - if err := WriteChannelID(w, a.ChannelID); err != nil { + stream, err := tlv.NewStream(ProduceRecordsSorted( + &a.ChannelID, &a.ShortChannelID, &a.PartialSignature, + )...) + if err != nil { return err } - if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { return err } - if err := WriteElement(w, a.PartialSignature); err != nil { - return err - } + a.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) - return WriteBytes(w, a.ExtraOpaqueData) + return nil +} + +// Encode serializes the target AnnounceSignatures2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(a, w) } // MsgType returns the integer uniquely identifying this message type on the @@ -93,16 +105,34 @@ func (a *AnnounceSignatures2) SerializedSize() (uint32, error) { return MessageSerializedSize(a) } +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (a *AnnounceSignatures2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &a.ChannelID, &a.ShortChannelID, + &a.PartialSignature, + } + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(a.ExtraSignedFields), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) SCID() ShortChannelID { - return a.ShortChannelID + return a.ShortChannelID.Val } // ChanID returns the ChannelID identifying the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) ChanID() ChannelID { - return a.ChannelID + return a.ChannelID.Val } diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 57b3a24b8c6..f8a58c1bc91 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -12,9 +12,6 @@ import ( // ChannelAnnouncement2 message is used to announce the existence of a taproot // channel between two peers in the network. type ChannelAnnouncement2 struct { - // Signature is a Schnorr signature over the TLV stream of the message. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] @@ -59,47 +56,103 @@ type ChannelAnnouncement2 struct { // the funding output is a pure 2-of-2 MuSig aggregate public key. MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte] - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + // Signature is a Schnorr signature over serialised signed-range TLV + // stream of the message. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields } -// Decode deserializes a serialized AnnounceSignatures1 stored in the passed -// io.Reader observing the specified protocol version. +// Encode serializes the target AnnounceSignatures1 into the passed io.Writer +// observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() +func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) +} + +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelAnnouncement2) AllRecords() []tlv.Record { + recordProducers := append( + c.allNonSignatureRecordProducers(), &c.Signature, + ) - return c.DecodeTLVRecords(r) + return ProduceRecordsSorted(recordProducers...) } -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { - // First extract into extra opaque data. - var tlvRecords ExtraOpaqueData - if err := ReadElements(r, &tlvRecords); err != nil { - return err +func (c *ChannelAnnouncement2) allNonSignatureRecordProducers() []tlv.RecordProducer { //nolint:ll + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + var recordProducers []tlv.RecordProducer + if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { + hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + hash.Val = c.ChainHash.Val + + recordProducers = append(recordProducers, &hash) } + recordProducers = append(recordProducers, + &c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1, + &c.NodeID2, + ) + + c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.MerkleRootHash.WhenSome( + func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { + recordProducers = append(recordProducers, &hash) + }, + ) + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraSignedFields), + )...) + + return recordProducers +} + +// Decode deserializes a serialized AnnounceSignatures1 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { var ( chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() ) - typeMap, err := tlvRecords.ExtractRecords( - &chainHash, &c.Features, &c.ShortChannelID, &c.Capacity, - &c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash, - ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + &c.Signature, + )...) + if err != nil { + return err + } + c.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) if err != nil { return err } @@ -122,68 +175,68 @@ func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) return nil } -// Encode serializes the target AnnounceSignatures1 into the passed io.Writer -// observing the protocol version specified. -// -// This is part of the lnwire.Message interface. -func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) +// DecodeNonSigTLVRecords decodes only the TLV section of the message. +func (c *ChannelAnnouncement2) DecodeNonSigTLVRecords(r io.Reader) error { + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + )...) if err != nil { return err } - _, err = c.DataToSign() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) if err != nil { return err } - return WriteBytes(w, c.ExtraOpaqueData) -} + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } -// DataToSign encodes the data to be signed into the ExtraOpaqueData member and -// returns it. -func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { - // The chain-hash record is only included if it is _not_ equal to the - // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer - if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { - hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() - hash.Val = c.ChainHash.Val + if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { + c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + } - recordProducers = append(recordProducers, &hash) + if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { + c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) } - recordProducers = append(recordProducers, - &c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1, - &c.NodeID2, - ) + if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { + c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) + } - c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { - recordProducers = append(recordProducers, &key) - }) + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) - c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { - recordProducers = append(recordProducers, &key) - }) + return nil +} - c.MerkleRootHash.WhenSome( - func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { - recordProducers = append(recordProducers, &hash) - }, +// EncodeAllNonSigFields encodes the entire message to the given writer but +// excludes the signature field. +func (c *ChannelAnnouncement2) EncodeAllNonSigFields(w io.Writer) error { + return EncodeRecordsTo( + w, ProduceRecordsSorted(c.allNonSignatureRecordProducers()...), ) - - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } - - return c.ExtraOpaqueData, nil } // MsgType returns the integer uniquely identifying this message type on the @@ -209,6 +262,10 @@ var _ Message = (*ChannelAnnouncement2)(nil) // lnwire.SizeableMessage interface. var _ SizeableMessage = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.PureTLVMessage interface. +var _ PureTLVMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_id.go b/lnwire/channel_id.go index 1615eb74710..5c9eca34fbd 100644 --- a/lnwire/channel_id.go +++ b/lnwire/channel_id.go @@ -3,10 +3,12 @@ package lnwire import ( "encoding/binary" "encoding/hex" + "io" "math" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -36,6 +38,40 @@ func (c ChannelID) String() string { return hex.EncodeToString(c[:]) } +// Record returns a TLV record that can be used to encode/decode a ChannelID +// to/from a TLV stream. +func (c *ChannelID) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 32, encodeChannelID, decodeChannelID) +} + +func encodeChannelID(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ChannelID); ok { + bigSize := [32]byte(*v) + + return tlv.EBytes32(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelID") +} + +func decodeChannelID(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ChannelID); ok { + var id [32]byte + err := tlv.DBytes32(r, &id, buf, l) + if err != nil { + return err + } + + *v = id + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.ChannelID", l, l) +} + // NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is // usable within the network. In order to convert the OutPoint into a ChannelID, // we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 56f7edf6b42..5e7a31e2ce6 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -22,10 +22,6 @@ const ( // HTLCs and other parameters. This message is also used to redeclare initially // set channel parameters. type ChannelUpdate2 struct { - // Signature is used to validate the announced data and prove the - // ownership of node id. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. // Along with the short channel ID, this uniquely identifies the @@ -74,10 +70,22 @@ type ChannelUpdate2 struct { // millionth of a satoshi. FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] - // ExtraOpaqueData is the set of data that was appended to this message - // to fill out the full maximum transport message size. These fields can - // be used to specify optional data such as custom TLV fields. - ExtraOpaqueData ExtraOpaqueData + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) } // Decode deserializes a serialized ChannelUpdate2 stored in the passed @@ -85,17 +93,6 @@ type ChannelUpdate2 struct { // // This is part of the lnwire.Message interface. func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() - - return c.DecodeTLVRecords(r) -} - -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { // First extract into extra opaque data. var tlvRecords ExtraOpaqueData if err := ReadElements(r, &tlvRecords); err != nil { @@ -111,10 +108,12 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { &secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, &c.HTLCMaximumMsat, &c.FeeBaseMsat, &c.FeeProportionalMillionths, + &c.Signature, ) if err != nil { return err } + c.Signature.Val.ForceSchnorr() // By default, the chain-hash is the bitcoin mainnet genesis block hash. c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash @@ -150,38 +149,21 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:ll } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) return nil } -// Encode serializes the target ChannelUpdate2 into the passed io.Writer -// observing the protocol version specified. +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. // -// This is part of the lnwire.Message interface. -func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) - if err != nil { - return err - } - - _, err = c.DataToSign() - if err != nil { - return err - } - - return WriteBytes(w, c.ExtraOpaqueData) -} +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelUpdate2) AllRecords() []tlv.Record { + var recordProducers []tlv.RecordProducer -// DataToSign is used to retrieve part of the announcement message which should -// be signed. For the ChannelUpdate2 message, this includes the serialised TLV -// records. -func (c *ChannelUpdate2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() hash.Val = c.ChainHash.Val @@ -190,7 +172,7 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { } recordProducers = append(recordProducers, - &c.ShortChannelID, &c.BlockHeight, + &c.ShortChannelID, &c.BlockHeight, &c.Signature, ) // Only include the disable flags if any bit is set. @@ -225,12 +207,11 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { ) } - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraSignedFields), + )...) - return c.ExtraOpaqueData, nil + return ProduceRecordsSorted(recordProducers...) } // MsgType returns the integer uniquely identifying this message type on the @@ -248,8 +229,14 @@ func (c *ChannelUpdate2) SerializedSize() (uint32, error) { return MessageSerializedSize(c) } -func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { - return c.ExtraOpaqueData +func (c *ChannelUpdate2) ExtraData() (ExtraOpaqueData, error) { + var buf *bytes.Buffer + err := EncodeRecordsTo(buf, tlv.MapToRecords(c.ExtraSignedFields)) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil } // A compile time check to ensure ChannelUpdate2 implements the diff --git a/lnwire/message.go b/lnwire/message.go index ea480075a1a..2944cb21ebe 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -63,6 +63,7 @@ const ( MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 MsgChannelAnnouncement2 = 267 + MsgNodeAnnouncement2 = 269 MsgChannelUpdate2 = 271 MsgKickoffSig = 777 @@ -190,6 +191,8 @@ func (t MessageType) String() string { return "MsgAnnounceSignatures2" case MsgChannelAnnouncement2: return "ChannelAnnouncement2" + case MsgNodeAnnouncement2: + return "NodeAnnouncement2" case MsgChannelUpdate2: return "ChannelUpdate2" default: @@ -350,6 +353,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnounceSignatures2{} case MsgChannelAnnouncement2: msg = &ChannelAnnouncement2{} + case MsgNodeAnnouncement2: + msg = &NodeAnnouncement2{} case MsgChannelUpdate2: msg = &ChannelUpdate2{} default: diff --git a/lnwire/node_announcement_2.go b/lnwire/node_announcement_2.go new file mode 100644 index 00000000000..944f71edb14 --- /dev/null +++ b/lnwire/node_announcement_2.go @@ -0,0 +1,457 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "image/color" + "io" + "net" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" +) + +// NodeAnnouncement2 is used to signal the presence of a node on the network +// along with information about the node that can be used to connect to it. This +// announcement is to be used in the gossip 1.75 protocol. +type NodeAnnouncement2 struct { + // Features is the feature vector that encodes the features supported + // by the target node. + Features tlv.RecordT[tlv.TlvType0, RawFeatureVector] + + // Color is an optional field used to customize a node's appearance in + // maps and graphs. + Color tlv.OptionalRecordT[tlv.TlvType1, Color] + + // BlockHeight allows ordering in the case of multiple announcements. We + // should ignore the message if block height is not greater than the + // last-received. The block height must always be greater or equal to + // the block height that the channel funding transaction was confirmed + // in. + BlockHeight tlv.RecordT[tlv.TlvType2, uint32] + + // Alias is used to customize their node's appearance in maps and + // graphs. + Alias tlv.OptionalRecordT[tlv.TlvType3, []byte] + + // NodeID is the public key of the node creating the announcement. + NodeID tlv.RecordT[tlv.TlvType6, [33]byte] + + // IPV4Addrs is an optional list of ipv4 addresses that the node is + // reachable at. + IPV4Addrs tlv.OptionalRecordT[tlv.TlvType5, IPV4Addrs] + + // IPV6Addrs is an optional list of ipv6 addresses that the node is + // reachable at. + IPV6Addrs tlv.OptionalRecordT[tlv.TlvType7, IPV6Addrs] + + // TorV3Addrs is an optional list of tor v3 addresses that the node is + // reachable at. + TorV3Addrs tlv.OptionalRecordT[tlv.TlvType9, TorV3Addrs] + + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (n *NodeAnnouncement2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &n.Features, + &n.BlockHeight, + &n.NodeID, + &n.Signature, + } + + n.Color.WhenSome(func(r tlv.RecordT[tlv.TlvType1, Color]) { + recordProducers = append(recordProducers, &r) + }) + + n.Alias.WhenSome(func(a tlv.RecordT[tlv.TlvType3, []byte]) { + recordProducers = append(recordProducers, &a) + }) + + n.IPV4Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType5, IPV4Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + n.IPV6Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType7, IPV6Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + n.TorV3Addrs.WhenSome(func(r tlv.RecordT[tlv.TlvType9, TorV3Addrs]) { + recordProducers = append(recordProducers, &r) + }) + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(n.ExtraSignedFields), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + +// Decode deserializes a serialized ChannelUpdate2 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Decode(r io.Reader, _ uint32) error { + var ( + color = tlv.ZeroRecordT[tlv.TlvType1, Color]() + alias = tlv.ZeroRecordT[tlv.TlvType3, []byte]() + ipv4 = tlv.ZeroRecordT[tlv.TlvType5, IPV4Addrs]() + ipv6 = tlv.ZeroRecordT[tlv.TlvType7, IPV6Addrs]() + torV3 = tlv.ZeroRecordT[tlv.TlvType9, TorV3Addrs]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &n.Features, + &n.BlockHeight, + &n.NodeID, + &n.Signature, + &alias, + &color, + &ipv4, + &ipv6, + &torV3, + )...) + if err != nil { + return err + } + n.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[n.Alias.TlvType()]; ok { + n.Alias = tlv.SomeRecordT(alias) + } + + if _, ok := typeMap[n.Color.TlvType()]; ok { + n.Color = tlv.SomeRecordT(color) + } + + if _, ok := typeMap[n.IPV4Addrs.TlvType()]; ok { + n.IPV4Addrs = tlv.SomeRecordT(ipv4) + } + + if _, ok := typeMap[n.IPV6Addrs.TlvType()]; ok { + n.IPV6Addrs = tlv.SomeRecordT(ipv6) + } + + if _, ok := typeMap[n.TorV3Addrs.TlvType()]; ok { + n.TorV3Addrs = tlv.SomeRecordT(torV3) + } + + n.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(n, w) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) MsgType() MessageType { + return MsgNodeAnnouncement2 +} + +// SerializedSize returns the serialized size of the message in bytes. +// +// This is part of the lnwire.SizeableMessage interface. +func (n *NodeAnnouncement2) SerializedSize() (uint32, error) { + return MessageSerializedSize(n) +} + +// A compile-time check to ensure NodeAnnouncement2 implements the Message +// interface. +var _ Message = (*NodeAnnouncement2)(nil) + +// A compile time check to ensure NodeAnnouncement2 implements the +// lnwire.SizeableMessage interface. +var _ SizeableMessage = (*NodeAnnouncement2)(nil) + +// A compile-time check to ensure NodeAnnouncement2 implements the +// PureTLVMessage interface. +var _ PureTLVMessage = (*NodeAnnouncement2)(nil) + +type Color color.RGBA + +func (c *Color) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 3, rgbEncoder, rgbDecoder) +} + +func rgbEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*Color); ok { + buf := bytes.NewBuffer(nil) + err := WriteColorRGBA(buf, color.RGBA(*v)) + if err != nil { + return err + } + _, err = w.Write(buf.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "Color") +} + +func rgbDecoder(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*Color); ok { + return ReadElements(r, &v.R, &v.G, &v.B) + } + + return tlv.NewTypeForDecodingErr(val, "Color", l, 3) +} + +// ipv4AddrEncodedSize is the number of bytes required to encode a single ipv4 +// address. Four bytes are used to encode the IP address and two bytes for the +// port number. +const ipv4AddrEncodedSize = 4 + 2 + +// IPV4Addrs is a list of ipv4 addresses that can be encoded as a TLV record. +type IPV4Addrs []*net.TCPAddr + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *IPV4Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, ipv4AddrsEncoder, ipv4AddrsDecoder, + ) +} + +// EncodedSize returns the number of bytes required to encode an IPV4Addrs +// variable. +func (a *IPV4Addrs) EncodedSize() uint64 { + return uint64(len(*a) * ipv4AddrEncodedSize) +} + +func ipv4AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV4Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To4()) + if err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + _, err = w.Write(port[:]) + + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +func ipv4AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*IPV4Addrs); ok { + if l%(ipv4AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv4 list encoding") + } + var ( + numAddrs = int(l / ipv4AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [4]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(port[:]) + if err != nil { + return err + } + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + *v = addrs + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +// IPV6Addrs is a list of ipv6 addresses that can be encoded as a TLV record. +type IPV6Addrs []*net.TCPAddr + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *IPV6Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, ipv6AddrsEncoder, ipv6AddrsDecoder, + ) +} + +// ipv6AddrEncodedSize is the number of bytes required to encode a single ipv6 +// address. Sixteen bytes are used to encode the IP address and two bytes for +// the port number. +const ipv6AddrEncodedSize = 16 + 2 + +// EncodedSize returns the number of bytes required to encode an IPV6Addrs +// variable. +func (a *IPV6Addrs) EncodedSize() uint64 { + return uint64(len(*a) * ipv6AddrEncodedSize) +} + +func ipv6AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV6Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To16()) + if err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + _, err = w.Write(port[:]) + + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +func ipv6AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + if v, ok := val.(*IPV6Addrs); ok { + if l%(ipv6AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv6 list encoding") + } + var ( + numAddrs = int(l / ipv6AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [16]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(port[:]) + if err != nil { + return err + } + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + *v = addrs + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +// TorV3Addrs is a list of tor v3 addresses that can be encoded as a TLV record. +type TorV3Addrs []*tor.OnionAddr + +// torV3AddrEncodedSize is the number of bytes required to encode a single tor +// v3 address. +const torV3AddrEncodedSize = tor.V3DecodedLen + 2 + +// EncodedSize returns the number of bytes required to encode an TorV3Addrs +// variable. +func (a *TorV3Addrs) EncodedSize() uint64 { + return uint64(len(*a) * torV3AddrEncodedSize) +} + +// Record returns a Record that can be used to encode/decode a IPV4Addrs +// to/from a TLV stream. +func (a *TorV3Addrs) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, a, a.EncodedSize, torV3AddrsEncoder, torV3AddrsDecoder, + ) +} + +func torV3AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*TorV3Addrs); ok { + for _, addr := range *v { + encodedHostLen := tor.V3Len - tor.OnionSuffixLen + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:encodedHostLen], + ) + if err != nil { + return err + } + if len(host) != tor.V3DecodedLen { + return fmt.Errorf("expected a tor v3 host "+ + "length of %d, got: %d", + tor.V2DecodedLen, len(host)) + } + if _, err = w.Write(host); err != nil { + return err + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(addr.Port)) + _, err = w.Write(port[:]) + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +} + +func torV3AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*TorV3Addrs); ok { + if l%torV3AddrEncodedSize != 0 { + return fmt.Errorf("invalid tor v3 list encoding") + } + var ( + numAddrs = int(l / torV3AddrEncodedSize) + addrs = make([]*tor.OnionAddr, 0, numAddrs) + ip [tor.V3DecodedLen]byte + p [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + _, err = r.Read(p[:]) + if err != nil { + return err + } + onionService := tor.Base32Encoding.EncodeToString(ip[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + addrs = append(addrs, &tor.OnionAddr{ + OnionService: onionService, + Port: port, + }) + } + *v = addrs + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +} diff --git a/lnwire/pure_tlv.go b/lnwire/pure_tlv.go new file mode 100644 index 00000000000..8e6f7bd9fc3 --- /dev/null +++ b/lnwire/pure_tlv.go @@ -0,0 +1,105 @@ +package lnwire + +import ( + "bytes" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // pureTLVUnsignedRangeOneStart defines the start of the first unsigned + // TLV range used for pure TLV messages. The range is inclusive of this + // number. + pureTLVUnsignedRangeOneStart = 160 + + // pureTLVSignedSecondRangeStart defines the start of the second signed + // TLV range used for pure TLV messages. The range is inclusive of this + // number. Note that the first range is the inclusive range of 0-159. + pureTLVSignedSecondRangeStart = 1000000000 + + // pureTLVUnsignedRangeTwoStart defines the start of the second unsigned + // TLV range used for pure TLV message. + pureTLVUnsignedRangeTwoStart = 3000000000 +) + +// PureTLVMessage describes an LN message that is a pure TLV stream. If the +// message includes a signature, it will sign all the TLV records in the +// inclusive ranges: 0 to 159 and 1000000000 to 2999999999. +type PureTLVMessage interface { + // AllRecords returns all the TLV records for the message. This will + // include all the records we know about along with any that we don't + // know about but that fall in the signed TLV range. + AllRecords() []tlv.Record +} + +// EncodePureTLVMessage encodes the given PureTLVMessage to the given buffer. +func EncodePureTLVMessage(msg PureTLVMessage, buf *bytes.Buffer) error { + return EncodeRecordsTo(buf, msg.AllRecords()) +} + +// SerialiseFieldsToSign serialises all the records from the given +// PureTLVMessage that fall within the signed TLV range. +func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { + // Filter out all the fields not in the signed ranges. + var signedRecords []tlv.Record + for _, record := range msg.AllRecords() { + if InUnsignedRange(record.Type()) { + continue + } + + signedRecords = append(signedRecords, record) + } + + var buf bytes.Buffer + if err := EncodeRecordsTo(&buf, signedRecords); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// InUnsignedRange returns true if the given TLV type falls outside the TLV +// ranges that the signature of a pure TLV message will cover. +func InUnsignedRange(t tlv.Type) bool { + return (t >= pureTLVUnsignedRangeOneStart && + t < pureTLVSignedSecondRangeStart) || + t >= pureTLVUnsignedRangeTwoStart +} + +// ExtraSignedFields is a type that stores a map from TLV types in the signed +// range (for PureMessages) to their corresponding serialised values. This type +// can be used to keep around data that we don't yet understand but that we need +// for re-composing the wire message since the signature covers these fields. +type ExtraSignedFields map[uint64][]byte + +// ExtraSignedFieldsFromTypeMap is a helper that can be used alongside calls to +// the tlv.Stream DecodeWithParsedTypesP2P or DecodeWithParsedTypes methods to +// extract the tlv type and value pairs in the defined PureTLVMessage signed +// range which we have not handled with any of our defined Records. These +// methods will return a tlv.TypeMap containing the records that were extracted +// from an io.Reader. If the record was know and handled by a defined record, +// then the value accompanying the record's type in the map will be nil. +// Otherwise, if the record was unhandled, it will be non-nil. +func ExtraSignedFieldsFromTypeMap(m tlv.TypeMap) ExtraSignedFields { + extraFields := make(ExtraSignedFields) + for t, v := range m { + // If the value in the type map is nil, then it indicates that + // we know this type, and it was handled by one of the records + // we passed to the decode function vai the TLV stream. + if v == nil { + continue + } + + // No need to keep this field if it is unknown to us and is not + // in the sign range. + if InUnsignedRange(t) { + continue + } + + // Otherwise, this is an un-handled type, so we keep track of + // it for signature validation and re-encoding later on. + extraFields[uint64(t)] = v + } + + return extraFields +} diff --git a/lnwire/pure_tlv_test.go b/lnwire/pure_tlv_test.go new file mode 100644 index 00000000000..a81a89ecb6d --- /dev/null +++ b/lnwire/pure_tlv_test.go @@ -0,0 +1,389 @@ +package lnwire + +import ( + "bytes" + "io" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestPureTLVMessages tests the forwards compatibility of two versions of the +// same Lightning Network message that uses the Pure TLV format. This in essence +// tests that and older client is able to verify the signature over relevant +// data in a newer client's message. +func TestPureTLVMessage(t *testing.T) { + t.Parallel() + + var ( + _, pkA = btcec.PrivKeyFromBytes([]byte{1}) + _, pkB = btcec.PrivKeyFromBytes([]byte{2}) + capacity = MilliSatoshi(100) + ) + + // Test encode and decode of MsgV1 as is. + t.Run("Encode and Decode of MsgV1", func(t *testing.T) { + t.Parallel() + + msgOld := newMsgV1(pkA, &capacity) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgOld.Encode(buf, 0)) + + var msgOld2 MsgV1 + require.NoError(t, msgOld2.Decode(buf, 0)) + + require.Equal(t, msgOld, &msgOld2) + }) + + // Test encode and decode of MsgV2 as is. + t.Run("Encode and Decode of MsgV2", func(t *testing.T) { + t.Parallel() + + msgNew := newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 90, 100, true, + ) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgNew.Encode(buf, 0)) + + var msgNew2 MsgV2 + require.NoError(t, msgNew2.Decode(buf, 0)) + + require.Equal(t, msgNew, &msgNew2) + }) + + // Create a MsgV2 and decode it into a MsgV1. Both the new client + // (MsgV2) and old client (MsgV1) should be able to generate the same + // digest that will be used to create and validate the signture. + t.Run("Encode MsgV2 and decode via MsgV1", func(t *testing.T) { + t.Parallel() + + var ( + buf = bytes.NewBuffer(nil) + msgV2 = newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 100, + 90, true, + ) + ) + require.NoError(t, msgV2.Encode(buf, 0)) + + // Get the serialised bytes that would be signed for msgV2. + signData1, err := SerialiseFieldsToSign(msgV2) + require.NoError(t, err) + + // Decoding via the old message should store some of the extra + // fields. + var msgV1 MsgV1 + require.NoError(t, msgV1.Decode(buf, 0)) + require.NotEmpty(t, msgV1.ExtraSignedFields) + + // Show that the extra fields map contains unknown fields in the + // signed range but not unknown fields in the unsigned range. + _, ok := msgV1.ExtraSignedFields[uint64(msgV2.Num.TlvType())] //nolint:ll + require.True(t, ok) + _, ok = msgV1.ExtraSignedFields[uint64(msgV2.Other.TlvType())] //nolint:ll + require.False(t, ok) + + // The serialised bytes to verify the signature against should + // be the same though. + signData2, err := SerialiseFieldsToSign(&msgV1) + require.NoError(t, err) + + require.Equal(t, signData1, signData2) + + // Re-encoding via the old message should keep the extra fields. + buf = bytes.NewBuffer(nil) + require.NoError(t, msgV1.Encode(buf, 0)) + + var msgV1ReEncoded MsgV1 + require.NoError(t, msgV1ReEncoded.Decode(buf, 0)) + + require.Equal(t, &msgV1, &msgV1ReEncoded) + }) +} + +// MsgV1 represents a more minimal, first version of a Lightning Network +// message. +type MsgV1 struct { + // Two known fields in the signed range. + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +var _ Message = (*MsgV1)(nil) +var _ PureTLVMessage = (*MsgV1)(nil) + +// newMsgV1 is a constructor for MsgV1. +func newMsgV1(nodeKey *btcec.PublicKey, capacity *MilliSatoshi) *MsgV1 { + newMsg := &MsgV1{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0]( + nodeKey, + ), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV1 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Decode(r io.Reader, _ uint32) error { + var capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &g.Signature, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + g.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV1 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) MsgType() MessageType { + return 7777 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV1) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + } + recordProducers = append( + recordProducers, + RecordsAsProducers( + tlv.MapToRecords(g.ExtraSignedFields), + )..., + ) + + g.Capacity.WhenSome( + func(capacity tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &capacity) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} + +// MsgV2 represents a newer version of MsgV1 which contains more fields both in +// the unsigned and signed TLV ranges. +type MsgV2 struct { + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // An additional fields (optional) in the signed range. + BitcoinKey tlv.OptionalRecordT[tlv.TlvType3, *btcec.PublicKey] + + // A zero length TLV in the signed range. + SecondPeer tlv.OptionalRecordT[tlv.TlvType5, TrueBoolean] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Another field in the unsigned range. An older node can throw this + // away. + SPVProof tlv.RecordT[tlv.TlvType161, []byte] + + // A new field in the second signed range. An older node should keep + // this since it is part of the serialised message that is signed. + Num tlv.RecordT[tlv.TlvType1000000000, uint8] + + // Another field in the second unsigned-range. Older nodes may throw + // this away and it won't affect the digest used for signature creation + // and validation. + Other tlv.RecordT[tlv.TlvType3000000000, uint8] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// newMsgV2 is a constructor for MsgV2. +func newMsgV2(nodeKey *btcec.PublicKey, capacity *MilliSatoshi, + btcKey *btcec.PublicKey, spvProof []byte, num, other uint8, + secondPeer bool) *MsgV2 { + + newMsg := &MsgV2{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0](nodeKey), + SPVProof: tlv.NewPrimitiveRecord[tlv.TlvType161](spvProof), + Num: tlv.NewPrimitiveRecord[tlv.TlvType1000000000](num), + Other: tlv.NewPrimitiveRecord[tlv.TlvType3000000000](num), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } + + if secondPeer { + newMsg.SecondPeer = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType5](TrueBoolean{}), + ) + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + if btcKey != nil { + newMsg.BitcoinKey = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType3](btcKey), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV2 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Decode(r io.Reader, _ uint32) error { + var ( + capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + btcKey = tlv.ZeroRecordT[tlv.TlvType3, *btcec.PublicKey]() + secondPeer = tlv.ZeroRecordT[tlv.TlvType5, TrueBoolean]() + ) + + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &btcKey, + &secondPeer, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + if _, ok := typeMap[g.SecondPeer.TlvType()]; ok { + g.SecondPeer = tlv.SomeRecordT(secondPeer) + } + + if _, ok := typeMap[g.BitcoinKey.TlvType()]; ok { + g.BitcoinKey = tlv.SomeRecordT(btcKey) + } + + g.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV2 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) MsgType() MessageType { + return 7779 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(g.ExtraSignedFields), + )...) + + g.Capacity.WhenSome( + func(c tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &c) + }, + ) + g.BitcoinKey.WhenSome( + func(key tlv.RecordT[tlv.TlvType3, *btcec.PublicKey]) { + recordProducers = append(recordProducers, &key) + }, + ) + g.SecondPeer.WhenSome( + func(second tlv.RecordT[tlv.TlvType5, TrueBoolean]) { + recordProducers = append(recordProducers, &second) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index 8b3d98400a1..969622d22cf 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -129,12 +129,29 @@ var _ TestMessage = (*AnnounceSignatures2)(nil) // // This is part of the TestMessage interface. func (a *AnnounceSignatures2) RandTestMessage(t *rapid.T) Message { - return &AnnounceSignatures2{ - ChannelID: RandChannelID(t), - ShortChannelID: RandShortChannelID(t), - PartialSignature: *RandPartialSig(t), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + var ( + chanID = RandChannelID(t) + scid = RandShortChannelID(t) + pSig = RandPartialSig(t) + ) + + msg := &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID]( + chanID, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2](scid), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + *pSig, + ), + ExtraSignedFields: make(map[uint64][]byte), + } + + randRecs, _ := RandSignedRangeRecords(t) + if len(randRecs) > 0 { + msg.ExtraSignedFields = ExtraSignedFields(randRecs) } + + return msg } // A compile time check to ensure ChannelAnnouncement1 implements the @@ -186,6 +203,97 @@ func (a *ChannelAnnouncement1) RandTestMessage(t *rapid.T) Message { } } +// A compile time check to ensure NodeAnnouncement2 implements the +// lnwire.TestMessage interface. +var _ TestMessage = (*NodeAnnouncement2)(nil) + +// RandTestMessage populates the message with random data suitable for testing. +// It uses the rapid testing framework to generate random values. +// +// This is part of the TestMessage interface. +func (n *NodeAnnouncement2) RandTestMessage(t *rapid.T) Message { + /* + + */ + + features := RandFeatureVector(t) + blockHeight := uint32(rapid.IntRange(0, 1000000).Draw(t, "blockHeight")) + + var nodeID [33]byte + copy(nodeID[:], RandPubKey(t).SerializeCompressed()) + + msg := &NodeAnnouncement2{ + Features: tlv.NewRecordT[tlv.TlvType0, RawFeatureVector]( + *features, + ), + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType2, uint32]( + blockHeight, + ), + Alias: tlv.OptionalRecordT[tlv.TlvType3, []byte]{}, + NodeID: tlv.NewPrimitiveRecord[tlv.TlvType6, [33]byte]( + nodeID, + ), + IPV4Addrs: tlv.OptionalRecordT[tlv.TlvType5, IPV4Addrs]{}, + IPV6Addrs: tlv.OptionalRecordT[tlv.TlvType7, IPV6Addrs]{}, + TorV3Addrs: tlv.OptionalRecordT[tlv.TlvType9, TorV3Addrs]{}, + ExtraSignedFields: make(map[uint64][]byte), + } + + msg.Signature.Val = RandSignature(t) + msg.Signature.Val.ForceSchnorr() + + randRecs, _ := RandSignedRangeRecords(t) + if len(randRecs) > 0 { + msg.ExtraSignedFields = ExtraSignedFields(randRecs) + } + + if rapid.Bool().Draw(t, "includeColour") { + color := tlv.ZeroRecordT[tlv.TlvType1, Color]() + color.Val = Color{ + R: uint8(rapid.Uint16().Draw(t, "r")), + G: uint8(rapid.Uint16().Draw(t, "g")), + B: uint8(rapid.Uint16().Draw(t, "b")), + } + msg.Color = tlv.SomeRecordT(color) + } + + if rapid.Bool().Draw(t, "includeIpv4Addrs") { + ipv4Addr := RandTCP4Addr(t) + ipv4AddrRecord := tlv.ZeroRecordT[ + tlv.TlvType5, IPV4Addrs, + ]() + ipv4AddrRecord.Val = IPV4Addrs{ipv4Addr} + msg.IPV4Addrs = tlv.SomeRecordT(ipv4AddrRecord) + } + if rapid.Bool().Draw(t, "includeIpv6Addrs") { + ipv6Addr := RandTCP6Addr(t) + ipv6AddrRecord := tlv.ZeroRecordT[ + tlv.TlvType7, IPV6Addrs, + ]() + ipv6AddrRecord.Val = IPV6Addrs{ipv6Addr} + msg.IPV6Addrs = tlv.SomeRecordT(ipv6AddrRecord) + } + if rapid.Bool().Draw(t, "includeTorV3Addrs") { + torAddr := RandV3OnionAddr(t) + torAddrRecord := tlv.ZeroRecordT[ + tlv.TlvType9, TorV3Addrs, + ]() + torAddrRecord.Val = TorV3Addrs{torAddr} + msg.TorV3Addrs = tlv.SomeRecordT(torAddrRecord) + } + + if rapid.Bool().Draw(t, "includeAlias") { + alias := rapid.String().Draw(t, "alias") + aliasRec := tlv.ZeroRecordT[ + tlv.TlvType3, []byte, + ]() + aliasRec.Val = []byte(alias) + msg.Alias = tlv.SomeRecordT(aliasRec) + } + + return msg +} + // A compile time check to ensure ChannelAnnouncement2 implements the // lnwire.TestMessage interface. var _ TestMessage = (*ChannelAnnouncement2)(nil) @@ -213,7 +321,6 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { copy(chainHashObj[:], chainHash[:]) msg := &ChannelAnnouncement2{ - Signature: RandSignature(t), ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( chainHashObj, ), @@ -232,10 +339,16 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( nodeID2, ), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + ExtraSignedFields: make(map[uint64][]byte), } - msg.Signature.ForceSchnorr() + msg.Signature.Val = RandSignature(t) + msg.Signature.Val.ForceSchnorr() + + randRecs, _ := RandSignedRangeRecords(t) + if len(randRecs) > 0 { + msg.ExtraSignedFields = ExtraSignedFields(randRecs) + } // Randomly include optional fields if rapid.Bool().Draw(t, "includeBitcoinKey1") { @@ -473,7 +586,6 @@ func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { //nolint:ll msg := &ChannelUpdate2{ - Signature: RandSignature(t), ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( chainHashObj, ), @@ -501,10 +613,11 @@ func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32]( feeProportionalMillionths, ), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + ExtraSignedFields: make(map[uint64][]byte), } - msg.Signature.ForceSchnorr() + msg.Signature.Val = RandSignature(t) + msg.Signature.Val.ForceSchnorr() if rapid.Bool().Draw(t, "isSecondPeer") { msg.SecondPeer = tlv.SomeRecordT( @@ -688,7 +801,7 @@ var _ TestMessage = (*CommitSig)(nil) // // This is part of the TestMessage interface. func (c *CommitSig) RandTestMessage(t *rapid.T) Message { - cr, _ := RandCustomRecords(t, nil, true) + cr, _ := RandCustomRecords(t, nil) sig := &CommitSig{ ChanID: RandChannelID(t), CommitSig: RandSignature(t), @@ -1448,7 +1561,7 @@ func (s *Shutdown) RandTestMessage(t *rapid.T) Message { shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t)) } - cr, _ := RandCustomRecords(t, nil, true) + cr, _ := RandCustomRecords(t, nil) return &Shutdown{ ChannelID: RandChannelID(t), @@ -1505,7 +1618,7 @@ func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message { numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") if numRecords > 0 { - msg.CustomRecords, _ = RandCustomRecords(t, nil, true) + msg.CustomRecords, _ = RandCustomRecords(t, nil) } // 50/50 chance to add a blinding point @@ -1586,7 +1699,7 @@ func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message { PaymentPreimage: RandPaymentPreimage(t), } - cr, ignoreRecords := RandCustomRecords(t, nil, true) + cr, ignoreRecords := RandCustomRecords(t, nil) msg.CustomRecords = cr randData := RandExtraOpaqueData(t, ignoreRecords) diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 1065cbacfe6..806b872d9cd 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -2,6 +2,7 @@ package lnwire import ( "crypto/sha256" + "encoding/binary" "fmt" "net" @@ -10,11 +11,12 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/tor" "github.com/stretchr/testify/require" "pgregory.net/rapid" ) -// RandChannelUpdate generates a random ChannelUpdate message using rapid's +// RandPartialSig generates a random PartialSig message using rapid's // generators. func RandPartialSig(t *rapid.T) *PartialSig { // Generate random private key bytes @@ -194,23 +196,37 @@ func RandNetAddrs(t *rapid.T) []net.Addr { } // RandCustomRecords generates random custom TLV records. -func RandCustomRecords(t *rapid.T, - ignoreRecords fn.Set[uint64], - custom bool) (CustomRecords, fn.Set[uint64]) { +func RandCustomRecords(t *rapid.T, ignoreRecords fn.Set[uint64]) (CustomRecords, + fn.Set[uint64]) { - numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords") + customRecords, set := RandTLVRecords( + t, ignoreRecords, MinCustomRecordsTlvType, + ) + + // Validate the custom records as a sanity check. + require.NoError(t, customRecords.Validate()) + + return customRecords, set +} + +// RandSignedRangeRecords generates a random set of signed records in the +// second "signed" tlv range for pure TLV messages. +func RandSignedRangeRecords(t *rapid.T) (CustomRecords, fn.Set[uint64]) { + return RandTLVRecords(t, nil, pureTLVSignedSecondRangeStart) +} + +// RandTLVRecords generates custom TLV records. +func RandTLVRecords(t *rapid.T, ignoreRecords fn.Set[uint64], + rangeStart int) (CustomRecords, fn.Set[uint64]) { + + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") customRecords := make(CustomRecords) if numRecords == 0 { return nil, nil } - rangeStart := 0 - rangeStop := int(CustomTypeStart) - if custom { - rangeStart = 70_000 - rangeStop = 100_000 - } + rangeStop := rangeStart + 30_000 ignoreSet := fn.NewSet[uint64]() for i := 0; i < numRecords; i++ { @@ -254,7 +270,7 @@ func RandExtraOpaqueData(t *rapid.T, ignoreRecords fn.Set[uint64]) ExtraOpaqueData { // Make some random records. - cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + cRecords, _ := RandTLVRecords(t, ignoreRecords, 0) if cRecords == nil { return ExtraOpaqueData{} } @@ -358,3 +374,54 @@ func RandOutPoint(t *rapid.T) wire.OutPoint { Index: vout, } } + +// RandTCP4Addr generates a random TCP4 address. +func RandTCP4Addr(t *rapid.T) *net.TCPAddr { + var ip [4]byte + ipBytes := rapid.SliceOfN(rapid.Byte(), 4, 4).Draw(t, "ip") + copy(ip[:], ipBytes) + + var port [2]byte + portBytes := rapid.SliceOfN(rapid.Byte(), 2, 2).Draw(t, "ip") + copy(port[:], portBytes) + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort} +} + +// RandTCP6Addr generates a random TCP6 address. +func RandTCP6Addr(t *rapid.T) *net.TCPAddr { + var ip [16]byte + ipBytes := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw(t, "ip") + copy(ip[:], ipBytes) + + var port [2]byte + portBytes := rapid.SliceOfN(rapid.Byte(), 2, 2).Draw(t, "ip") + copy(port[:], portBytes) + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort} +} + +// RandV3OnionAddr generates a random V3 onion address. +func RandV3OnionAddr(t *rapid.T) *tor.OnionAddr { + var serviceID [tor.V3DecodedLen]byte + serviceIDBytes := rapid.SliceOfN( + rapid.Byte(), tor.V3DecodedLen, tor.V3DecodedLen, + ).Draw(t, "ip") + copy(serviceID[:], serviceIDBytes) + + var port [2]byte + portBytes := rapid.SliceOfN(rapid.Byte(), 2, 2).Draw(t, "ip") + copy(port[:], portBytes) + + onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) + onionService += tor.OnionSuffix + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &tor.OnionAddr{OnionService: onionService, Port: addrPort} +} diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 571e0584dc5..cbcce5335c9 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -8,7 +8,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -108,7 +110,8 @@ func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, // FetchPkScript defines a function that can be used to fetch the output script // for the transaction with the given SCID. -type FetchPkScript func(*lnwire.ShortChannelID) ([]byte, error) +type FetchPkScript func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) // ValidateChannelAnn validates the channel announcement. func ValidateChannelAnn(a lnwire.ChannelAnnouncement, @@ -202,24 +205,124 @@ func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error { func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, fetchPkScript FetchPkScript) error { + // Next, we fetch the funding transaction's PK script. We need this so + // that we know what type of channel we will be validating: P2WSH or + // P2TR. + scriptClass, scriptAddr, err := fetchPkScript(a.ShortChannelID.Val) + if err != nil { + return err + } + + var keys []*btcec.PublicKey + + switch scriptClass { + case txscript.WitnessV0ScriptHashTy: + keys, err = chanAnn2P2WSHMuSig2Keys(a) + if err != nil { + return err + } + case txscript.WitnessV1TaprootTy: + keys, err = chanAnn2P2TRMuSig2Keys(a, scriptAddr) + if err != nil { + return err + } + default: + return fmt.Errorf("invalid on-chain pk script type for "+ + "channel_announcement_2: %s", scriptClass) + } + + // Do a MuSig2 aggregation of the keys to obtain the aggregate key that + // the signature will be validated against. + aggKey, _, _, err := musig2.AggregateKeys(keys, true) + if err != nil { + return err + } + + // Get the message that the signature should have signed. dataHash, err := ChanAnn2DigestToSign(a) if err != nil { return err } - sig, err := a.Signature.ToSignature() + // Obtain the signature. + sig, err := a.Signature.Val.ToSignature() if err != nil { return err } + // Check that the signature is valid for the aggregate key given the + // message digest. + if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { + return fmt.Errorf("invalid sig") + } + + return nil +} + +// chanAnn2P2WSHMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2WSH channel. +func chanAnn2P2WSHMuSig2Keys(a *lnwire.ChannelAnnouncement2) ( + []*btcec.PublicKey, error) { + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) if err != nil { - return err + return nil, err } nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) if err != nil { - return err + return nil, err + } + + btcKeyMissingErrString := "bitcoin key %d missing for announcement " + + "of a P2WSH channel" + + btcKey1Bytes, err := a.BitcoinKey1.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 1), + ) + if err != nil { + return nil, err + } + + btcKey1, err := btcec.ParsePubKey(btcKey1Bytes.Val[:]) + if err != nil { + return nil, err + } + + btcKey2Bytes, err := a.BitcoinKey2.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 2), + ) + if err != nil { + return nil, err + } + + btcKey2, err := btcec.ParsePubKey(btcKey2Bytes.Val[:]) + if err != nil { + return nil, err + } + + return []*btcec.PublicKey{ + nodeKey1, nodeKey2, btcKey1, btcKey2, + }, nil +} + +// chanAnn2P2TRMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2TR channel. +func chanAnn2P2TRMuSig2Keys(a *lnwire.ChannelAnnouncement2, + scriptAddr btcutil.Address) ([]*btcec.PublicKey, error) { + + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) + if err != nil { + return nil, err + } + + nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) + if err != nil { + return nil, err } keys := []*btcec.PublicKey{ @@ -240,49 +343,36 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, bitcoinKey1, err := btcec.ParsePubKey(btcKey1.Val[:]) if err != nil { - return err + return nil, err } bitcoinKey2, err := btcec.ParsePubKey(btcKey2.Val[:]) if err != nil { - return err + return nil, err } keys = append(keys, bitcoinKey1, bitcoinKey2) } else { - // If bitcoin keys are not provided, then we need to get the - // on-chain output key since this will be the 3rd key in the - // 3-of-3 MuSig2 signature. - pkScript, err := fetchPkScript(&a.ShortChannelID.Val) - if err != nil { - return err - } - - outputKey, err := schnorr.ParsePubKey(pkScript[2:]) + // If bitcoin keys are not provided, then the on-chain output + // key is considered the 3rd key in the 3-of-3 MuSig2 signature. + outputKey, err := schnorr.ParsePubKey( + scriptAddr.ScriptAddress(), + ) if err != nil { - return err + return nil, err } keys = append(keys, outputKey) } - aggKey, _, _, err := musig2.AggregateKeys(keys, true) - if err != nil { - return err - } - - if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { - return fmt.Errorf("invalid sig") - } - - return nil + return keys, nil } // ChanAnn2DigestToSign computes the digest of the message to be signed. func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash, error) { - data, err := a.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(a) if err != nil { return nil, err } diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index e8c5799b41c..7a1a44aa04f 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" @@ -76,20 +77,25 @@ func TestChanAnnounce2Validation(t *testing.T) { t.Parallel() t.Run( - "test 4-of-4 MuSig2 channel announcement", - test4of4MuSig2ChanAnnouncement, + "test 4-of-4 MuSig2 P2TR channel announcement", + test4of4MuSig2P2TRChanAnnouncement, ) t.Run( - "test 3-of-3 MuSig2 channel announcement", + "test 3-of-3 MuSig2 P2TR channel announcement", test3of3MuSig2ChanAnnouncement, ) + + t.Run( + "test 4-of-4 MuSig2 P2WSH channel announcement", + test4of4MuSig2P2WSHChanAnnouncement, + ) } -// test4of4MuSig2ChanAnnouncement covers the case where both bitcoin keys are -// present in the channel announcement. In this case, the signature should be -// a 4-of-4 MuSig2. -func test4of4MuSig2ChanAnnouncement(t *testing.T) { +// test4of4MuSig2P2TRChanAnnouncement covers the case where the funding +// transaction PK script is a P2WSH. In this case, the signature should be valid +// for the MuSig2 4-of-4 aggregation of the node keys and the bitcoin keys. +func test4of4MuSig2P2WSHChanAnnouncement(t *testing.T) { t.Parallel() // Generate the keys for node 1 and node2. @@ -162,10 +168,138 @@ func test4of4MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + multiSigScript, err := input.GenMultiSigScript( + node1.btcPub.SerializeCompressed(), + node2.btcPub.SerializeCompressed(), + ) + require.NoError(t, err) + + scriptHash, err := input.WitnessScriptHash(multiSigScript) + require.NoError(t, err) + pkAddr, err := btcutil.NewAddressScriptHash( + scriptHash, &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV0ScriptHashTy, pkAddr, nil + } // Validate the announcement. - require.NoError(t, ValidateChannelAnn(ann, nil)) + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) +} + +// test4of4MuSig2P2TRChanAnnouncement covers the case where both bitcoin keys +// are present in the channel announcement 2 and the funding transaction PK +// script is a P2TR. In this case, the signature should be a 4-of-4 MuSig2. +func test4of4MuSig2P2TRChanAnnouncement(t *testing.T) { + t.Parallel() + + // Generate the keys for node 1 and node2. + node1, node2 := genChanAnnKeys(t) + + // Build the unsigned channel announcement. + ann := buildUnsignedChanAnnouncement(node1, node2, true) + + // Serialise the bytes that need to be signed. + msg, err := ChanAnn2DigestToSign(ann) + require.NoError(t, err) + + var msgBytes [32]byte + copy(msgBytes[:], msg.CloneBytes()) + + // Generate the 4 nonces required for producing the signature. + var ( + node1NodeNonce = genNonceForPubKey(t, node1.nodePub) + node1BtcNonce = genNonceForPubKey(t, node1.btcPub) + node2NodeNonce = genNonceForPubKey(t, node2.nodePub) + node2BtcNonce = genNonceForPubKey(t, node2.btcPub) + ) + + nonceAgg, err := musig2.AggregateNonces([][66]byte{ + node1NodeNonce.PubNonce, + node1BtcNonce.PubNonce, + node2NodeNonce.PubNonce, + node2BtcNonce.PubNonce, + }) + require.NoError(t, err) + + pubKeys := []*btcec.PublicKey{ + node1.nodePub, node2.nodePub, node1.btcPub, node2.btcPub, + } + + // Let Node1 sign the announcement message with its node key. + psA1, err := musig2.Sign( + node1NodeNonce.SecNonce, node1.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node1 sign the announcement message with its bitcoin key. + psA2, err := musig2.Sign( + node1BtcNonce.SecNonce, node1.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its node key. + psB1, err := musig2.Sign( + node2NodeNonce.SecNonce, node2.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its bitcoin key. + psB2, err := musig2.Sign( + node2BtcNonce.SecNonce, node2.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Finally, combine the partial signatures from Node1 and Node2 and add + // the signature to the announcement message. + s := musig2.CombineSigs(psA1.R, []*musig2.PartialSignature{ + psA1, psA2, psB1, psB2, + }) + + sig, err := lnwire.NewSigFromSignature(s) + require.NoError(t, err) + + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + combinedKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{node1.btcPub, node2.btcPub}, true, + ) + require.NoError(t, err) + + pkAddr, err := btcutil.NewAddressTaproot( + combinedKey.FinalKey.SerializeCompressed()[1:], + &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil + } + + // Validate the announcement. + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) } // test3of3MuSig2ChanAnnouncement covers the case where no bitcoin keys are @@ -220,14 +354,17 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { }) require.NoError(t, err) - pkScript, err := input.PayToTaprootScript(outputKey) + pkAddr, err := btcutil.NewAddressTaproot( + outputKey.SerializeCompressed()[1:], &chaincfg.MainNetParams, + ) require.NoError(t, err) - // We'll pass in a mock tx fetcher that will return the funding output - // containing this key. This is needed since the output key can not be - // determined from the channel announcement itself. - fetchTx := func(chanID *lnwire.ShortChannelID) ([]byte, error) { - return pkScript, nil + // Create a mock tx fetcher that returns the expected script class + // and pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil } pubKeys := []*btcec.PublicKey{node1.nodePub, node2.nodePub, outputKey} @@ -262,7 +399,7 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig // Validate the announcement. require.NoError(t, ValidateChannelAnn(ann, fetchTx)) diff --git a/netann/channel_update.go b/netann/channel_update.go index efc5cf61e49..3180961afdf 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -235,7 +235,7 @@ func verifyChannelUpdate2Signature(c *lnwire.ChannelUpdate2, return fmt.Errorf("unable to reconstruct message data: %w", err) } - nodeSig, err := c.Signature.ToSignature() + nodeSig, err := c.Signature.Val.ToSignature() if err != nil { return err } @@ -323,7 +323,7 @@ func ChanUpdate2DigestTag() []byte { // chanUpdate2DigestToSign computes the digest of the ChannelUpdate2 message to // be signed. func chanUpdate2DigestToSign(c *lnwire.ChannelUpdate2) ([]byte, error) { - data, err := c.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(c) if err != nil { return nil, err } diff --git a/server.go b/server.go index ad521cd6736..c49055fb5b1 100644 --- a/server.go +++ b/server.go @@ -1152,7 +1152,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Graph: s.graphBuilder, ChainIO: s.cc.ChainIO, Notifier: s.cc.ChainNotifier, - ChainHash: *s.cfg.ActiveNetParams.GenesisHash, + ChainParams: s.cfg.ActiveNetParams.Params, Broadcast: s.BroadcastMessage, ChanSeries: chanSeries, NotifyWhenOnline: s.NotifyWhenOnline,