Skip to content

Commit 6bb7f1c

Browse files
authored
Merge pull request #1412 from lightninglabs/chan-open-courier-check
tapchannel: validate proof courier before opening or accepting channels
2 parents a724d38 + bab648f commit 6bb7f1c

File tree

9 files changed

+394
-28
lines changed

9 files changed

+394
-28
lines changed

docs/examples/basic-price-oracle/go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ require (
9999
github.com/lightninglabs/neutrino/cache v1.1.2 // indirect
100100
github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb // indirect
101101
github.com/lightningnetwork/lnd v0.19.0-beta.rc1 // indirect
102+
github.com/lightningnetwork/lnd/cert v1.2.2 // indirect
102103
github.com/lightningnetwork/lnd/clock v1.1.1 // indirect
103104
github.com/lightningnetwork/lnd/fn/v2 v2.0.8 // indirect
104105
github.com/lightningnetwork/lnd/healthcheck v1.2.6 // indirect

internal/test/grpc.go

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package test
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"fmt"
7+
"net"
8+
"testing"
9+
"time"
10+
11+
"github.com/lightningnetwork/lnd/cert"
12+
"github.com/lightningnetwork/lnd/lntest/port"
13+
"github.com/stretchr/testify/require"
14+
"golang.org/x/sync/errgroup"
15+
"google.golang.org/grpc"
16+
)
17+
18+
var (
19+
// ListenAddrTemplate is the template for the address the mock server
20+
// listens on.
21+
ListenAddrTemplate = "localhost:%d"
22+
23+
// StartupWaitTime is the time we wait for the server to start up.
24+
StartupWaitTime = 50 * time.Millisecond
25+
)
26+
27+
// StartMockGRPCServer starts a mock gRPC server on a free port and returns the
28+
// address it's listening on. The caller should clean up the server by calling
29+
// the cleanup function.
30+
func StartMockGRPCServer(t *testing.T, grpcServer *grpc.Server,
31+
withTLS bool) (string, func(), error) {
32+
33+
nextPort := port.NextAvailablePort()
34+
listenAddr := fmt.Sprintf(ListenAddrTemplate, nextPort)
35+
36+
grpcListener, err := net.Listen("tcp", listenAddr)
37+
if err != nil {
38+
return "", nil, fmt.Errorf("mock RPC server unable to listen "+
39+
"on %s", listenAddr)
40+
}
41+
42+
listener := grpcListener
43+
if withTLS {
44+
listener = tls.NewListener(grpcListener, genCert(t))
45+
}
46+
47+
// Create an errgroup with an associated context. If the goroutine
48+
// errors, the context is closed.
49+
g, ctx := errgroup.WithContext(context.Background())
50+
51+
// Channel to signal that the Serve goroutine has started.
52+
startupSignal := make(chan struct{})
53+
54+
g.Go(func() error {
55+
// The goroutine has started, signal the main goroutine.
56+
close(startupSignal)
57+
58+
err := grpcServer.Serve(listener)
59+
if err != nil {
60+
return fmt.Errorf("mock RPC server unable to serve "+
61+
"on %s: %v", listenAddr, err)
62+
}
63+
64+
return nil
65+
})
66+
67+
// We wait until the goroutine has started before returning the
68+
// listener address.
69+
<-startupSignal
70+
71+
// Use a timeout to check for any immediate errors.
72+
select {
73+
case <-ctx.Done():
74+
// If the context is canceled, an error occurred during startup.
75+
return "", nil, ctx.Err()
76+
77+
case <-time.After(StartupWaitTime):
78+
// No error was reported within the startup wait time, we can
79+
// assume the server is running now.
80+
}
81+
82+
// The cleanup function stops the server and waits for the goroutine to
83+
// finish.
84+
cleanup := func() {
85+
grpcServer.Stop()
86+
_ = listener.Close()
87+
_ = g.Wait()
88+
}
89+
90+
return listenAddr, cleanup, nil
91+
}
92+
93+
func genCert(t *testing.T) *tls.Config {
94+
certBytes, keyBytes, err := cert.GenCertPair(
95+
"tapd autogenerated cert", nil, nil, false, time.Minute,
96+
)
97+
require.NoError(t, err)
98+
99+
tlsCert, _, err := cert.LoadCertFromBytes(certBytes, keyBytes)
100+
require.NoError(t, err)
101+
102+
return cert.TLSConfFromCert(tlsCert)
103+
}

proof/courier.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,12 @@ func (c *UniverseRpcCourier) ensureConnect(ctx context.Context) error {
12671267
c.client = unirpc.NewUniverseClient(conn)
12681268
c.rawConn = conn
12691269

1270-
return nil
1270+
// Make sure we initiate the connection. The GetInfo RPC method is in
1271+
// the base macaroon white list, so it doesn't require any
1272+
// authentication, independent of the universe's configuration.
1273+
_, err = c.client.Info(ctx, &unirpc.InfoRequest{})
1274+
1275+
return err
12711276
}
12721277

12731278
// DeliverProof attempts to delivery a proof file to the receiver.
@@ -1666,3 +1671,31 @@ func FetchProofProvenance(ctx context.Context, localArchive Archiver,
16661671

16671672
return proofFile, nil
16681673
}
1674+
1675+
// CheckUniverseRpcCourierConnection checks if the universe RPC courier at the
1676+
// given URL is reachable. It returns an error if the connection cannot be
1677+
// established within the given timeout duration.
1678+
func CheckUniverseRpcCourierConnection(ctx context.Context,
1679+
timeout time.Duration, courierURL *url.URL) error {
1680+
1681+
// We now also make a quick test connection.
1682+
ctxt, cancel := context.WithTimeout(ctx, timeout)
1683+
defer cancel()
1684+
courier, err := NewUniverseRpcCourier(
1685+
ctxt, &UniverseRpcCourierCfg{}, nil, nil, courierURL,
1686+
false,
1687+
)
1688+
if err != nil {
1689+
return fmt.Errorf("unable to test connection proof courier "+
1690+
"'%v': %v", courierURL.String(), err)
1691+
}
1692+
1693+
err = courier.Close()
1694+
if err != nil {
1695+
// We only log any disconnect errors, as they're not critical.
1696+
log.Warnf("Unable to disconnect from proof courier '%v': %v",
1697+
courierURL.String(), err)
1698+
}
1699+
1700+
return nil
1701+
}

proof/courier_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@ package proof
33
import (
44
"bytes"
55
"context"
6+
"fmt"
7+
"net/url"
68
"testing"
79

810
"github.com/lightninglabs/taproot-assets/asset"
911
"github.com/lightninglabs/taproot-assets/fn"
1012
"github.com/lightninglabs/taproot-assets/internal/test"
13+
"github.com/lightninglabs/taproot-assets/taprpc/universerpc"
14+
"github.com/lightningnetwork/lnd/lntest/port"
1115
"github.com/stretchr/testify/require"
16+
"google.golang.org/grpc"
17+
"google.golang.org/grpc/credentials/insecure"
1218
)
1319

1420
// TestUniverseRpcCourierLocalArchiveShortCut tests that the local archive is
@@ -72,3 +78,71 @@ func TestUniverseRpcCourierLocalArchiveShortCut(t *testing.T) {
7278
})
7379
require.ErrorContains(t, err, "is missing outpoint")
7480
}
81+
82+
// TestCheckUniverseRpcCourierConnection tests that we can connect to the
83+
// universe rpc courier. We also test that we fail to connect to a
84+
// universe rpc courier that is not listening on the given address.
85+
func TestCheckUniverseRpcCourierConnection(t *testing.T) {
86+
serverOpts := []grpc.ServerOption{
87+
grpc.Creds(insecure.NewCredentials()),
88+
}
89+
grpcServer := grpc.NewServer(serverOpts...)
90+
91+
server := MockUniverseServer{}
92+
universerpc.RegisterUniverseServer(grpcServer, &server)
93+
94+
// We also grab a port that is free to listen on for our negative test.
95+
// Since we know the port is free, and we don't listen on it, we expect
96+
// the connection to fail.
97+
noConnectPort := port.NextAvailablePort()
98+
noConnectAddr := fmt.Sprintf(test.ListenAddrTemplate, noConnectPort)
99+
100+
mockServerAddr, cleanup, err := test.StartMockGRPCServer(
101+
t, grpcServer, true,
102+
)
103+
require.NoError(t, err)
104+
t.Cleanup(cleanup)
105+
106+
tests := []struct {
107+
name string
108+
courierAddr *url.URL
109+
expectErr string
110+
}{
111+
{
112+
name: "valid universe rpc courier",
113+
courierAddr: MockCourierURL(
114+
t, UniverseRpcCourierType, mockServerAddr,
115+
),
116+
},
117+
{
118+
name: "valid universe rpc courier, but can't connect",
119+
courierAddr: MockCourierURL(
120+
t, UniverseRpcCourierType, noConnectAddr,
121+
),
122+
expectErr: "unable to connect to courier service",
123+
},
124+
}
125+
126+
for _, tt := range tests {
127+
t.Run(tt.name, func(t *testing.T) {
128+
// We use a short timeout here, since we don't want to
129+
// wait for the full default timeout of the funding
130+
// controller
131+
ctxt, cancel := context.WithTimeout(
132+
context.Background(), test.StartupWaitTime*2,
133+
)
134+
defer cancel()
135+
136+
err := CheckUniverseRpcCourierConnection(
137+
ctxt, test.StartupWaitTime, tt.courierAddr,
138+
)
139+
if tt.expectErr != "" {
140+
require.ErrorContains(t, err, tt.expectErr)
141+
142+
return
143+
}
144+
145+
require.NoError(t, err)
146+
})
147+
}
148+
}

proof/mock.go

