diff --git a/error.go b/error.go index 14cb777..51da779 100644 --- a/error.go +++ b/error.go @@ -53,4 +53,11 @@ var ( // ErrIOReadFull is returned when an io read full operation fails. ErrIOReadFull = errors.New("io read full error") + + // ErrMaxHopPayloadSizeExceeded is returned when decoding a TLV payload + // whose encoded length exceeds MaxHopPayloadSize. This prevents the + // payload from overflowing the routing info boundary into the + // zero-padding area reserved for subsequent hops. + ErrMaxHopPayloadSizeExceeded = errors.New("hop payload size exceeds " + + "maximum allowed") ) diff --git a/payload.go b/payload.go index c02e351..d5361fc 100644 --- a/payload.go +++ b/payload.go @@ -120,6 +120,10 @@ func DecodeHopPayload(r io.Reader, tlvGuaranteed bool) (*HopPayload, error) { if err != nil { return nil, err } + if payloadSize > MaxHopPayloadSize { + return nil, fmt.Errorf("payload size %v, limit %v: %w", + payloadSize, MaxHopPayloadSize, ErrMaxHopPayloadSizeExceeded) + } } // Now that we know the payload size, we'll create a new buffer to read diff --git a/sphinx.go b/sphinx.go index d56959f..3f77c53 100644 --- a/sphinx.go +++ b/sphinx.go @@ -53,6 +53,20 @@ const ( // onion messaging jumbo onion packet. MaxOnionMessagePayloadSize = 32768 + // BigSizeLenMaxBytes is the maximum number of bytes required to encode + // a payload length using the BigSize variable-length encoding format. + // For TLV payloads, the length prefix can be 1, 3, 5, or 9 bytes + // depending on the value. Since MaxHopPayloadSize fits within a 16-bit + // value, 3 bytes (0xfd + 2 bytes) is the maximum needed. + BigSizeLenMaxBytes = 3 + + // MaxHopPayloadSize is the maximum size of the payload data for a + // single hop, excluding the BigSize length prefix and the trailing + // HMAC. When decoding a TLV payload, the encoded length must not + // exceed this value, otherwise the payload would overflow the routing + // info boundary into the zero-padding area used by later hops. + MaxHopPayloadSize = MaxRoutingPayloadSize - BigSizeLenMaxBytes - HMACSize + // keyLen is the length of the keys used to generate cipher streams and // encrypt payloads. Since we use SHA256 to generate the keys, the // maximum length currently is 32 bytes. @@ -744,7 +758,7 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256, // out the payload so we can derive the specified forwarding // instructions. hopPayload, err := DecodeHopPayload( - bytes.NewReader(hopInfo), tlvPayloadOnly, + bytes.NewReader(hopInfo[:routingInfoLen]), tlvPayloadOnly, ) if err != nil { return nil, nil, err diff --git a/sphinx_test.go b/sphinx_test.go index 221bddc..a64c443 100644 --- a/sphinx_test.go +++ b/sphinx_test.go @@ -634,7 +634,7 @@ func TestSphinxSingleHop(t *testing.T) { } } -func TestSphinxNodeRelpay(t *testing.T) { +func TestSphinxNodeReplay(t *testing.T) { // We'd like to ensure that the sphinx node itself rejects all replayed // packets which share the same shared secret. nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops) @@ -662,7 +662,7 @@ func TestSphinxNodeRelpay(t *testing.T) { } } -func TestSphinxNodeRelpaySameBatch(t *testing.T) { +func TestSphinxNodeReplaySameBatch(t *testing.T) { // We'd like to ensure that the sphinx node itself rejects all replayed // packets which share the same shared secret. nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops) @@ -708,7 +708,7 @@ func TestSphinxNodeRelpaySameBatch(t *testing.T) { } } -func TestSphinxNodeRelpayLaterBatch(t *testing.T) { +func TestSphinxNodeReplayLaterBatch(t *testing.T) { // We'd like to ensure that the sphinx node itself rejects all replayed // packets which share the same shared secret. nodes, _, _, fwdMsg, err := newTestRoute(testLegacyRouteNumHops) @@ -1439,3 +1439,66 @@ func TestVariablePayloadOnion(t *testing.T) { "match expected BOLT 4 packet, want: %s, got %s", hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes())) } + +// TestUnwrapPacketBeyondRoutingInfoBoundary tests that unwrapPacket does not +// read into the zero-padding area when processing a malformed onion packet +// with an oversized payload. +func TestUnwrapPacketBeyondRoutingInfoBoundary(t *testing.T) { + t.Parallel() + + // Create a router to get a valid key pair. + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + sessionKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + // Compute shared secret as unwrapPacket would. + sessionKeyECDH := &PrivKeyECDH{PrivKey: sessionKey} + sharedSecretArr, err := sessionKeyECDH.ECDH(privKey.PubKey()) + require.NoError(t, err) + sharedSecret := Hash256(sharedSecretArr) + + // Generate the rho key and stream bytes for encryption. + rhoKey := generateKey("rho", &sharedSecret) + streamBytes := generateCipherStream(rhoKey, uint(MaxRoutingPayloadSize)) + + // Create routing info with a malicious payload size. + // 0xfd 0x04 0xf2 encodes 1266 in BigSize format. + // With 3-byte length + 1266-byte payload + 32-byte HMAC = 1301 bytes, + // exceeding the 1300-byte boundary. + routingInfo := make([]byte, MaxRoutingPayloadSize) + + // Build the plaintext payload that will be encrypted. + plaintext := make([]byte, MaxRoutingPayloadSize) + plaintext[0] = 0xfd + plaintext[1] = 0x04 + plaintext[2] = 0xf2 + + // Fill payload area with recognizable pattern. + for i := 3; i < MaxRoutingPayloadSize; i++ { + plaintext[i] = 0xaa + } + + // Encrypt the routing info by XORing with stream bytes. + xor(routingInfo, plaintext, streamBytes[:MaxRoutingPayloadSize]) + + // Compute valid HMAC for the packet. + muKey := generateKey("mu", &sharedSecret) + headerMAC := calcMac(muKey, routingInfo) + + // Create the onion packet. + onionPkt := &OnionPacket{ + Version: baseVersion, + EphemeralKey: sessionKey.PubKey(), + RoutingInfo: routingInfo, + HeaderMAC: headerMAC, + } + + // Process the packet - this should fail because the payload size + // exceeds the routing info boundary. + _, _, err = unwrapPacket(onionPkt, &sharedSecret, nil, true) + expectedErr := fmt.Errorf("payload size %v, limit %v: %w", + 1266, MaxHopPayloadSize, ErrMaxHopPayloadSizeExceeded) + require.Equal(t, expectedErr, err) +}