Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sphinx.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
// out the payload so we can derive the specified forwarding
// instructions.
hopPayload, err := DecodeHopPayload(
bytes.NewReader(hopInfo), tlvPayloadOnly,
bytes.NewReader(hopInfo[:routingInfoLen]), tlvPayloadOnly,
)
if err != nil {
return nil, nil, err
Expand Down
68 changes: 65 additions & 3 deletions sphinx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ func TestSphinxSingleHop(t *testing.T) {
}
}

func TestSphinxNodeRelpay(t *testing.T) {
func TestSphinxNodeReplay(t *testing.T) {
// We'd like to ensure that the sphinx node itself rejects all replayed
// packets which share the same shared secret.
nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
Expand Down Expand Up @@ -662,7 +662,7 @@ func TestSphinxNodeRelpay(t *testing.T) {
}
}

func TestSphinxNodeRelpaySameBatch(t *testing.T) {
func TestSphinxNodeReplaySameBatch(t *testing.T) {
// We'd like to ensure that the sphinx node itself rejects all replayed
// packets which share the same shared secret.
nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
Expand Down Expand Up @@ -708,7 +708,7 @@ func TestSphinxNodeRelpaySameBatch(t *testing.T) {
}
}

func TestSphinxNodeRelpayLaterBatch(t *testing.T) {
func TestSphinxNodeReplayLaterBatch(t *testing.T) {
// We'd like to ensure that the sphinx node itself rejects all replayed
// packets which share the same shared secret.
nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
Expand Down Expand Up @@ -1439,3 +1439,65 @@ func TestVariablePayloadOnion(t *testing.T) {
"match expected BOLT 4 packet, want: %s, got %s",
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
}

// TestUnwrapPacketBeyondRoutingInfoBoundary tests that unwrapPacket does not
// read into the zero-padding area when processing a malformed onion packet
// with an oversized payload.
func TestUnwrapPacketBeyondRoutingInfoBoundary(t *testing.T) {
t.Parallel()

// Create a router to get a valid key pair.
privKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

sessionKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

// Compute shared secret as unwrapPacket would.
sessionKeyECDH := &PrivKeyECDH{PrivKey: sessionKey}
sharedSecretArr, err := sessionKeyECDH.ECDH(privKey.PubKey())
require.NoError(t, err)
sharedSecret := Hash256(sharedSecretArr)

// Generate the rho key and stream bytes for encryption.
rhoKey := generateKey("rho", &sharedSecret)
streamBytes := generateCipherStream(rhoKey, uint(MaxRoutingPayloadSize))

// Create routing info with a malicious payload size.
// 0xfd 0x04 0xf2 encodes 1266 in BigSize format.
// With 3-byte length + 1266-byte payload + 32-byte HMAC = 1301 bytes,
// exceeding the 1300-byte boundary.
routingInfo := make([]byte, MaxRoutingPayloadSize)

// Build the plaintext payload that will be encrypted.
plaintext := make([]byte, MaxRoutingPayloadSize)
plaintext[0] = 0xfd
plaintext[1] = 0x04
plaintext[2] = 0xf2

// Fill payload area with recognizable pattern.
for i := 3; i < MaxRoutingPayloadSize; i++ {
plaintext[i] = 0xaa
}

// Encrypt the routing info by XORing with stream bytes.
xor(routingInfo, plaintext, streamBytes[:MaxRoutingPayloadSize])

// Compute valid HMAC for the packet.
muKey := generateKey("mu", &sharedSecret)
headerMAC := calcMac(muKey, routingInfo)

// Create the onion packet.
onionPkt := &OnionPacket{
Version: baseVersion,
EphemeralKey: sessionKey.PubKey(),
RoutingInfo: routingInfo,
HeaderMAC: headerMAC,
}

// Process the packet - this should fail because the payload size
// exceeds the routing info boundary.
_, _, err = unwrapPacket(onionPkt, &sharedSecret, nil, true)
expectedErr := fmt.Errorf("%w: %w", ErrIOReadFull, io.ErrUnexpectedEOF)
require.Equal(t, expectedErr, err)
}