diff --git a/config_builder.go b/config_builder.go index 49a0b6c507b..69e7066ce77 100644 --- a/config_builder.go +++ b/config_builder.go @@ -874,6 +874,10 @@ type DatabaseInstances struct { // configuration. TowerServerDB watchtower.DB + // PeerStorageDB is the database that stores the data that peers shares + // with us for backup. + PeerStorageDB kvdb.Backend + // WalletDB is the configuration for loading the wallet database using // the btcwallet's loader. WalletDB btcwallet.LoaderOption @@ -941,6 +945,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( DecayedLogDB: databaseBackends.DecayedLogDB, WalletDB: databaseBackends.WalletDB, NativeSQLStore: databaseBackends.NativeSQLStore, + PeerStorageDB: databaseBackends.PeerStorageDB, } cleanUp := func() { // We can just close the returned close functions directly. Even diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 7628726ef0b..7820f064fe9 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -487,6 +487,12 @@ [spec change](https://github.com/lightning/bolts/pull/1092/commits/e0ee59f3c92b7c98be8dfc47b7db358b45baf9de) that meant we shouldn't require it. +* [Implement feature bits, message and functionality to backup peer data in the +peer backup proposal](https://github.com/lightningnetwork/lnd/pull/8490) + This PR implements the feature bits, messages and functionality to backup + peer data referenced in the peer backup proposal here: + https://github.com/lightning/bolts/pull/1110 + ## Testing * Added fuzz tests for [onion diff --git a/feature/default_sets.go b/feature/default_sets.go index cc802fe8593..9a85b5521c4 100644 --- a/feature/default_sets.go +++ b/feature/default_sets.go @@ -92,4 +92,8 @@ var defaultSetDesc = setDesc{ SetInit: {}, // I SetNodeAnn: {}, // N }, + lnwire.ProvideStorageOptional: { + SetInit: {}, // I + SetNodeAnn: {}, // N + }, } diff --git a/feature/manager.go b/feature/manager.go index c7029e89382..1819dfdb780 100644 --- a/feature/manager.go +++ b/feature/manager.go @@ -66,6 +66,9 @@ type Config struct { // CustomFeatures is a set of custom features to advertise in each // set. CustomFeatures map[Set][]lnwire.FeatureBit + + // NoPeerStorage unsets any bits signalling support for peer storage. + NoPeerStorage bool } // Manager is responsible for generating feature vectors for different requested @@ -188,6 +191,10 @@ func newManager(cfg Config, desc setDesc) (*Manager, error) { raw.Unset(lnwire.RouteBlindingOptional) raw.Unset(lnwire.RouteBlindingRequired) } + if cfg.NoPeerStorage { + raw.Unset(lnwire.ProvideStorageOptional) + } + for _, custom := range cfg.CustomFeatures[set] { if custom > set.Maximum() { return nil, fmt.Errorf("feature bit: %v "+ diff --git a/lncfg/db.go b/lncfg/db.go index 4072a51feaf..3640362c5a4 100644 --- a/lncfg/db.go +++ b/lncfg/db.go @@ -25,12 +25,14 @@ const ( TowerClientDBName = "wtclient.db" TowerServerDBName = "watchtower.db" WalletDBName = "wallet.db" + PeerStorageDBName = "peer_storage.db" - SqliteChannelDBName = "channel.sqlite" - SqliteChainDBName = "chain.sqlite" - SqliteNeutrinoDBName = "neutrino.sqlite" - SqliteTowerDBName = "watchtower.sqlite" - SqliteNativeDBName = "lnd.sqlite" + SqliteChannelDBName = "channel.sqlite" + SqliteChainDBName = "chain.sqlite" + SqliteNeutrinoDBName = "neutrino.sqlite" + SqliteTowerDBName = "watchtower.sqlite" + SqliteNativeDBName = "lnd.sqlite" + SqlitePeerStorageDBName = "peerStorage.sqlite" BoltBackend = "bolt" EtcdBackend = "etcd" @@ -67,6 +69,10 @@ const ( // NSNeutrinoDB is the namespace name that we use for the neutrino DB. NSNeutrinoDB = "neutrinodb" + + // NSPeerStorageDB is the namespace name that we use for peer storage + // DB. + NSPeerStorageDB = "peerstoragedb" ) // DB holds database configuration for LND. @@ -227,6 +233,10 @@ type DatabaseBackends struct { // server data. This might be nil if the watchtower server is disabled. TowerServerDB kvdb.Backend + // PeerStorageDB points to the database backend that stores the backup + // data that peers share with us. + PeerStorageDB kvdb.Backend + // WalletDB is an option that instructs the wallet loader where to load // the underlying wallet database from. WalletDB btcwallet.LoaderOption @@ -349,6 +359,16 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSTowerServerDB] = etcdTowerServerBackend.Close + etcdPeerStorageBackend, err := kvdb.Open( + kvdb.EtcdBackendName, ctx, + db.Etcd.CloneWithSubNamespace(NSPeerStorageDB), + ) + if err != nil { + return nil, fmt.Errorf("error opening etcd "+ + "peer storage DB: %w", err) + } + closeFuncs[NSPeerStorageDB] = etcdPeerStorageBackend.Close + etcdWalletBackend, err := kvdb.Open( kvdb.EtcdBackendName, ctx, db.Etcd. @@ -357,7 +377,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, ) if err != nil { return nil, fmt.Errorf("error opening etcd macaroon "+ - "DB: %v", err) + "DB: %w", err) } closeFuncs[NSWalletDB] = etcdWalletBackend.Close @@ -371,6 +391,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, DecayedLogDB: etcdDecayedLogBackend, TowerClientDB: etcdTowerClientBackend, TowerServerDB: etcdTowerServerBackend, + PeerStorageDB: etcdPeerStorageBackend, // The wallet loader will attempt to use/create the // wallet in the replicated remote DB if we're running // in a clustered environment. This will ensure that all @@ -439,6 +460,16 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSTowerServerDB] = postgresTowerServerBackend.Close + postgresPeerStorageBackend, err := kvdb.Open( + kvdb.PostgresBackendName, ctx, + postgresConfig, NSPeerStorageDB, + ) + if err != nil { + return nil, fmt.Errorf("error opening postgres "+ + "peer storage server DB: %w", err) + } + closeFuncs[NSPeerStorageDB] = postgresPeerStorageBackend.Close + postgresWalletBackend, err := kvdb.Open( kvdb.PostgresBackendName, ctx, postgresConfig, NSWalletDB, @@ -482,6 +513,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, DecayedLogDB: postgresDecayedLogBackend, TowerClientDB: postgresTowerClientBackend, TowerServerDB: postgresTowerServerBackend, + PeerStorageDB: postgresPeerStorageBackend, // The wallet loader will attempt to use/create the // wallet in the replicated remote DB if we're running // in a clustered environment. This will ensure that all @@ -561,6 +593,16 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSTowerServerDB] = sqliteTowerServerBackend.Close + sqlitePeerStorageBackend, err := kvdb.Open( + kvdb.SqliteBackendName, ctx, sqliteConfig, chanDBPath, + SqlitePeerStorageDBName, NSPeerStorageDB, + ) + if err != nil { + return nil, fmt.Errorf("error opening sqlite peer "+ + "storage server DB: %w", err) + } + closeFuncs[NSPeerStorageDB] = sqlitePeerStorageBackend.Close + sqliteWalletBackend, err := kvdb.Open( kvdb.SqliteBackendName, ctx, sqliteConfig, walletDBPath, SqliteChainDBName, NSWalletDB, @@ -605,6 +647,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, DecayedLogDB: sqliteDecayedLogBackend, TowerClientDB: sqliteTowerClientBackend, TowerServerDB: sqliteTowerServerBackend, + PeerStorageDB: sqlitePeerStorageBackend, // The wallet loader will attempt to use/create the // wallet in the replicated remote DB if we're running // in a clustered environment. This will ensure that all @@ -645,6 +688,20 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSMacaroonDB] = macaroonBackend.Close + peerStorageBackend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ + DBPath: chanDBPath, + DBFileName: PeerStorageDBName, + DBTimeout: db.Bolt.DBTimeout, + NoFreelistSync: db.Bolt.NoFreelistSync, + AutoCompact: db.Bolt.AutoCompact, + AutoCompactMinAge: db.Bolt.AutoCompactMinAge, + }) + if err != nil { + return nil, fmt.Errorf("error opening peer storage DB: "+ + "%w", err) + } + closeFuncs[NSPeerStorageDB] = peerStorageBackend.Close + decayedLogBackend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ DBPath: chanDBPath, DBFileName: DecayedLogDbName, @@ -710,6 +767,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, DecayedLogDB: decayedLogBackend, TowerClientDB: towerClientBackend, TowerServerDB: towerServerBackend, + PeerStorageDB: peerStorageBackend, // When "running locally", LND will use the bbolt wallet.db to // store the wallet located in the chain data dir, parametrized // by the active network. The wallet loader has its own cleanup diff --git a/lncfg/protocol.go b/lncfg/protocol.go index e98b4dcf885..6f6687baf77 100644 --- a/lncfg/protocol.go +++ b/lncfg/protocol.go @@ -57,6 +57,10 @@ type ProtocolOptions struct { // NoRouteBlindingOption disables forwarding of payments in blinded routes. NoRouteBlindingOption bool `long:"no-route-blinding" description:"do not forward payments that are a part of a blinded route"` + + // OptionPeerStorage, when set to true, enables storage of backup data + // shared by peers. + OptionPeerStorage bool `long:"peer-storage" description:"store peer's backup data'"` } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo @@ -105,3 +109,9 @@ func (l *ProtocolOptions) NoTimestampsQuery() bool { func (l *ProtocolOptions) NoRouteBlinding() bool { return l.NoRouteBlindingOption } + +// PeerStorage returns true if we want to enable storage of backup data +// shared by peers. +func (l *ProtocolOptions) PeerStorage() bool { + return l.OptionPeerStorage +} diff --git a/lncfg/protocol_integration.go b/lncfg/protocol_integration.go index 841f8e9eb67..95ab5bb3fbb 100644 --- a/lncfg/protocol_integration.go +++ b/lncfg/protocol_integration.go @@ -60,6 +60,10 @@ type ProtocolOptions struct { // NoRouteBlindingOption disables forwarding of payments in blinded routes. NoRouteBlindingOption bool `long:"no-route-blinding" description:"do not forward payments that are a part of a blinded route"` + + // OptionPeerStorage, when set to true, enables storage of backup data + // shared by peers. + OptionPeerStorage bool `long:"peer-storage" description:"store peer's backup data'"` } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo @@ -100,3 +104,9 @@ func (l *ProtocolOptions) NoAnySegwit() bool { func (l *ProtocolOptions) NoRouteBlinding() bool { return l.NoRouteBlindingOption } + +// PeerStorage returns true if we want to enable storage of backup data +// shared by peers. +func (l *ProtocolOptions) PeerStorage() bool { + return l.OptionPeerStorage +} diff --git a/lnwire/features.go b/lnwire/features.go index e4dd7f4f81c..2af7494f544 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -171,6 +171,16 @@ const ( // sender-generated preimages according to BOLT XX. AMPOptional FeatureBit = 31 + // ProvideStorageRequired is a required feature bit that + // indicates that a node offers storing arbitrary data for their peers. + // See: https://github.com/lightning/bolts/pull/1110/ + ProvideStorageRequired FeatureBit = 42 + + // ProvideStorageOptional is an optional feature bit that + // indicates that a node offers storing arbitrary data for their peers. + // See: https://github.com/lightning/bolts/pull/1110/ + ProvideStorageOptional FeatureBit = 43 + // ExplicitChannelTypeRequired is a required bit that denotes that a // connection established with this node is to use explicit channel // commitment types for negotiation instead of the existing implicit @@ -331,6 +341,8 @@ var Features = map[FeatureBit]string{ SimpleTaprootChannelsOptionalFinal: "simple-taproot-chans", SimpleTaprootChannelsRequiredStaging: "simple-taproot-chans-x", SimpleTaprootChannelsOptionalStaging: "simple-taproot-chans-x", + ProvideStorageOptional: "provide-storage", + ProvideStorageRequired: "provide-storage", } // RawFeatureVector represents a set of feature bits as defined in BOLT-09. A diff --git a/lnwire/fuzz_test.go b/lnwire/fuzz_test.go index 542a2f0c0a9..adb40fbaaa2 100644 --- a/lnwire/fuzz_test.go +++ b/lnwire/fuzz_test.go @@ -900,3 +900,25 @@ func FuzzClosingComplete(f *testing.F) { harness(t, data) }) } + +func FuzzPeerStorage(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + // Prefix with PeerStorage. + data = prefixWithMsgType(data, MsgPeerStorage) + + // Pass the message into our general fuzz harness for wire + // messages! + harness(t, data) + }) +} + +func FuzzPeerStorageRetrieval(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + // Prefix with PeerStorage. + data = prefixWithMsgType(data, MsgPeerStorageRetrieval) + + // Pass the message into our general fuzz harness for wire + // messages! + harness(t, data) + }) +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 8ab082b0bdb..27eab07a3ea 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -460,6 +460,9 @@ func WriteElement(w *bytes.Buffer, element interface{}) error { case ExtraOpaqueData: return e.Encode(w) + case PeerStorageBlob: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -939,6 +942,9 @@ func ReadElement(r io.Reader, element interface{}) error { case *ExtraOpaqueData: return e.Decode(r) + case *PeerStorageBlob: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e4c5c6baf55..925b1d79973 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1595,6 +1595,18 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgPeerStorage, + scenario: func(m PeerStorage) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgPeerStorageRetrieval, + scenario: func(m PeerStorageRetrieval) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index bcee9f86d4d..e8bad6785c7 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -23,6 +23,8 @@ type MessageType uint16 // Lightning protocol. const ( MsgWarning MessageType = 1 + MsgPeerStorage = 7 + MsgPeerStorageRetrieval = 9 MsgInit = 16 MsgError = 17 MsgPing = 18 @@ -152,6 +154,10 @@ func (t MessageType) String() string { return "ClosingComplete" case MsgClosingSig: return "ClosingSig" + case MsgPeerStorage: + return "PeerStorage" + case MsgPeerStorageRetrieval: + return "PeerStorageRetrieval" default: return "" } @@ -279,6 +285,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &ClosingComplete{} case MsgClosingSig: msg = &ClosingSig{} + case MsgPeerStorage: + msg = &PeerStorage{} + case MsgPeerStorageRetrieval: + msg = &PeerStorageRetrieval{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. diff --git a/lnwire/message_test.go b/lnwire/message_test.go index bbb434785f8..b6e635ff50c 100644 --- a/lnwire/message_test.go +++ b/lnwire/message_test.go @@ -290,6 +290,8 @@ func makeAllMessages(t testing.TB, r *rand.Rand) []lnwire.Message { msgAll = append(msgAll, newMsgGossipTimestampRange(t, r)) msgAll = append(msgAll, newMsgQueryShortChanIDsZlib(t, r)) msgAll = append(msgAll, newMsgReplyChannelRangeZlib(t, r)) + msgAll = append(msgAll, newMsgPeerStorage(t, r)) + msgAll = append(msgAll, newMsgPeerStorageRetrieval(t, r)) return msgAll } @@ -883,6 +885,36 @@ func newMsgGossipTimestampRange(t testing.TB, return msg } +func newMsgPeerStorage(t testing.TB, r *rand.Rand) *lnwire.PeerStorage { + t.Helper() + + // Read random bytes. + data := make([]byte, r.Intn(lnwire.MaxPeerStorageBytes)) + _, err := r.Read(data) + require.NoError(t, err, "unable to generate peer storage "+ + "blob") + + return &lnwire.PeerStorage{ + Blob: data, + } +} + +func newMsgPeerStorageRetrieval(t testing.TB, + r *rand.Rand) *lnwire.PeerStorageRetrieval { + + t.Helper() + + // Read random bytes. + data := make([]byte, r.Intn(lnwire.MaxPeerStorageBytes)) + _, err := r.Read(data) + require.NoError(t, err, "unable to generate peer storage "+ + "blob") + + return &lnwire.PeerStorageRetrieval{ + Blob: data, + } +} + func randRawKey(t testing.TB) [33]byte { t.Helper() diff --git a/lnwire/peer_storage.go b/lnwire/peer_storage.go new file mode 100644 index 00000000000..47232bf521c --- /dev/null +++ b/lnwire/peer_storage.go @@ -0,0 +1,61 @@ +package lnwire + +import ( + "bytes" + "errors" + "io" +) + +// MaxPeerStorageBytes is the maximum size in bytes of the blob in peer storage +// message. +const MaxPeerStorageBytes = 65531 + +// ErrPeerStorageBytesExceeded is returned when the Peer Storage blob's size +// exceeds MaxPeerStorageBytes. +var ErrPeerStorageBytesExceeded = errors.New("peer storage bytes exceeded") + +// PeerStorage contains a data blob that the sending peer would like the +// receiving peer to store. +type PeerStorage struct { + // Blob is data for the receiving peer to store from the sender. + Blob PeerStorageBlob +} + +// NewPeerStorageMsg creates new instance of PeerStorage message object. +func NewPeerStorageMsg(data PeerStorageBlob) (*PeerStorage, error) { + if len(data) > MaxPeerStorageBytes { + return nil, ErrPeerStorageBytesExceeded + } + + return &PeerStorage{ + Blob: data, + }, nil +} + +// A compile time check to ensure PeerStorage implements the lnwire.Message +// interface. +var _ Message = (*PeerStorage)(nil) + +// Decode deserializes a serialized PeerStorage message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorage) Decode(r io.Reader, _ uint32) error { + return ReadElement(r, &msg.Blob) +} + +// Encode serializes the target PeerStorage into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorage) Encode(w *bytes.Buffer, _ uint32) error { + return WriteElement(w, msg.Blob) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorage) MsgType() MessageType { + return MsgPeerStorage +} diff --git a/lnwire/peer_storage_blob.go b/lnwire/peer_storage_blob.go new file mode 100644 index 00000000000..f5f137cf18d --- /dev/null +++ b/lnwire/peer_storage_blob.go @@ -0,0 +1,49 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "io" +) + +// PeerStorageBlob is the type of the data sent and received by peers in the +// `PeerStorage` and `PeerStorageRetrieval` message. +type PeerStorageBlob []byte + +// Encode writes the PeerStorageBlob to the passed bytes.Buffer. +func (p *PeerStorageBlob) Encode(w *bytes.Buffer) error { + // Write length first. + var l [2]byte + blob := *p + binary.BigEndian.PutUint16(l[:], uint16(len(blob))) + if _, err := w.Write(l[:]); err != nil { + return err + } + + // Then, write in the actual blob. + if _, err := w.Write(blob[:]); err != nil { + return err + } + + return nil +} + +// Decode reads the passed io.Reader into a PeerStorageBlob. +func (p *PeerStorageBlob) Decode(r io.Reader) error { + // Read length first. + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + peerStorageLen := binary.BigEndian.Uint16(l[:]) + + // Then read the actual blob. + storageBlob := make(PeerStorageBlob, peerStorageLen) + if _, err := io.ReadFull(r, storageBlob); err != nil { + return err + } + + *p = storageBlob + + return nil +} diff --git a/lnwire/peer_storage_retrieval.go b/lnwire/peer_storage_retrieval.go new file mode 100644 index 00000000000..d5d4d9fe2e4 --- /dev/null +++ b/lnwire/peer_storage_retrieval.go @@ -0,0 +1,50 @@ +package lnwire + +import ( + "bytes" + "io" +) + +// PeerStorageRetrieval stores the last PeerStorage message received from +// that particular peer. It is sent to that peer on reconnection after the +// `init` message and before the `channelReestablish` message on reconnection. +type PeerStorageRetrieval struct { + // Blob contains data a peer backs up for another. + Blob PeerStorageBlob +} + +// NewPeerStorageRetrievalMsg creates a new instance of PeerStorageRetrieval +// message object. +func NewPeerStorageRetrievalMsg(data PeerStorageBlob) *PeerStorageRetrieval { + return &PeerStorageRetrieval{ + Blob: data, + } +} + +// A compile time check to ensure PeerStorageRetrieval implements the +// lnwire.Message interface. +var _ Message = (*PeerStorageRetrieval)(nil) + +// Decode deserializes a serialized PeerStorageRetrieval message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorageRetrieval) Decode(r io.Reader, _ uint32) error { + return ReadElement(r, &msg.Blob) +} + +// Encode serializes the target PeerStorageRetrieval into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorageRetrieval) Encode(w *bytes.Buffer, _ uint32) error { + return WriteElement(w, msg.Blob) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (msg *PeerStorageRetrieval) MsgType() MessageType { + return MsgPeerStorageRetrieval +} diff --git a/peer/brontide.go b/peer/brontide.go index 1fa6c1311a8..d39036b3ccf 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -76,6 +76,8 @@ const ( // ErrorBufferSize is the number of historic peer errors that we store. ErrorBufferSize = 10 + + backupUpdateInterval = 1 * time.Second ) var ( @@ -139,6 +141,21 @@ type TimestampedError struct { Timestamp time.Time } +// PeerDataStore is an interface representing the structure enabling the peer +// to carry out operations to store other peer's data. +// +//nolint:revive +type PeerDataStore interface { + // Store saves backup data received from peers. + Store(data []byte) error + + // Delete eliminates a peer's data from the store. + Delete() error + + // Retrieve fetches the data of a peer from the store. + Retrieve() ([]byte, error) +} + // Config defines configuration fields that are necessary for a peer object // to function. type Config struct { @@ -372,6 +389,10 @@ type Config struct { // Quit is the server's quit channel. If this is closed, we halt operation. Quit chan struct{} + + // PeerDataStore is the storage layer that helps us store other peer's + // backup data. + PeerDataStore PeerDataStore } // Brontide is an active peer on the Lightning Network. This struct is responsible @@ -495,6 +516,11 @@ type Brontide struct { // log is a peer-specific logging instance. log btclog.Logger + + // backupData is an in-memory store for data that we store for our + // peers. + backupData lnwire.PeerStorageBlob + bMtx *sync.Cond } // A compile-time check to ensure that Brontide satisfies the lnpeer.Peer interface. @@ -526,6 +552,7 @@ func NewBrontide(cfg Config) *Brontide { startReady: make(chan struct{}), quit: make(chan struct{}), log: build.NewPrefixLog(logPrefix, peerLog), + bMtx: sync.NewCond(&sync.Mutex{}), } var ( @@ -724,15 +751,39 @@ func (p *Brontide) Start() error { } } + // Send this peer its backup data if we have it. This is sent after the + // init message and before the channelReestablish message. + if p.LocalFeatures().HasFeature(lnwire.ProvideStorageOptional) { + data, err := p.cfg.PeerDataStore.Retrieve() + if err != nil { + return fmt.Errorf("unable to retrieve peer "+ + "backup data: %v", err) + } + + if data != nil { + if err := p.writeMessage( + lnwire.NewPeerStorageRetrievalMsg(data), + ); err != nil { + return fmt.Errorf("unable to send "+ + "PeerStorageRetrieval msg to peer on "+ + "connection: %v", err) + } + } + } + err = p.pingManager.Start() if err != nil { return fmt.Errorf("could not start ping manager %w", err) } - p.wg.Add(4) + p.wg.Add(5) go p.queueHandler() go p.writeHandler() go p.channelManager() + + // Initialize peerStorageWriter before readHandler to ensure readiness + // for storing `PeerStorage` messages upon receipt. + go p.peerStorageWriter() go p.readHandler() // Signal to any external processes that the peer is now active. @@ -754,6 +805,75 @@ func (p *Brontide) Start() error { return nil } +// peerStorageWriter coordinates persisting peer's backup data by delaying its +// storage from its time of receipt to its time of persistence by the +// duration specified in the `backupUpdateInterval`. +func (p *Brontide) peerStorageWriter() { + defer p.wg.Done() + + var data []byte + +Loop: + p.bMtx.L.Lock() + for { + p.bMtx.Wait() + + // Store the data in a different variable as we are about to + // exit lock. + data = p.backupData + p.bMtx.L.Unlock() + + // Check if we are to exit, now that we are awake. + select { + case <-p.quit: + if data == nil { + return + } + + // Store the data immediately and exit. + err := p.cfg.PeerDataStore.Store(data) + if err != nil { + peerLog.Warnf("Failed to store peer "+ + "backup data: %v", err) + } + + return + + default: + } + + break + } + + t := time.NewTicker(backupUpdateInterval) + + select { + case <-t.C: + // Store the data. + err := p.cfg.PeerDataStore.Store(data) + if err != nil { + peerLog.Criticalf("Failed to store peer "+ + "backup data: %v", err) + } + + goto Loop + + case <-p.quit: + if data == nil { + return + } + + // Store the data immediately and exit. + err := p.cfg.PeerDataStore.Store(data) + if err != nil { + peerLog.Warnf("Failed to store peer backup "+ + "data: %v", err) + } + + return + } +} + // initGossipSync initializes either a gossip syncer or an initial routing // dump, depending on the negotiated synchronization method. func (p *Brontide) initGossipSync() { @@ -1231,6 +1351,12 @@ func (p *Brontide) WaitForDisconnect(ready chan struct{}) { p.wg.Wait() } +func (p *Brontide) IsDisconnected() bool { + val := atomic.LoadInt32(&p.disconnect) + + return val == 1 +} + // Disconnect terminates the connection with the remote peer. Additionally, a // signal is sent to the server and htlcSwitch indicating the resources // allocated to the peer can now be cleaned up. @@ -1257,6 +1383,10 @@ func (p *Brontide) Disconnect(reason error) { close(p.quit) + // Signal the peerStorageWriter to stop waiting and pick up the close + // signal. + p.bMtx.Signal() + if err := p.pingManager.Stop(); err != nil { p.log.Errorf("couldn't stop pingManager during disconnect: %v", err) @@ -1649,7 +1779,7 @@ func (p *Brontide) readHandler() { discStream.Start() defer discStream.Stop() out: - for atomic.LoadInt32(&p.disconnect) == 0 { + for !p.IsDisconnected() { nextMsg, err := p.readNextMessage() if !idleTimer.Stop() { select { @@ -1798,6 +1928,19 @@ out: discStream.AddMsg(msg) + case *lnwire.PeerStorage: + err = p.handlePeerStorageMessage(msg) + if err != nil { + p.storeError(err) + p.log.Errorf("%v", err) + } + + case *lnwire.PeerStorageRetrieval: + err = p.handlePeerStorageRetrieval() + if err != nil { + p.storeError(err) + p.log.Errorf("%v", err) + } case *lnwire.Custom: err := p.handleCustomMessage(msg) if err != nil { @@ -1898,6 +2041,25 @@ func (p *Brontide) hasChannel(chanID lnwire.ChannelID) bool { // channel with the peer to mitigate a dos vector where a peer costlessly // connects to us and spams us with errors. func (p *Brontide) storeError(err error) { + // If we do not have any active channels with the peer, we do not store + // errors as a dos mitigation. + if !p.hasActiveChannels() { + p.log.Trace("no channels with peer, not storing err") + return + } + + p.cfg.ErrorBuffer.Add( + &TimestampedError{Timestamp: time.Now(), Error: err}, + ) +} + +// hasActiveChannels checks if the Brontide instance has any active Lightning +// network channels that are currently open and not pending. +// +// Returns: +// - true if there is at least one active channel. +// - false if there are no active channels or all channels are pending. +func (p *Brontide) hasActiveChannels() bool { var haveChannels bool p.activeChannels.Range(func(_ lnwire.ChannelID, @@ -1915,16 +2077,7 @@ func (p *Brontide) storeError(err error) { return false }) - // If we do not have any active channels with the peer, we do not store - // errors as a dos mitigation. - if !haveChannels { - p.log.Trace("no channels with peer, not storing err") - return - } - - p.cfg.ErrorBuffer.Add( - &TimestampedError{Timestamp: time.Now(), Error: err}, - ) + return haveChannels } // handleWarningOrError processes a warning or error msg and returns true if @@ -2105,8 +2258,11 @@ func messageSummary(msg lnwire.Message) string { time.Unix(int64(msg.FirstTimestamp), 0), msg.TimestampRange) - case *lnwire.Custom: - return fmt.Sprintf("type=%d", msg.Type) + case *lnwire.Custom, + *lnwire.PeerStorageRetrieval, + *lnwire.PeerStorage: + + return fmt.Sprintf("type=%d", msg.MsgType()) } return fmt.Sprintf("unknown msg type=%T", msg) @@ -4124,6 +4280,45 @@ func (p *Brontide) handleRemovePendingChannel(req *newChannelMsg) { p.addedChannels.Delete(chanID) } +// handlePeerStorageMessage handles `PeerStorage` message, it stores the message +// and sends it back to the peer as an ack. +func (p *Brontide) handlePeerStorageMessage(msg *lnwire.PeerStorage) error { + // Check if we have the feature to store peer backup enabled. + if !p.LocalFeatures().HasFeature(lnwire.ProvideStorageOptional) { + warning := "received peer storage message but not " + + "advertising required feature bit" + + if err := p.SendMessage(false, &lnwire.Warning{ + ChanID: lnwire.ConnectionWideID, + Data: []byte(warning), + }); err != nil { + return err + } + + p.Disconnect(errors.New(warning)) + + return nil + } + + // If we have no active channels with this peer, we return quickly. + if !p.hasActiveChannels() { + p.log.Tracef("Received peerStorage message from "+ + "peer(%v) with no active channels -- ignoring", + p.String()) + + return nil + } + + p.log.Debugf("handling peerbackup for peer(%s)", p) + + p.bMtx.L.Lock() + p.backupData = msg.Blob + p.bMtx.Signal() + p.bMtx.L.Unlock() + + return nil +} + // sendLinkUpdateMsg sends a message that updates the channel to the // channel's message stream. func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { @@ -4148,3 +4343,23 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { // continue processing message. chanStream.AddMsg(msg) } + +// handlePeerStorageRetrieval sends a warning and disconnects any peer that +// sends us a `PeerStorageRetrieval` message. +func (p *Brontide) handlePeerStorageRetrieval() error { + peerLog.Tracef("received peerStorageRetrieval message from "+ + "peer, %v", p.Address()) + + warning := "receieved unexpected peerStorageRetrieval message" + + if err := p.SendMessage(false, &lnwire.Warning{ + ChanID: lnwire.ConnectionWideID, + Data: []byte(warning), + }); err != nil { + return err + } + + p.Disconnect(errors.New(warning)) + + return nil +} diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 8ea8846b2fe..8a3cf1fb1a6 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -2,26 +2,25 @@ package peer import ( "bytes" + "errors" "fmt" "testing" "time" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" - "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/pool" "github.com/stretchr/testify/require" ) @@ -39,19 +38,13 @@ var ( func TestPeerChannelClosureShutdownResponseLinkRemoved(t *testing.T) { t.Parallel() - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} + harness, err := createTestPeerWithChannel(t, noUpdate) + require.NoError(t, err, "unable to create test channels") - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, + var ( + alicePeer = harness.peer + bobChan = harness.channel ) - require.NoError(t, err, "unable to create test channels") chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -87,19 +80,16 @@ func TestPeerChannelClosureShutdownResponseLinkRemoved(t *testing.T) { func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { t.Parallel() - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} + harness, err := createTestPeerWithChannel(t, noUpdate) + require.NoError(t, err, "unable to create test channels") - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, + var ( + alicePeer = harness.peer + bobChan = harness.channel + mockSwitch = harness.mockSwitch + broadcastTxChan = harness.publishTx + notifier = harness.notifier ) - require.NoError(t, err, "unable to create test channels") chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -192,19 +182,16 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { t.Parallel() - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} + harness, err := createTestPeerWithChannel(t, noUpdate) + require.NoError(t, err, "unable to create test channels") - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, + var ( + bobChan = harness.channel + alicePeer = harness.peer + mockSwitch = harness.mockSwitch + broadcastTxChan = harness.publishTx + notifier = harness.notifier ) - require.NoError(t, err, "unable to create test channels") chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -316,19 +303,16 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { t.Parallel() - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} + harness, err := createTestPeerWithChannel(t, noUpdate) + require.NoError(t, err, "unable to create test channels") - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, + var ( + bobChan = harness.channel + alicePeer = harness.peer + mockSwitch = harness.mockSwitch + broadcastTxChan = harness.publishTx + notifier = harness.notifier ) - require.NoError(t, err, "unable to create test channels") chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -503,19 +487,16 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { t.Parallel() - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} + harness, err := createTestPeerWithChannel(t, noUpdate) + require.NoError(t, err, "unable to create test channels") - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, + var ( + alicePeer = harness.peer + bobChan = harness.channel + mockSwitch = harness.mockSwitch + broadcastTxChan = harness.publishTx + notifier = harness.notifier ) - require.NoError(t, err, "unable to create test channels") chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -830,31 +811,27 @@ func TestCustomShutdownScript(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - - mockSwitch := &mockMessageSwitch{} - // Open a channel. - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, test.update, - mockSwitch, + harness, err := createTestPeerWithChannel( + t, test.update, ) if err != nil { t.Fatalf("unable to create test channels: %v", err) } + var ( + alicePeer = harness.peer + bobChan = harness.channel + mockSwitch = harness.mockSwitch + ) + chanPoint := bobChan.ChannelPoint() chanID := lnwire.NewChanIDFromOutPoint(chanPoint) mockLink := newMockUpdateHandler(chanID) mockSwitch.links = append(mockSwitch.links, mockLink) - // Request initiator to cooperatively close the channel, with - // a specified delivery address. + // Request initiator to cooperatively close the channel, + // with a specified delivery address. updateChan := make(chan interface{}, 1) errChan := make(chan error, 1) closeCommand := htlcswitch.ChanClose{ @@ -1000,35 +977,23 @@ func TestStaticRemoteDowngrade(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - writeBufferPool := pool.NewWriteBuffer( - pool.DefaultWriteBufferGCInterval, - pool.DefaultWriteBufferExpiryInterval, - ) + params := createTestPeer(t) - writePool := pool.NewWrite( - writeBufferPool, 1, timeout, + var ( + p = params.peer + mockConn = params.mockConn + writePool = p.cfg.WritePool ) - require.NoError(t, writePool.Start()) - - mockConn := newMockConn(t, 1) - - p := Brontide{ - cfg: Config{ - LegacyFeatures: legacy, - Features: test.features, - Conn: mockConn, - WritePool: writePool, - PongBuf: make([]byte, lnwire.MaxPongBytes), - }, - log: peerLog, - } + // Set feature bits. + p.cfg.LegacyFeatures = legacy + p.cfg.Features = test.features var b bytes.Buffer _, err := lnwire.WriteMessage(&b, test.expectedInit, 0) require.NoError(t, err) - // Send our init message, assert that we write our expected message - // and shutdown our write pool. + // Send our init message, assert that we write our + // expected message and shutdown our write pool. require.NoError(t, p.sendInitMsg(test.legacy)) mockConn.assertWrite(b.Bytes()) require.NoError(t, writePool.Stop()) @@ -1056,100 +1021,19 @@ func genScript(t *testing.T, address string) lnwire.DeliveryAddress { func TestPeerCustomMessage(t *testing.T) { t.Parallel() - // Set up node Alice. - dbAlice, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - - aliceKey, err := btcec.NewPrivateKey() - require.NoError(t, err) - - writeBufferPool := pool.NewWriteBuffer( - pool.DefaultWriteBufferGCInterval, - pool.DefaultWriteBufferExpiryInterval, - ) + params := createTestPeer(t) - writePool := pool.NewWrite( - writeBufferPool, 1, timeout, - ) - require.NoError(t, writePool.Start()) - - readBufferPool := pool.NewReadBuffer( - pool.DefaultReadBufferGCInterval, - pool.DefaultReadBufferExpiryInterval, - ) - - readPool := pool.NewRead( - readBufferPool, 1, timeout, + var ( + mockConn = params.mockConn + alicePeer = params.peer + receivedCustomChan = params.customChan + remoteKey = alicePeer.PubKey() ) - require.NoError(t, readPool.Start()) - - mockConn := newMockConn(t, 1) - - receivedCustomChan := make(chan *customMsg) - - remoteKey := [33]byte{8} - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - - // TODO(yy): change ChannelNotifier to be an interface. - channelNotifier := channelnotifier.New(dbAlice.ChannelStateDB()) - require.NoError(t, channelNotifier.Start()) - t.Cleanup(func() { - require.NoError(t, channelNotifier.Stop(), - "stop channel notifier failed") - }) - - alicePeer := NewBrontide(Config{ - PubKeyBytes: remoteKey, - ChannelDB: dbAlice.ChannelStateDB(), - Addr: &lnwire.NetAddress{ - IdentityKey: aliceKey.PubKey(), - }, - PrunePersistentPeerConnection: func([33]byte) {}, - Features: lnwire.EmptyFeatureVector(), - LegacyFeatures: lnwire.EmptyFeatureVector(), - WritePool: writePool, - ReadPool: readPool, - Conn: mockConn, - ChainNotifier: notifier, - HandleCustomMessage: func( - peer [33]byte, msg *lnwire.Custom) error { - - receivedCustomChan <- &customMsg{ - peer: peer, - msg: *msg, - } - return nil - }, - PongBuf: make([]byte, lnwire.MaxPongBytes), - ChannelNotifier: channelNotifier, - }) - - // Set up the init sequence. - go func() { - // Read init message. - <-mockConn.writtenMessages - - // Write the init reply message. - initReplyMsg := lnwire.NewInitMessage( - lnwire.NewRawFeatureVector( - lnwire.DataLossProtectRequired, - ), - lnwire.NewRawFeatureVector(), - ) - var b bytes.Buffer - _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) - require.NoError(t, err) - - mockConn.readMessages <- b.Bytes() - }() - - // Start the peer. - require.NoError(t, alicePeer.Start()) + // Start peer. + startPeerDone := startPeer(t, mockConn, alicePeer) + _, err := fn.RecvOrTimeout(startPeerDone, 2*timeout) + require.NoError(t, err) // Send a custom message. customMsg, err := lnwire.NewCustom( @@ -1185,21 +1069,12 @@ func TestUpdateNextRevocation(t *testing.T) { require := require.New(t) - // TODO(yy): create interface for lnwallet.LightningChannel so we can - // easily mock it without the following setups. - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - mockSwitch := &mockMessageSwitch{} - - alicePeer, bobChan, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, - ) + harness, err := createTestPeerWithChannel(t, noUpdate) require.NoError(err, "unable to create test channels") + bobChan := harness.channel + alicePeer := harness.peer + // testChannel is used to test the updateNextRevocation function. testChannel := bobChan.State() @@ -1412,30 +1287,22 @@ func TestHandleRemovePendingChannel(t *testing.T) { func TestStartupWriteMessageRace(t *testing.T) { t.Parallel() - // Set up parameters for createTestPeer. - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - mockSwitch := &mockMessageSwitch{} - - // Use a callback to extract the channel created by createTestPeer, so - // we can mark it borked below. We can't mark it borked within the - // callback, since the channel hasn't been saved to the DB yet when the - // callback executes. + // Use a callback to extract the channel created by + // createTestPeerWithChannel, so we can mark it borked below. + // We can't mark it borked within the callback, since the channel hasn't + // been saved to the DB yet when the callback executes. var channel *channeldb.OpenChannel getChannels := func(a, b *channeldb.OpenChannel) { channel = a } - // createTestPeer creates a peer and a channel with that peer. - peer, _, err := createTestPeer( - t, notifier, broadcastTxChan, getChannels, mockSwitch, - ) + // createTestPeerWithChannel creates a peer and a channel with that + // peer. + harness, err := createTestPeerWithChannel(t, getChannels) require.NoError(t, err, "unable to create test channel") + peer := harness.peer + // Avoid the need to mock the channel graph by marking the channel // borked. Borked channels still get a reestablish message sent on // reconnect, while skipping channel graph checks and link creation. @@ -1445,58 +1312,21 @@ func TestStartupWriteMessageRace(t *testing.T) { mockConn := newMockConn(t, 2) peer.cfg.Conn = mockConn - // Set up other configuration necessary to successfully execute - // peer.Start(). - peer.cfg.LegacyFeatures = lnwire.EmptyFeatureVector() - writeBufferPool := pool.NewWriteBuffer( - pool.DefaultWriteBufferGCInterval, - pool.DefaultWriteBufferExpiryInterval, - ) - writePool := pool.NewWrite( - writeBufferPool, 1, timeout, - ) - require.NoError(t, writePool.Start()) - peer.cfg.WritePool = writePool - readBufferPool := pool.NewReadBuffer( - pool.DefaultReadBufferGCInterval, - pool.DefaultReadBufferExpiryInterval, - ) - readPool := pool.NewRead( - readBufferPool, 1, timeout, - ) - require.NoError(t, readPool.Start()) - peer.cfg.ReadPool = readPool - // Send a message while starting the peer. As the peer starts up, it // should not trigger a data race between the sending of this message // and the sending of the channel reestablish message. - sendPingDone := make(chan struct{}) + var sendPingDone = make(chan struct{}) go func() { require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0))) close(sendPingDone) }() - // Handle init messages. - go func() { - // Read init message. - <-mockConn.writtenMessages - - // Write the init reply message. - initReplyMsg := lnwire.NewInitMessage( - lnwire.NewRawFeatureVector( - lnwire.DataLossProtectRequired, - ), - lnwire.NewRawFeatureVector(), - ) - var b bytes.Buffer - _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) - require.NoError(t, err) - - mockConn.readMessages <- b.Bytes() - }() - // Start the peer. No data race should occur. - require.NoError(t, peer.Start()) + startPeerDone := startPeer(t, mockConn, peer) + + // Ensure startup is complete. + _, err = fn.RecvOrTimeout(startPeerDone, 2*timeout) + require.NoError(t, err) // Ensure messages were sent during startup. <-sendPingDone @@ -1516,21 +1346,12 @@ func TestStartupWriteMessageRace(t *testing.T) { func TestRemovePendingChannel(t *testing.T) { t.Parallel() - // Set up parameters for createTestPeer. - notifier := &mock.ChainNotifier{ - SpendChan: make(chan *chainntnfs.SpendDetail), - EpochChan: make(chan *chainntnfs.BlockEpoch), - ConfChan: make(chan *chainntnfs.TxConfirmation), - } - broadcastTxChan := make(chan *wire.MsgTx) - mockSwitch := &mockMessageSwitch{} - - // createTestPeer creates a peer and a channel with that peer. - peer, _, err := createTestPeer( - t, notifier, broadcastTxChan, noUpdate, mockSwitch, - ) + // createTestPeerWithChannel creates a peer and a channel. + harness, err := createTestPeerWithChannel(t, noUpdate) require.NoError(t, err, "unable to create test channel") + peer := harness.peer + // Add a pending channel to the peer Alice. errChan := make(chan error, 1) pendingChanID := lnwire.ChannelID{1} @@ -1593,3 +1414,299 @@ func TestRemovePendingChannel(t *testing.T) { require.NoError(t, err) } + +// TestHandlePeerStorageRetrieval tests the `handlePeerStorageRetrieval +// ` brontide method. +func TestHandlePeerStorageRetrieval(t *testing.T) { + harness := createTestPeer(t) + + peer := harness.peer + + // Buffer outgoingQueue to prevent blocking. + peer.outgoingQueue = make(chan outgoingMsg, 1) + + // Send signal that the peer is ready and can handle disconnect. + close(peer.startReady) + + err := peer.handlePeerStorageRetrieval() + require.NoError(t, err) + + // Test that we send a warning to the peer. + select { + case receivedMsg := <-peer.outgoingQueue: + require.IsType(t, &lnwire.Warning{}, receivedMsg.msg) + + case <-time.After(timeout): + t.Fatalf("did not receive message " + + "as expected.") + } + + // Test that we disconnect. + require.True(t, peer.IsDisconnected()) +} + +// TestHandlePeerStorage tests handling peer storage message. +func TestHandlePeerStorage(t *testing.T) { + t.Parallel() + rt := require.New(t) + + // Create test data. + blob := []byte{0x9c, 0x40, 0x1, 0x2, 0x3} + + testCases := []struct { + name string + msgTestFunc func(msg lnwire.Message) + setUpFunc func(peer *Brontide) + noSendMsg bool + disconnect bool + backupData lnwire.PeerStorageBlob + }{ + { + name: "option peer storage disabled", + msgTestFunc: func(msg lnwire.Message) { + // In this case, we expect to send a warning to + // this peer. + rt.IsType(msg, &lnwire.Warning{}) + }, + disconnect: true, + }, + { + name: "option peer storage enabled, active channels " + + "present", + setUpFunc: func(peer *Brontide) { + // Enable option_peer_storage. + peer.cfg.Features = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector( + lnwire.ProvideStorageOptional, + ), + lnwire.Features, + ) + + // Add a dummy active channel. + chanID := lnwire.NewChanIDFromOutPoint( + wire.OutPoint{Index: 1}, + ) + peer.activeChannels.Store( + chanID, &lnwallet.LightningChannel{}, + ) + }, + noSendMsg: true, + backupData: blob, + }, + { + name: "option peer storage enabled, active channels " + + "absent", + setUpFunc: func(peer *Brontide) { + // Enable option_peer_storage. + peer.cfg.Features = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector( + lnwire.ProvideStorageOptional, + ), + lnwire.Features, + ) + + peer.activeChannels = &lnutils.SyncMap[ + lnwire.ChannelID, + *lnwallet.LightningChannel, + ]{} + }, + noSendMsg: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + params := createTestPeer(t) + peer := params.peer + + // The backupData should be nil on initialization. + rt.Nil(peer.backupData) + + // Send a signal to this channel to indicate readiness + // for potential peer disconnection. + close(peer.startReady) + + if tc.setUpFunc != nil { + tc.setUpFunc(peer) + } + + // Buffer outgoingQueue to prevent blocking. + peer.outgoingQueue = make(chan outgoingMsg, 1) + + peerStorageMsg, err := lnwire.NewPeerStorageMsg( + blob, + ) + rt.NoError(err) + + // Call the method. + err = peer.handlePeerStorageMessage(peerStorageMsg) + rt.NoError(err) + + // Test for the backup data stored in memory. + rt.Equal(tc.backupData, peer.backupData) + + // Test for expected outgoing messages, if any. + select { + case receivedMsg := <-peer.outgoingQueue: + if !tc.noSendMsg { + tc.msgTestFunc(receivedMsg.msg) + + return + } + + t.Fatalf("received unexpected "+ + "message %v", + receivedMsg.msg.MsgType()) + + case <-time.After(2 * time.Second): + if !tc.noSendMsg { + t.Fatalf("did not receive message " + + "as expected.") + } + } + + // Check if the peer should be disconnected. + rt.Equal(tc.disconnect, peer.IsDisconnected()) + }) + } +} + +// TestPeerStorageWriter tests the peerStorageWriter function. +func TestPeerStorageWriter(t *testing.T) { + harness := createTestPeer(t) + peer := harness.peer + + // Start the function in a goroutine to test its functionality in + // another thread. A successful test means the function started and + // performed as expected. + peer.wg.Add(1) + go peer.peerStorageWriter() + + // Update the backupData every quarter of the backupUpdateInterval. + // We do this to test that only the most recent update within this + // interval is persisted. This also tests that indeed, the storage is + // delayed. + interval := backupUpdateInterval / 4 + + // We would update the backupData eight times at every quarter of the + // backupUpdateInterval. That means this process would go on for + // 2 times the backupUpdateInterval duration. + for i := 0; i < 8; i++ { + ti := time.NewTicker(interval) + select { + case <-ti.C: + peer.bMtx.L.Lock() + peer.backupData = []byte{byte(i)} + peer.bMtx.Signal() + peer.bMtx.L.Unlock() + + case <-time.After(1 * time.Second): + t.Fatalf("did not receive ticker as expected.") + } + } + + // We expect to persist initial backup data at index 0. After one full + // backupUpdateInterval, the next data to be persisted is at index 4. + // Between data at index 4 and the final data at index 7, only 3/4 of + // the backupUpdateInterval elapses. The remaining interval is completed + // with this sleep command, so that the data at index 4 would be + // persisted. + time.Sleep(backupUpdateInterval / 4) + + retrievedData, err := peer.cfg.PeerDataStore.Retrieve() + require.NoError(t, err) + + // We expect one data persisted within a backupUpdateInterval. + // Since we sent updates within a duration of two times the + // backupUpdateInterval, we expect to have persisted two updates + // only. + require.Len(t, retrievedData, 2) + + // Convert the data to its corresponding integer values. + convToIntData := func(retrievedData []byte) []int { + return fn.Map(func(b byte) int { + return int(b) + }, retrievedData) + } + + // The backup data was updated eight times, each annotated with its + // index. Due to the delay, we expect only the data from the zeroth and + // fourth updates to be persisted in that order. + require.Equal(t, []int{0, 4}, convToIntData(retrievedData)) + + // Test that we store remaining data on quit. + close(peer.quit) + peer.bMtx.Signal() + + // The signal for data at index 7 wasn't picked up because it + // was sent during the storage delay for data at index 4. + // After sending another signal, we expect the + // `peerStorageWriter` to now pick it up, then bypass the + // storage delay and persist it immediately as we have now + // closed the peer's quit channel as well. + // + // Wait a bit to allow for storing before we retrieve. + time.Sleep(time.Second / 2) + retrievedData, err = peer.cfg.PeerDataStore.Retrieve() + require.NoError(t, err) + + require.Len(t, retrievedData, 3) + require.Equal(t, []int{0, 4, 7}, convToIntData(retrievedData)) + + // Wait for goroutine to exit(good manners). + peer.wg.Wait() +} + +// TestPeerBackupReconnect ensures that a peer sends the backup data, +// if available upon connection. It verifies the peer's behavior by simulating +// a reconnection and checking if the expected backup data is sent to the mock +// connection within a specified timeout. +func TestPeerBackupReconnect(t *testing.T) { + t.Parallel() + rt := require.New(t) + + params := createTestPeer(t) + + var ( + peer = params.peer + mockConn = params.mockConn + ) + + // Enable option_peer_storage. + peer.cfg.Features = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector( + lnwire.ProvideStorageOptional, + ), + lnwire.Features, + ) + + // Create sample backup data. + sampleData := []byte{0, 1, 2} + + // Store peer's backup data. + rt.NoError(peer.cfg.PeerDataStore.Store(sampleData)) + + // Test that we send the data to the peer on startup. + donePeer := startPeer(t, mockConn, peer) + t.Cleanup(func() { + _, err := fn.RecvOrTimeout(donePeer, 2*timeout) + require.NoError(t, err) + peer.Disconnect(errors.New("")) + }) + + // Test that we send this peer its backup on startup. + select { + case rawMsg := <-mockConn.writtenMessages: + msgReader := bytes.NewReader(rawMsg) + nextMsg, err := lnwire.ReadMessage(msgReader, 0) + require.NoError(t, err) + + msg, ok := nextMsg.(*lnwire.PeerStorageRetrieval) + require.True(t, ok) + + require.True(t, bytes.Equal(msg.Blob, sampleData)) + + case <-time.After(timeout): + t.Fatalf("timeout waiting for messsage") + } +} diff --git a/peer/peer_storage.go b/peer/peer_storage.go new file mode 100644 index 00000000000..72b41b506e8 --- /dev/null +++ b/peer/peer_storage.go @@ -0,0 +1,115 @@ +package peer + +import ( + "errors" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // peerStorage is the key used in kvdb to store backup data received + // from peers. + peerStorage = []byte("peer-storage") + + // ErrUninitializedDB signifies an error encountered when attempting + // to access the database before it has been set up. + ErrUninitializedDB = errors.New("uninitialized Kvdb peer data store") +) + +// PeerStorageProducer creates a backup storage instance for a peer. +// +//nolint:revive +type PeerStorageProducer struct { + // DB represents the key-value database backend. + DB kvdb.Backend +} + +// NewPeerStorageProducer initializes the peer storage database by creating a +// top-level bucket. It returns a new PeerStorageProducer with the specified +// database backend if successful, or a nil object and an error if not. +func NewPeerStorageProducer(dB kvdb.Backend) (*PeerStorageProducer, error) { + if err := kvdb.Update(dB, func(tx kvdb.RwTx) error { + _, err := tx.CreateTopLevelBucket(peerStorage) + + return err + }, func() {}); err != nil { + return nil, err + } + + return &PeerStorageProducer{ + DB: dB, + }, nil +} + +// PeerStorageDB is the kvdb implementation of the PeerDataStore interface. +// +//nolint:revive +type PeerStorageDB struct { + // PeerStorageProducer provides access to the kvdb backend. + *PeerStorageProducer + + // pubKey is the public key of the peer associated with this storage + // instance. + pubKey []byte +} + +// NewPeerStorageDB creates a new PeerStorageDB instance associated with the +// given public key. +func (p *PeerStorageProducer) NewPeerStorageDB(key []byte) *PeerStorageDB { + return &PeerStorageDB{ + pubKey: key, + PeerStorageProducer: p, + } +} + +// Store persists the peer's backup in the storage layer. +func (k *PeerStorageDB) Store(data []byte) error { + return kvdb.Update(k.DB, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket(peerStorage) + if bucket == nil { + return ErrUninitializedDB + } + + return bucket.Put(k.pubKey, data) + }, func() {}) +} + +// Retrieve fetches the peer's backup from the storage layer. +func (k *PeerStorageDB) Retrieve() ([]byte, error) { + var data []byte + if err := kvdb.View(k.DB, func(tx kvdb.RTx) error { + bucket := tx.ReadBucket(peerStorage) + if bucket == nil { + return ErrUninitializedDB + } + + blob := bucket.Get(k.pubKey) + + // Copy data (see walletdb's doc on the Get method of + // ReadBucket interface) + if blob != nil { + data = make([]byte, len(blob)) + copy(data, blob) + } + + return nil + }, func() { + data = nil + }); err != nil { + return nil, err + } + + return data, nil +} + +// Delete removes the peer's backup from the storage layer. +func (k *PeerStorageDB) Delete() error { + return kvdb.Update(k.DB, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket(peerStorage) + if bucket == nil { + return ErrUninitializedDB + } + + return bucket.Delete(k.pubKey) + }, func() {}) +} diff --git a/peer/peer_storage_test.go b/peer/peer_storage_test.go new file mode 100644 index 00000000000..3c51fac19e9 --- /dev/null +++ b/peer/peer_storage_test.go @@ -0,0 +1,67 @@ +package peer + +import ( + "os" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/stretchr/testify/require" +) + +// TestPeerStorageDB tests the PeerStorageDB functionality. +func TestPeerStorageDB(t *testing.T) { + rt := require.New(t) + + file, err := os.CreateTemp("", "*.db") + rt.NoError(err) + t.Cleanup(func() { + rt.NoError(file.Close()) + rt.NoError(os.Remove(file.Name())) + }) + + dbPath := file.Name() + db, err := kvdb.Open( + kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout, + ) + rt.NoError(err) + t.Cleanup(func() { + rt.NoError(db.Close()) + }) + + peerStoreProducer, err := NewPeerStorageProducer(db) + rt.NoError(err) + + // Create a sample private key for testing. + samplePrivKey, _ := btcec.NewPrivateKey() + pubKey := samplePrivKey.PubKey() + pubKeyBytes := pubKey.SerializeCompressed() + + // Create a PeerStorageDB instance. + peerDataStore := peerStoreProducer.NewPeerStorageDB(pubKeyBytes) + rt.NotNil(peerDataStore) + + // Sample byte data. + samplePeerData := []byte("sample data") + + // Test store. + err = peerDataStore.Store(samplePeerData) + rt.NoError(err) + + // Test Retrieve. + retrievedData, err := peerDataStore.Retrieve() + rt.NoError(err) + rt.NotNil(retrievedData) + + // Retrieved data should be same as the one earlier stored. + rt.Equal(retrievedData, samplePeerData) + + // Test Delete. + err = peerDataStore.Delete() + rt.NoError(err) + + // Test that there is no data stored for that peer after delete. + retrievedData, err = peerDataStore.Retrieve() + rt.NoError(err) + rt.Nil(retrievedData) +} diff --git a/peer/test_utils.go b/peer/test_utils.go index 05bfe6ad4cc..6234607659b 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "net" + "sync" "sync/atomic" "testing" "time" @@ -18,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -27,6 +29,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" + "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/shachain" "github.com/stretchr/testify/require" @@ -48,32 +51,52 @@ var ( testKeyLoc = keychain.KeyLocator{Family: keychain.KeyFamilyNodeKey} ) -// noUpdate is a function which can be used as a parameter in createTestPeer to -// call the setup code with no custom values on the channels set up. +// noUpdate is a function which can be used as a parameter in +// createTestPeerWithChannel to call the setup code with no custom values on +// the channels set up. var noUpdate = func(a, b *channeldb.OpenChannel) {} -// createTestPeer creates a channel between two nodes, and returns a peer for -// one of the nodes, together with the channel seen from both nodes. It takes -// an updateChan function which can be used to modify the default values on -// the channel states for each peer. -func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, - publTx chan *wire.MsgTx, updateChan func(a, b *channeldb.OpenChannel), - mockSwitch *mockMessageSwitch) ( - *Brontide, *lnwallet.LightningChannel, error) { +type peerTestCtx struct { + peer *Brontide + channel *lnwallet.LightningChannel + notifier *mock.ChainNotifier + publishTx <-chan *wire.MsgTx + mockSwitch *mockMessageSwitch + db *channeldb.DB + privKey *btcec.PrivateKey + mockConn *mockMessageConn + customChan chan *customMsg + chanStatusMgr *netann.ChanStatusManager +} - nodeKeyLocator := keychain.KeyLocator{ - Family: keychain.KeyFamilyNodeKey, - } - aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes( - channels.AlicesPrivKey, - ) - aliceKeySigner := keychain.NewPrivKeyMessageSigner( - aliceKeyPriv, nodeKeyLocator, - ) - bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes( - channels.BobsPrivKey, +// createTestPeerWithChannel creates a channel between two nodes, and returns a +// peer for one of the nodes, together with the channel seen from both nodes. +// It takes an updateChan function which can be used to modify the default +// values on the channel states for each peer. +func createTestPeerWithChannel(t *testing.T, updateChan func(a, + b *channeldb.OpenChannel)) (*peerTestCtx, error) { + + params := createTestPeer(t) + + var ( + publishTx = params.publishTx + mockSwitch = params.mockSwitch + alicePeer = params.peer + notifier = params.notifier + aliceKeyPriv = params.privKey + dbAlice = params.db + chanStatusMgr = params.chanStatusMgr ) + err := chanStatusMgr.Start() + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, chanStatusMgr.Stop()) + }) + + aliceKeyPub := alicePeer.IdentityKey() + estimator := alicePeer.cfg.FeeEstimator + channelCapacity := btcutil.Amount(10 * 1e8) channelBal := channelCapacity / 2 aliceDustLimit := btcutil.Amount(200) @@ -88,6 +111,10 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, } fundingTxIn := wire.NewTxIn(prevOut, nil, nil) + bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes( + channels.BobsPrivKey, + ) + aliceCfg := channeldb.ChannelConfig{ ChannelConstraints: channeldb.ChannelConstraints{ DustLimit: aliceDustLimit, @@ -141,23 +168,23 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) if err != nil { - return nil, nil, err + return nil, err } bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) if err != nil { - return nil, nil, err + return nil, err } bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) if err != nil { - return nil, nil, err + return nil, err } alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) if err != nil { - return nil, nil, err + return nil, err } aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) @@ -167,29 +194,20 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, isAliceInitiator, 0, ) if err != nil { - return nil, nil, err + return nil, err } - dbAlice, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - dbBob, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, err + return nil, err } t.Cleanup(func() { require.NoError(t, dbBob.Close()) }) - estimator := chainfee.NewStaticEstimator(12500, 0) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { - return nil, nil, err + return nil, err } // TODO(roasbeef): need to factor in commit fee? @@ -214,7 +232,7 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, var chanIDBytes [8]byte if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil { - return nil, nil, err + return nil, err } shortChanID := lnwire.NewShortChanIDFromInt( @@ -259,13 +277,9 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, // Set custom values on the channel states. updateChan(aliceChannelState, bobChannelState) - aliceAddr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - + aliceAddr := alicePeer.cfg.Addr.Address if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil { - return nil, nil, err + return nil, err } bobAddr := &net.TCPAddr{ @@ -274,7 +288,7 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, } if err := bobChannelState.SyncPending(bobAddr, 0); err != nil { - return nil, nil, err + return nil, err } aliceSigner := input.NewMockSigner( @@ -289,7 +303,7 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, aliceSigner, aliceChannelState, alicePool, ) if err != nil { - return nil, nil, err + return nil, err } _ = alicePool.Start() t.Cleanup(func() { @@ -301,116 +315,16 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, bobSigner, bobChannelState, bobPool, ) if err != nil { - return nil, nil, err + return nil, err } _ = bobPool.Start() t.Cleanup(func() { require.NoError(t, bobPool.Stop()) }) - chainIO := &mock.ChainIO{ - BestHeight: broadcastHeight, - } - wallet := &lnwallet.LightningWallet{ - WalletController: &mock.WalletController{ - RootKey: aliceKeyPriv, - PublishedTransactions: publTx, - }, - } - - // If mockSwitch is not set by the caller, set it to the default as the - // caller does not need to control it. - if mockSwitch == nil { - mockSwitch = &mockMessageSwitch{} - } - - nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner) - - const chanActiveTimeout = time.Minute - - chanStatusMgr, err := netann.NewChanStatusManager(&netann.ChanStatusConfig{ - ChanStatusSampleInterval: 30 * time.Second, - ChanEnableTimeout: chanActiveTimeout, - ChanDisableTimeout: 2 * time.Minute, - DB: dbAlice.ChannelStateDB(), - Graph: dbAlice.ChannelGraph(), - MessageSigner: nodeSignerAlice, - OurPubKey: aliceKeyPub, - OurKeyLoc: testKeyLoc, - IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate, - *wire.OutPoint, bool) error { - - return nil - }, - }) - if err != nil { - return nil, nil, err - } - if err = chanStatusMgr.Start(); err != nil { - return nil, nil, err - } - - errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize) - if err != nil { - return nil, nil, err - } - - var pubKey [33]byte - copy(pubKey[:], aliceKeyPub.SerializeCompressed()) - - cfgAddr := &lnwire.NetAddress{ - IdentityKey: aliceKeyPub, - Address: aliceAddr, - ChainNet: wire.SimNet, - } - - interceptableSwitchNotifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), - } - interceptableSwitchNotifier.EpochChan <- &chainntnfs.BlockEpoch{ - Height: 1, - } - - interceptableSwitch, err := htlcswitch.NewInterceptableSwitch( - &htlcswitch.InterceptableSwitchConfig{ - CltvRejectDelta: testCltvRejectDelta, - CltvInterceptDelta: testCltvRejectDelta + 3, - Notifier: interceptableSwitchNotifier, - }, + alicePeer.remoteFeatures = lnwire.NewFeatureVector( + nil, lnwire.Features, ) - if err != nil { - return nil, nil, err - } - - // TODO(yy): change ChannelNotifier to be an interface. - channelNotifier := channelnotifier.New(dbAlice.ChannelStateDB()) - require.NoError(t, channelNotifier.Start()) - t.Cleanup(func() { - require.NoError(t, channelNotifier.Stop(), - "stop channel notifier failed") - }) - - cfg := &Config{ - Addr: cfgAddr, - PubKeyBytes: pubKey, - ErrorBuffer: errBuffer, - ChainIO: chainIO, - Switch: mockSwitch, - ChanActiveTimeout: chanActiveTimeout, - InterceptSwitch: interceptableSwitch, - ChannelDB: dbAlice.ChannelStateDB(), - FeeEstimator: estimator, - Wallet: wallet, - ChainNotifier: notifier, - ChanStatusMgr: chanStatusMgr, - Features: lnwire.NewFeatureVector(nil, lnwire.Features), - DisconnectPeer: func(b *btcec.PublicKey) error { return nil }, - ChannelNotifier: channelNotifier, - } - - alicePeer := NewBrontide(*cfg) - alicePeer.remoteFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint()) alicePeer.activeChannels.Store(chanID, channelAlice) @@ -418,7 +332,13 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, alicePeer.wg.Add(1) go alicePeer.channelManager() - return alicePeer, channelBob, nil + return &peerTestCtx{ + peer: alicePeer, + channel: channelBob, + notifier: notifier, + publishTx: publishTx, + mockSwitch: mockSwitch, + }, nil } // mockMessageSwitch is a mock implementation of the messageSwitch interface @@ -618,3 +538,281 @@ func (m *mockMessageConn) LocalAddr() net.Addr { func (m *mockMessageConn) Close() error { return nil } + +type mockPeerDataStore struct { + data [][]byte + mtx sync.RWMutex +} + +func newMockDataStore() *mockPeerDataStore { + return &mockPeerDataStore{} +} + +// Store persists the backup data given to us by peers. +func (d *mockPeerDataStore) Store(data []byte) error { + d.mtx.Lock() + defer d.mtx.Unlock() + + d.data = append(d.data, data) + + return nil +} + +// Delete deletes the peer with PeerPub public key from the storage layer. +func (d *mockPeerDataStore) Delete() error { + d.mtx.Lock() + defer d.mtx.Unlock() + + d.data = nil + + return nil +} + +// Retrieve obtains data for peer with peerPub public key from the storage +// layer. +func (d *mockPeerDataStore) Retrieve() ( + []byte, error) { + + d.mtx.RLock() + defer d.mtx.RUnlock() + + return fn.Flatten(d.data), nil +} + +// createTestPeer creates a new peer for testing and returns a context struct +// containing necessary handles and mock objects for conducting tests on peer +// functionalities. +func createTestPeer(t *testing.T) *peerTestCtx { + nodeKeyLocator := keychain.KeyLocator{ + Family: keychain.KeyFamilyNodeKey, + } + + aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes( + channels.AlicesPrivKey, + ) + + aliceKeySigner := keychain.NewPrivKeyMessageSigner( + aliceKeyPriv, nodeKeyLocator, + ) + + aliceAddr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + cfgAddr := &lnwire.NetAddress{ + IdentityKey: aliceKeyPub, + Address: aliceAddr, + ChainNet: wire.SimNet, + } + + errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize) + require.NoError(t, err) + + chainIO := &mock.ChainIO{ + BestHeight: broadcastHeight, + } + + publishTx := make(chan *wire.MsgTx) + wallet := &lnwallet.LightningWallet{ + WalletController: &mock.WalletController{ + RootKey: aliceKeyPriv, + PublishedTransactions: publishTx, + }, + } + + const chanActiveTimeout = time.Minute + + dbAlice, err := channeldb.Open(t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, dbAlice.Close()) + }) + + nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner) + + chanStatusMgr, err := netann.NewChanStatusManager(&netann. + ChanStatusConfig{ + ChanStatusSampleInterval: 30 * time.Second, + ChanEnableTimeout: chanActiveTimeout, + ChanDisableTimeout: 2 * time.Minute, + DB: dbAlice.ChannelStateDB(), + Graph: dbAlice.ChannelGraph(), + MessageSigner: nodeSignerAlice, + OurPubKey: aliceKeyPub, + OurKeyLoc: testKeyLoc, + IsChannelActive: func(lnwire.ChannelID) bool { + return true + }, + ApplyChannelUpdate: func(*lnwire.ChannelUpdate, + *wire.OutPoint, bool) error { + + return nil + }, + }) + require.NoError(t, err) + + interceptableSwitchNotifier := &mock.ChainNotifier{ + EpochChan: make(chan *chainntnfs.BlockEpoch, 1), + } + interceptableSwitchNotifier.EpochChan <- &chainntnfs.BlockEpoch{ + Height: 1, + } + + interceptableSwitch, err := htlcswitch.NewInterceptableSwitch( + &htlcswitch.InterceptableSwitchConfig{ + CltvRejectDelta: testCltvRejectDelta, + CltvInterceptDelta: testCltvRejectDelta + 3, + Notifier: interceptableSwitchNotifier, + }, + ) + require.NoError(t, err) + + // TODO(yy): create interface for lnwallet.LightningChannel so we can + // easily mock it without the following setups. + notifier := &mock.ChainNotifier{ + SpendChan: make(chan *chainntnfs.SpendDetail), + EpochChan: make(chan *chainntnfs.BlockEpoch), + ConfChan: make(chan *chainntnfs.TxConfirmation), + } + + mockSwitch := &mockMessageSwitch{} + + // TODO(yy): change ChannelNotifier to be an interface. + channelNotifier := channelnotifier.New(dbAlice.ChannelStateDB()) + require.NoError(t, channelNotifier.Start()) + t.Cleanup(func() { + require.NoError(t, channelNotifier.Stop(), + "stop channel notifier failed") + }) + + writeBufferPool := pool.NewWriteBuffer( + pool.DefaultWriteBufferGCInterval, + pool.DefaultWriteBufferExpiryInterval, + ) + + writePool := pool.NewWrite( + writeBufferPool, 1, timeout, + ) + require.NoError(t, writePool.Start()) + + readBufferPool := pool.NewReadBuffer( + pool.DefaultReadBufferGCInterval, + pool.DefaultReadBufferExpiryInterval, + ) + + readPool := pool.NewRead( + readBufferPool, 1, timeout, + ) + require.NoError(t, readPool.Start()) + + mockConn := newMockConn(t, 1) + + receivedCustomChan := make(chan *customMsg) + + var pubKey [33]byte + copy(pubKey[:], aliceKeyPub.SerializeCompressed()) + + estimator := chainfee.NewStaticEstimator(12500, 0) + + cfg := &Config{ + Addr: cfgAddr, + PubKeyBytes: pubKey, + ErrorBuffer: errBuffer, + ChainIO: chainIO, + Switch: mockSwitch, + ChanActiveTimeout: chanActiveTimeout, + InterceptSwitch: interceptableSwitch, + ChannelDB: dbAlice.ChannelStateDB(), + FeeEstimator: estimator, + Wallet: wallet, + ChainNotifier: notifier, + ChanStatusMgr: chanStatusMgr, + Features: lnwire.NewFeatureVector( + nil, lnwire.Features, + ), + DisconnectPeer: func(b *btcec.PublicKey) error { + return nil + }, + ChannelNotifier: channelNotifier, + PrunePersistentPeerConnection: func([33]byte) {}, + LegacyFeatures: lnwire.EmptyFeatureVector(), + WritePool: writePool, + ReadPool: readPool, + Conn: mockConn, + HandleCustomMessage: func( + peer [33]byte, msg *lnwire.Custom) error { + + receivedCustomChan <- &customMsg{ + peer: peer, + msg: *msg, + } + + return nil + }, + PongBuf: make([]byte, lnwire.MaxPongBytes), + PeerDataStore: newMockDataStore(), + } + + alicePeer := NewBrontide(*cfg) + + return &peerTestCtx{ + publishTx: publishTx, + mockSwitch: mockSwitch, + peer: alicePeer, + notifier: notifier, + db: dbAlice, + privKey: aliceKeyPriv, + mockConn: mockConn, + customChan: receivedCustomChan, + chanStatusMgr: chanStatusMgr, + } +} + +// startPeer invokes the `Start` method on the specified peer and handles any +// initial startup messages for testing. +func startPeer(t *testing.T, mockConn *mockMessageConn, + peer *Brontide) <-chan struct{} { + + // Start the peer in a goroutine so that we can handle and test for + // startup messages. Successfully sending and receiving init message, + // indicates a successful startup. + done := make(chan struct{}) + go func() { + require.NoError(t, peer.Start()) + close(done) + }() + + // Receive the init message that should be the first message received on + // startup. + rawMsg, err := fn.RecvOrTimeout[[]byte]( + mockConn.writtenMessages, timeout, + ) + require.NoError(t, err) + + msgReader := bytes.NewReader(rawMsg) + nextMsg, err := lnwire.ReadMessage(msgReader, 0) + require.NoError(t, err) + + _, ok := nextMsg.(*lnwire.Init) + require.True(t, ok) + + // Write the reply for the init message to complete the startup. + initReplyMsg := lnwire.NewInitMessage( + lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + lnwire.GossipQueriesOptional, + ), + lnwire.NewRawFeatureVector(), + ) + + var b bytes.Buffer + _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) + require.NoError(t, err) + + ok = fn.SendOrQuit[[]byte, struct{}]( + mockConn.readMessages, b.Bytes(), make(chan struct{}), + ) + require.True(t, ok) + + return done +} diff --git a/sample-lnd.conf b/sample-lnd.conf index 2783c58fd52..c76c6ca79fe 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -1306,6 +1306,9 @@ ; Set to disable blinded route forwarding. ; protocol.no-route-blinding=false +; Set to enable storing backup data for other peers. +; protocol.peer-storage=false + [db] ; The selected database backend. The current default backend is "bolt". lnd diff --git a/server.go b/server.go index 627f0c0f3fa..25ccc4bf2c9 100644 --- a/server.go +++ b/server.go @@ -330,6 +330,10 @@ type server struct { // txPublisher is a publisher with fee-bumping capability. txPublisher *sweep.TxPublisher + // peerStorageProvider offers a structure for storing peer's backup + // data. + peerStorageProvider *peer.PeerStorageProducer + quit chan struct{} wg sync.WaitGroup @@ -550,6 +554,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, CustomFeatures: cfg.ProtocolOptions.ExperimentalProtocol.CustomFeatures(), NoTaprootChans: !cfg.ProtocolOptions.TaprootChans, NoRouteBlinding: cfg.ProtocolOptions.NoRouteBlinding(), + NoPeerStorage: !cfg.ProtocolOptions.PeerStorage(), }) if err != nil { return nil, err @@ -1518,8 +1523,15 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } - // Assemble a peer notifier which will provide clients with subscriptions - // to peer online and offline events. + s.peerStorageProvider, err = peer.NewPeerStorageProducer( + dbs.PeerStorageDB, + ) + if err != nil { + return nil, err + } + + // Assemble a peer notifier which will provide clients with + // subscriptions to peer online and offline events. s.peerNotifier = peernotifier.New() // Create a channel event store which monitors all open channels. @@ -3799,6 +3811,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, brontideConn := conn.(*brontide.Conn) addr := conn.RemoteAddr() pubKey := brontideConn.RemotePub() + pubKeyBytes := pubKey.SerializeCompressed() srvrLog.Infof("Finalizing connection to %x@%s, inbound=%v", pubKey.SerializeCompressed(), addr, inbound) @@ -3907,7 +3920,10 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, RequestAlias: s.aliasMgr.RequestAlias, AddLocalAlias: s.aliasMgr.AddLocalAlias, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), - Quit: s.quit, + PeerDataStore: s.peerStorageProvider.NewPeerStorageDB( + pubKeyBytes, + ), + Quit: s.quit, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed())