diff --git a/messaging/common/message_sender.go b/messaging/common/message_sender.go index 2263514674f..11737daf536 100644 --- a/messaging/common/message_sender.go +++ b/messaging/common/message_sender.go @@ -48,15 +48,15 @@ const ( var RekeyCompatibility = true type MessageSender struct { - identity *ecdsa.PrivateKey - datasync *datasync.DataSync - datasyncPersistence datasyncnode.Persistence - transport *transport.Transport - protocol *encryption.Protocol - logger *zap.Logger - persistence messagingtypes.MessageSenderPersistence - segmentationPersistence segmentation.Persistence - publisher *pubsub.Publisher + identity *ecdsa.PrivateKey + datasync *datasync.DataSync + datasyncPersistence datasyncnode.Persistence + transport *transport.Transport + segmenter *segmentation.Segmenter + protocol *encryption.Protocol + logger *zap.Logger + persistence messagingtypes.MessageSenderPersistence + publisher *pubsub.Publisher datasyncEnabled bool @@ -79,16 +79,16 @@ func NewMessageSender( logger *zap.Logger, ) (*MessageSender, error) { p := &MessageSender{ - identity: identity, - datasyncPersistence: datasyncPersistence, - datasyncEnabled: true, // FIXME - protocol: enc, - persistence: persistence, - segmentationPersistence: segmentationPersistence, - publisher: pubsub.NewPublisher(), - transport: transport, - logger: logger, - ephemeralKeys: make(map[string]*ecdsa.PrivateKey), + identity: identity, + datasyncPersistence: datasyncPersistence, + datasyncEnabled: true, // FIXME + transport: transport, + segmenter: segmentation.NewSegmenter(segmentationPersistence, logger), + protocol: enc, + persistence: persistence, + publisher: pubsub.NewPublisher(), + logger: logger, + ephemeralKeys: make(map[string]*ecdsa.PrivateKey), } return p, nil @@ -837,15 +837,13 @@ func (s *MessageSender) HandleMessages(msg *messagingtypes.ReceivedMessage) (*me return nil, s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, msg) } - // The current message segment has been successfully retrieved. - // However, the collection of segments is not yet complete. - if err == ErrMessageSegmentsIncomplete { - return nil, nil - } - return nil, err } + if response == nil { + return nil, nil + } + // Process queued hash ratchet messages for _, hashRatchetInfo := range response.Message.EncryptionLayer.HashRatchetInfo { messages, err := s.persistence.GetHashRatchetMessages(hashRatchetInfo.KeyID) @@ -906,8 +904,7 @@ func (h *handleMessageResponse) Messages() []*messagingtypes.Message { } func (s *MessageSender) handleMessage(receivedMsg *messagingtypes.ReceivedMessage) (*handleMessageResponse, error) { - logger := s.logger.With(zap.String("site", "handleMessage")) - hlogger := logger.With(zap.String("hash", types.EncodeHex(receivedMsg.Hash))) + hlogger := s.logger.Named("handleMessage").With(zap.String("hash", types.EncodeHex(receivedMsg.Hash))) message := &messagingtypes.Message{} @@ -924,18 +921,14 @@ func (s *MessageSender) handleMessage(receivedMsg *messagingtypes.ReceivedMessag return nil, err } - err = s.handleSegmentationLayer(message) + isSegmentMessage, completed, err := s.handleSegmentationLayer(message) if err != nil { - // Segments not completed yet, stop processing - if err == ErrMessageSegmentsIncomplete { - return nil, err - } - // Segments already completed, stop processing - if err == ErrMessageSegmentsAlreadyCompleted { - return nil, err - } + return nil, err + } - // Not a critical error; message wasn't segmented, proceed with next layers. + // Segments not completed yet, stop processing + if isSegmentMessage && !completed { + return nil, nil } err = s.handleEncryptionLayer(context.Background(), message) diff --git a/messaging/common/message_sender_segmentation.go b/messaging/common/message_sender_segmentation.go new file mode 100644 index 00000000000..08dcd071015 --- /dev/null +++ b/messaging/common/message_sender_segmentation.go @@ -0,0 +1,80 @@ +package common + +import ( + "time" + + "github.com/jinzhu/copier" + "go.uber.org/zap" + + "github.com/status-im/status-go/messaging/layers/segmentation" + "github.com/status-im/status-go/messaging/types" + wakutypes "github.com/status-im/status-go/messaging/waku/types" +) + +// reducedMaxMessageSize returns the max message size reduced to 3/4 to leave room for segment metadata +func (s *MessageSender) reducedMaxMessageSize() uint32 { + return s.transport.MaxMessageSize() * 3 / 4 +} + +func (s *MessageSender) segmentMessage(newMessage *wakutypes.NewMessage) ([]*wakutypes.NewMessage, error) { + return s.segmentMessageWithSize(newMessage, int(s.reducedMaxMessageSize())) +} + +func (s *MessageSender) segmentMessageWithSize(newMessage *wakutypes.NewMessage, segmentSize int) ([]*wakutypes.NewMessage, error) { + segments, err := s.segmenter.Segment(newMessage.Payload, segmentSize) + if err != nil { + return nil, err + } + + replicateMessage := func(payload []byte) (*wakutypes.NewMessage, error) { + copy := &wakutypes.NewMessage{} + err := copier.Copy(copy, newMessage) + if err != nil { + return nil, err + } + + copy.Payload = payload + return copy, nil + } + + newMessages := make([]*wakutypes.NewMessage, 0, len(segments)) + for _, segment := range segments { + segmentMessage, err := replicateMessage(segment) + if err != nil { + return nil, err + } + newMessages = append(newMessages, segmentMessage) + } + + s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) + + return newMessages, err +} + +// handleSegmentationLayer is capable of reconstructing the message from both complete and partial sets of data segments. +func (s *MessageSender) handleSegmentationLayer(message *types.Message) (segmented, completed bool, err error) { + var reconstructedPayload []byte + reconstructedPayload, err = s.segmenter.Reconstruct(message.TransportLayer.Payload, message.TransportLayer.SigPubKey) + + switch err { + case nil: + message.TransportLayer.Payload = reconstructedPayload + segmented = true + completed = true + case segmentation.ErrIncomplete: + segmented = true + completed = false + err = nil + case segmentation.ErrInvalidPayload: + segmented = false + completed = false + err = nil + } + + return +} + +func (s *MessageSender) CleanupSegments() error { + monthAgo := time.Now().AddDate(0, -1, 0) + return s.segmenter.CleanupStaleSegments(monthAgo) +} diff --git a/messaging/common/message_sender_test.go b/messaging/common/message_sender_test.go index cde9cf36ec8..78a02b0eac9 100644 --- a/messaging/common/message_sender_test.go +++ b/messaging/common/message_sender_test.go @@ -352,7 +352,7 @@ func (s *MessageSenderSuite) TestHandleSegmentMessages() { wrappedPayload, err := v1protocol.WrapMessageV1(encodedPayload, protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, authorKey) s.Require().NoError(err) - segmentedMessages, err := segmentMessage(&wakutypes.NewMessage{Payload: wrappedPayload}, int(math.Ceil(float64(len(wrappedPayload))/2))) + segmentedMessages, err := s.sender.segmentMessageWithSize(&wakutypes.NewMessage{Payload: wrappedPayload}, int(math.Ceil(float64(len(wrappedPayload))/2))) s.Require().NoError(err) s.Require().Len(segmentedMessages, 2) @@ -379,7 +379,7 @@ func (s *MessageSenderSuite) TestHandleSegmentMessages() { // Receiving another segment after the message has been reassembled is considered an error _, err = s.sender.HandleMessages(message) - s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) + s.Require().ErrorIs(err, segmentation.ErrAlreadyCompleted) } func (s *MessageSenderSuite) TestGetEphemeralKey() { diff --git a/messaging/types/segment_message.go b/messaging/layers/segmentation/message.go similarity index 73% rename from messaging/types/segment_message.go rename to messaging/layers/segmentation/message.go index 04526c9236c..cccec80d52d 100644 --- a/messaging/types/segment_message.go +++ b/messaging/layers/segmentation/message.go @@ -1,12 +1,12 @@ -package types +package segmentation -import "github.com/status-im/status-go/protocol/protobuf" +import "github.com/status-im/status-go/messaging/layers/segmentation/protobuf" -type SegmentMessage struct { +type Message struct { *protobuf.SegmentMessage } -func (s *SegmentMessage) IsValid() bool { +func (s *Message) IsValid() bool { // Check if the hash length is valid (32 bytes for Keccak256) if len(s.EntireMessageHash) != 32 { return false @@ -25,6 +25,6 @@ func (s *SegmentMessage) IsValid() bool { return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0 } -func (s *SegmentMessage) IsParityMessage() bool { +func (s *Message) IsParityMessage() bool { return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0 } diff --git a/messaging/layers/segmentation/persistence.go b/messaging/layers/segmentation/persistence.go index 2b6085eea24..a5507845bd9 100644 --- a/messaging/layers/segmentation/persistence.go +++ b/messaging/layers/segmentation/persistence.go @@ -2,14 +2,12 @@ package segmentation import ( "crypto/ecdsa" - - "github.com/status-im/status-go/messaging/types" ) type Persistence interface { IsMessageAlreadyCompleted(hash []byte) (bool, error) - SaveMessageSegment(segment *types.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error - GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*types.SegmentMessage, error) + SaveMessageSegment(segment *Message, sigPubKey *ecdsa.PublicKey, timestamp int64) error + GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*Message, error) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error RemoveMessageSegmentsOlderThan(timestamp int64) error RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error diff --git a/messaging/layers/segmentation/protobuf/generate.go b/messaging/layers/segmentation/protobuf/generate.go new file mode 100644 index 00000000000..661732a820a --- /dev/null +++ b/messaging/layers/segmentation/protobuf/generate.go @@ -0,0 +1,3 @@ +package protobuf + +//go:generate protoc --go_out=. ./segment_message.proto diff --git a/protocol/protobuf/segment_message.proto b/messaging/layers/segmentation/protobuf/segment_message.proto similarity index 100% rename from protocol/protobuf/segment_message.proto rename to messaging/layers/segmentation/protobuf/segment_message.proto diff --git a/messaging/common/message_segmentation.go b/messaging/layers/segmentation/segmenter.go similarity index 55% rename from messaging/common/message_segmentation.go rename to messaging/layers/segmentation/segmenter.go index 3f55f4bc965..ef244b21cdf 100644 --- a/messaging/common/message_segmentation.go +++ b/messaging/layers/segmentation/segmenter.go @@ -1,67 +1,57 @@ -package common +package segmentation import ( "bytes" + "crypto/ecdsa" "math" "time" + "github.com/cockroachdb/errors" + "github.com/ethereum/go-ethereum/crypto" "github.com/golang/protobuf/proto" - "github.com/jinzhu/copier" "github.com/klauspost/reedsolomon" - "github.com/pkg/errors" "go.uber.org/zap" - "github.com/status-im/status-go/crypto" cryptotypes "github.com/status-im/status-go/crypto/types" - "github.com/status-im/status-go/messaging/types" - wakutypes "github.com/status-im/status-go/messaging/waku/types" - "github.com/status-im/status-go/protocol/protobuf" + "github.com/status-im/status-go/messaging/layers/segmentation/protobuf" ) -var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete") -var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed") -var ErrMessageSegmentsInvalidPayload = errors.New("invalid segment payload") -var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match") -var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments") - const ( segmentsParityRate = 0.125 segmentsReedsolomonMaxCount = 256 ) -func (s *MessageSender) segmentMessage(newMessage *wakutypes.NewMessage) ([]*wakutypes.NewMessage, error) { - // We set the max message size to 3/4 of the allowed message size, to leave - // room for segment message metadata. - newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3)) - s.logger.Debug("message segmented", zap.Int("segments", len(newMessages))) - return newMessages, err +var ErrIncomplete = errors.New("message segments incomplete") +var ErrAlreadyCompleted = errors.New("message segments already completed") +var ErrInvalidPayload = errors.New("invalid segment payload") +var ErrHashMismatch = errors.New("hash of entire payload does not match") +var ErrInvalidParity = errors.New("invalid parity segments") + +type Segmenter struct { + persistence Persistence + logger *zap.Logger } -func replicateMessageWithNewPayload(message *wakutypes.NewMessage, payload []byte) (*wakutypes.NewMessage, error) { - copy := &wakutypes.NewMessage{} - err := copier.Copy(copy, message) - if err != nil { - return nil, err +func NewSegmenter(persistence Persistence, logger *zap.Logger) *Segmenter { + return &Segmenter{ + persistence: persistence, + logger: logger.Named("segmentation"), } - - copy.Payload = payload - return copy, nil } -// Segments message into smaller chunks if the size exceeds segmentSize. -func segmentMessage(newMessage *wakutypes.NewMessage, segmentSize int) ([]*wakutypes.NewMessage, error) { - if len(newMessage.Payload) <= segmentSize { - return []*wakutypes.NewMessage{newMessage}, nil +func (s *Segmenter) Segment(payload []byte, segmentSize int) ([][]byte, error) { + if len(payload) <= segmentSize { + return [][]byte{payload}, nil } - entireMessageHash := crypto.Keccak256(newMessage.Payload) - entirePayloadSize := len(newMessage.Payload) + entireMessageHash := crypto.Keccak256(payload) + entirePayloadSize := len(payload) segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize))) paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate)) segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount) - segmentMessages := make([]*wakutypes.NewMessage, segmentsCount) + segmentMessages := make([][]byte, segmentsCount) for start, index := 0, 0; start < entirePayloadSize; start += segmentSize { end := start + segmentSize @@ -69,7 +59,7 @@ func segmentMessage(newMessage *wakutypes.NewMessage, segmentSize int) ([]*wakut end = entirePayloadSize } - segmentPayload := newMessage.Payload[start:end] + segmentPayload := payload[start:end] segmentWithMetadata := &protobuf.SegmentMessage{ EntireMessageHash: entireMessageHash, Index: uint32(index), @@ -80,13 +70,9 @@ func segmentMessage(newMessage *wakutypes.NewMessage, segmentSize int) ([]*wakut if err != nil { return nil, err } - segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) - if err != nil { - return nil, err - } segmentPayloads[index] = segmentPayload - segmentMessages[index] = segmentMessage + segmentMessages[index] = marshaledSegmentWithMetadata index++ } @@ -129,62 +115,50 @@ func segmentMessage(newMessage *wakutypes.NewMessage, segmentSize int) ([]*wakut if err != nil { return nil, err } - segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata) - if err != nil { - return nil, err - } - segmentMessages = append(segmentMessages, segmentMessage) + segmentMessages = append(segmentMessages, marshaledSegmentWithMetadata) index++ } return segmentMessages, nil } -// handleSegmentationLayer is capable of reconstructing the message from both complete and partial sets of data segments. -// It has capability to perform forward error correction. -func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { - logger := s.logger.Named("handleSegmentationLayer").With(zap.String("hash", cryptotypes.HexBytes(message.TransportLayer.Hash).String())) - - segmentMessage := &types.SegmentMessage{ +func (s *Segmenter) Reconstruct(payload []byte, sigPubKey *ecdsa.PublicKey) ([]byte, error) { + segmentMessage := &Message{ SegmentMessage: &protobuf.SegmentMessage{}, } - err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage) - if err != nil { - return errors.Wrap(err, "failed to unmarshal SegmentMessage") - } - - if !segmentMessage.IsValid() { - return ErrMessageSegmentsInvalidPayload + err := proto.Unmarshal(payload, segmentMessage.SegmentMessage) + if err != nil || !segmentMessage.IsValid() { + return nil, ErrInvalidPayload } - logger.Debug("handling message segment", + s.logger.Debug("handling message segment", zap.String("EntireMessageHash", cryptotypes.HexBytes(segmentMessage.EntireMessageHash).String()), zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount), zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex), zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount)) - alreadyCompleted, err := s.segmentationPersistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) + alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) if err != nil { - return err + return nil, err } if alreadyCompleted { - return ErrMessageSegmentsAlreadyCompleted + return nil, ErrAlreadyCompleted } - err = s.segmentationPersistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) + err = s.persistence.SaveMessageSegment(segmentMessage, sigPubKey, time.Now().Unix()) if err != nil { - return err + return nil, err } - segments, err := s.segmentationPersistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) + segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, sigPubKey) if err != nil { - return err + return nil, err } if len(segments) == 0 { - return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur. + return nil, errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur. } firstSegmentMessage := segments[0] @@ -192,7 +166,7 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { // First segment message must not be a parity message. if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) { - return ErrMessageSegmentsIncomplete + return nil, ErrIncomplete } payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount) @@ -206,7 +180,7 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { } else { enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount)) if err != nil { - return err + return nil, err } var lastNonParitySegmentPayload []byte @@ -227,15 +201,15 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { err = enc.Reconstruct(payloads) if err != nil { - return err + return nil, err } ok, err := enc.Verify(payloads) if err != nil { - return err + return nil, err } if !ok { - return ErrMessageSegmentsInvalidParity + return nil, ErrInvalidParity } if lastNonParitySegmentPayload != nil { @@ -248,35 +222,31 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ { _, err := entirePayload.Write(payloads[i]) if err != nil { - return errors.Wrap(err, "failed to write segment payload") + return nil, errors.Wrap(err, "failed to write segment payload") } } // Sanity check. entirePayloadHash := crypto.Keccak256(entirePayload.Bytes()) if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) { - return ErrMessageSegmentsHashMismatch + return nil, ErrHashMismatch } - err = s.segmentationPersistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) + err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, sigPubKey, time.Now().Unix()) if err != nil { - return err + return nil, err } - message.TransportLayer.Payload = entirePayload.Bytes() - - return nil + return entirePayload.Bytes(), nil } -func (s *MessageSender) CleanupSegments() error { - monthAgo := time.Now().AddDate(0, -1, 0).Unix() - - err := s.segmentationPersistence.RemoveMessageSegmentsOlderThan(monthAgo) +func (s *Segmenter) CleanupStaleSegments(olderThan time.Time) error { + err := s.persistence.RemoveMessageSegmentsOlderThan(olderThan.Unix()) if err != nil { return err } - err = s.segmentationPersistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) + err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(olderThan.Unix()) if err != nil { return err } diff --git a/messaging/common/message_segmentation_test.go b/messaging/layers/segmentation/segmenter_test.go similarity index 78% rename from messaging/common/message_segmentation_test.go rename to messaging/layers/segmentation/segmenter_test.go index 741d378fdf8..e31b379b384 100644 --- a/messaging/common/message_segmentation_test.go +++ b/messaging/layers/segmentation/segmenter_test.go @@ -1,21 +1,17 @@ -package common +package segmentation import ( _ "embed" - "math" "testing" + "github.com/ethereum/go-ethereum/crypto" "github.com/golang/protobuf/proto" bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" "go.uber.org/zap" - "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/messaging/layers/segmentation" - segmentationmigrations "github.com/status-im/status-go/messaging/layers/segmentation/migrations" - "github.com/status-im/status-go/messaging/types" - wakutypes "github.com/status-im/status-go/messaging/waku/types" - "github.com/status-im/status-go/protocol/protobuf" + "github.com/status-im/status-go/messaging/layers/segmentation/migrations" + "github.com/status-im/status-go/messaging/layers/segmentation/protobuf" "github.com/status-im/status-go/t/helpers" ) @@ -26,9 +22,8 @@ func TestMessageSegmentationSuite(t *testing.T) { type MessageSegmentationSuite struct { suite.Suite - sender *MessageSender + segmenter *Segmenter testPayload []byte - logger *zap.Logger } func (s *MessageSegmentationSuite) SetupSuite() { @@ -39,30 +34,18 @@ func (s *MessageSegmentationSuite) SetupSuite() { } func (s *MessageSegmentationSuite) SetupTest() { - identity, err := crypto.GenerateKey() - s.Require().NoError(err) - - s.logger, err = zap.NewDevelopment() - s.Require().NoError(err) - db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ { - Names: segmentationmigrations.AssetNames(), - AssetFunc: segmentationmigrations.Asset, + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, }, })) s.Require().NoError(err) - s.sender, err = NewMessageSender( - identity, - nil, - nil, - segmentation.NewSQLitePersistence(db), - nil, - nil, - s.logger, + s.segmenter = NewSegmenter( + NewSQLitePersistence(db), + zap.Must(zap.NewDevelopment()), ) - s.Require().NoError(err) } func (s *MessageSegmentationSuite) SetupSubTest() { @@ -146,13 +129,13 @@ func (s *MessageSegmentationSuite) TestHandleSegmentationLayer() { for _, tc := range testCases { s.Run(tc.name, func() { - segmentedMessages, err := segmentMessage(&wakutypes.NewMessage{Payload: s.testPayload}, int(math.Ceil(float64(len(s.testPayload))/float64(tc.segmentsCount)))) + signer, err := crypto.GenerateKey() s.Require().NoError(err) - s.Require().Len(segmentedMessages, tc.segmentsCount+tc.expectedParitySegmentsCount) - message := &types.Message{TransportLayer: types.TransportLayer{ - SigPubKey: &s.sender.identity.PublicKey, - }} + segSize := (len(s.testPayload) + tc.segmentsCount - 1) / tc.segmentsCount // ceil[len(testPayload)/segmentsCount] + segmentedMessages, err := s.segmenter.Segment(s.testPayload, segSize) + s.Require().NoError(err) + s.Require().Len(segmentedMessages, tc.segmentsCount+tc.expectedParitySegmentsCount) messageRecreated := false handledSegments := []int{} @@ -160,20 +143,17 @@ func (s *MessageSegmentationSuite) TestHandleSegmentationLayer() { for i, segmentIndex := range tc.retrievedSegments { s.T().Log("i=", i, "segmentIndex=", segmentIndex) - message.TransportLayer.Payload = segmentedMessages[segmentIndex].Payload - - err = s.sender.handleSegmentationLayer(message) - + reconstructedPayload, err := s.segmenter.Reconstruct(segmentedMessages[segmentIndex], &signer.PublicKey) handledSegments = append(handledSegments, segmentIndex) if len(handledSegments) < tc.segmentsCount { - s.Require().ErrorIs(err, ErrMessageSegmentsIncomplete) + s.Require().ErrorIs(err, ErrIncomplete) } else if len(handledSegments) == tc.segmentsCount { s.Require().NoError(err) - s.Require().ElementsMatch(s.testPayload, message.TransportLayer.Payload) + s.Require().ElementsMatch(s.testPayload, reconstructedPayload) messageRecreated = true } else { - s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) + s.Require().ErrorIs(err, ErrAlreadyCompleted) } } @@ -191,7 +171,7 @@ func (s *MessageSegmentationSuite) TestProtobufMissDecoding() { // any byte sequence, and if the structure coincidentally matches valid encoding // patterns (e.g., varint or byte fields), it produces seemingly valid but incorrect results. - segmentedMessage := types.SegmentMessage{ + segmentedMessage := Message{ SegmentMessage: &protobuf.SegmentMessage{}, } diff --git a/messaging/layers/segmentation/sqlite_persistence.go b/messaging/layers/segmentation/sqlite_persistence.go index 28263481d8e..f4eb5ada061 100644 --- a/messaging/layers/segmentation/sqlite_persistence.go +++ b/messaging/layers/segmentation/sqlite_persistence.go @@ -7,8 +7,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" - "github.com/status-im/status-go/messaging/types" - "github.com/status-im/status-go/protocol/protobuf" + "github.com/status-im/status-go/messaging/layers/segmentation/protobuf" ) type SQLitePersistence struct { @@ -28,7 +27,7 @@ func (s *SQLitePersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) return alreadyCompleted > 0, nil } -func (s *SQLitePersistence) SaveMessageSegment(segment *types.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { +func (s *SQLitePersistence) SaveMessageSegment(segment *Message, sigPubKey *ecdsa.PublicKey, timestamp int64) error { sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) _, err := s.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", @@ -38,7 +37,7 @@ func (s *SQLitePersistence) SaveMessageSegment(segment *types.SegmentMessage, si } // Get ordered message segments for given hash -func (s *SQLitePersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*types.SegmentMessage, error) { +func (s *SQLitePersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*Message, error) { sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) rows, err := s.db.Query(` @@ -58,9 +57,9 @@ func (s *SQLitePersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.Pub } defer rows.Close() - var segments []*types.SegmentMessage + var segments []*Message for rows.Next() { - segment := &types.SegmentMessage{ + segment := &Message{ SegmentMessage: &protobuf.SegmentMessage{}, } err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload) diff --git a/messaging/common/testdata/segmentationProtobufMissDecoding.bin b/messaging/layers/segmentation/testdata/segmentationProtobufMissDecoding.bin similarity index 100% rename from messaging/common/testdata/segmentationProtobufMissDecoding.bin rename to messaging/layers/segmentation/testdata/segmentationProtobufMissDecoding.bin diff --git a/protocol/protobuf/service.go b/protocol/protobuf/service.go index b1adc62c3cd..bdb05ceaee2 100644 --- a/protocol/protobuf/service.go +++ b/protocol/protobuf/service.go @@ -4,7 +4,7 @@ import ( "github.com/golang/protobuf/proto" ) -//go:generate protoc --go_out=. ./chat_message.proto ./application_metadata_message.proto ./membership_update_message.proto ./command.proto ./contact.proto ./pairing.proto ./push_notifications.proto ./emoji_reaction.proto ./enums.proto ./shard.proto ./group_chat_invitation.proto ./chat_identity.proto ./communities.proto ./pin_message.proto ./anon_metrics.proto ./status_update.proto ./sync_settings.proto ./contact_verification.proto ./community_update.proto ./community_shard_key.proto ./url_data.proto ./community_privileged_user_sync_message.proto ./profile_showcase.proto ./segment_message.proto ./messenger_local_backup.proto ./wallet_local_backup.proto ./accounts_local_backup.proto +//go:generate protoc --go_out=. ./chat_message.proto ./application_metadata_message.proto ./membership_update_message.proto ./command.proto ./contact.proto ./pairing.proto ./push_notifications.proto ./emoji_reaction.proto ./enums.proto ./shard.proto ./group_chat_invitation.proto ./chat_identity.proto ./communities.proto ./pin_message.proto ./anon_metrics.proto ./status_update.proto ./sync_settings.proto ./contact_verification.proto ./community_update.proto ./community_shard_key.proto ./url_data.proto ./community_privileged_user_sync_message.proto ./profile_showcase.proto ./messenger_local_backup.proto ./wallet_local_backup.proto ./accounts_local_backup.proto func Unmarshal(payload []byte) (*ApplicationMetadataMessage, error) { var message ApplicationMetadataMessage