+25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/lightninglabs/taproot-assets/commitment"
2121
"github.com/lightninglabs/taproot-assets/fn"
2222
"github.com/lightninglabs/taproot-assets/internal/test"
23+
"github.com/lightninglabs/taproot-assets/taprpc/universerpc"
2324
"github.com/lightningnetwork/lnd/keychain"
2425
"github.com/lightningnetwork/lnd/lnutils"
2526
"github.com/lightningnetwork/lnd/lnwire"
@@ -1046,3 +1047,27 @@ func newMockIgnoreChecker(ignoreAll bool,
10461047
func (m *mockIgnoreChecker) IsIgnored(assetPoint AssetPoint) bool {
10471048
return m.ignoreAll || m.ignoredAssetPoints.Contains(assetPoint)
10481049
}
1050+
1051+
// MockUniverseServer is a mock implementation of the UniverseServer
1052+
// interface. It implements the GetInfo RPC method, which returns an empty
1053+
// InfoResponse.
1054+
type MockUniverseServer struct {
1055+
universerpc.UnimplementedUniverseServer
1056+
}
1057+
1058+
// Info is a mock implementation of the GetInfo RPC.
1059+
func (m *MockUniverseServer) Info(context.Context,
1060+
*universerpc.InfoRequest) (*universerpc.InfoResponse, error) {
1061+
1062+
return &universerpc.InfoResponse{}, nil
1063+
}
1064+
1065+
// MockCourierURL creates a new mock proof courier URL for the given protocol
1066+
// and address.
1067+
func MockCourierURL(t *testing.T, protocol, addr string) *url.URL {
1068+
urlString := fmt.Sprintf("%s://%s", protocol, addr)
1069+
proofCourierAddr, err := ParseCourierAddress(urlString)
1070+
require.NoError(t, err)
1071+
1072+
return proofCourierAddr
1073+
}

rfq/oracle_test.go

+9-22
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7-
"net"
87
"testing"
98
"time"
109

@@ -22,10 +21,6 @@ import (
2221
)
2322

2423
const (
25-
// testServiceAddress is the address of the mock RPC price oracle
26-
// service.
27-
testServiceAddress = "localhost:8095"
28-
2924
// testAssetRate is the asset units to BTC rate used in the test cases.
3025
testAssetRate uint64 = 42_000
3126
)
@@ -109,15 +104,15 @@ func validateAssetRatesRequest(
109104

110105
// startBackendRPC starts the given RPC server and blocks until the server is
111106
// shut down.
112-
func startBackendRPC(grpcServer *grpc.Server) error {
107+
func startBackendRPC(t *testing.T, grpcServer *grpc.Server) string {
113108
server := mockRpcPriceOracleServer{}
114109
priceoraclerpc.RegisterPriceOracleServer(grpcServer, &server)
115-
grpcListener, err := net.Listen("tcp", testServiceAddress)
116-
if err != nil {
117-
return fmt.Errorf("RPC server unable to listen on %s",
118-
testServiceAddress)
119-
}
120-
return grpcServer.Serve(grpcListener)
110+
mockAddr, cleanup, err := test.StartMockGRPCServer(t, grpcServer, false)
111+
require.NoError(t, err)
112+
113+
t.Cleanup(cleanup)
114+
115+
return mockAddr
121116
}
122117

123118
// testCaseQueryAskPrice is a test case for the RPC price oracle client
@@ -140,11 +135,7 @@ func runQueryAskPriceTest(t *testing.T, tc *testCaseQueryAskPrice) {
140135
grpc.Creds(insecure.NewCredentials()),
141136
}
142137
backendService := grpc.NewServer(serverOpts...)
143-
go func() { _ = startBackendRPC(backendService) }()
144-
defer backendService.Stop()
145-
146-
// Wait for the server to start.
147-
time.Sleep(200 * time.Millisecond)
138+
testServiceAddress := startBackendRPC(t, backendService)
148139

149140
// Create a new RPC price oracle client and connect to the mock service.
150141
serviceAddr := fmt.Sprintf("rfqrpc://%s", testServiceAddress)
@@ -252,11 +243,7 @@ func runQueryBidPriceTest(t *testing.T, tc *testCaseQueryBidPrice) {
252243
grpc.Creds(insecure.NewCredentials()),
253244
}
254245
backendService := grpc.NewServer(serverOpts...)
255-
go func() { _ = startBackendRPC(backendService) }()
256-
defer backendService.Stop()
257-
258-
// Wait for the server to start.
259-
time.Sleep(2 * time.Second)
246+
testServiceAddress := startBackendRPC(t, backendService)
260247

261248
// Create a new RPC price oracle client and connect to the mock service.
262249
serviceAddr := fmt.Sprintf("rfqrpc://%s", testServiceAddress)

tapchannel/aux_closer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ func (a *AuxChanCloser) ShutdownBlob(
578578
return lfn.Some[lnwire.CustomRecords](records), nil
579579
}
580580

581-
// shipChannelTxn takes a chanenl transaction, an output commitment, and the
581+
// shipChannelTxn takes a channel transaction, an output commitment, and the
582582
// set of vPackets used to make the output commitment and ships a complete
583583
// pre-singed package off to the porter. This'll insert a transfer for the
584584
// channel, send the final transaction to the network, and update any

0 commit comments

Comments
 (0)