From 7d9705de4e676948613295a0f4fee57d86b00ce5 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 21:40:16 +0200 Subject: [PATCH 1/3] Implement MSM verification Signed-off-by: Yacov Manevich --- msm/misc.go | 5 +- msm/misc_test.go | 69 +++++- msm/msm.go | 562 +++++++++++++++++++++++++++++++++++++---------- msm/msm_test.go | 132 +++++++++-- 4 files changed, 623 insertions(+), 145 deletions(-) diff --git a/msm/misc.go b/msm/misc.go index 62b1630c..267df4e3 100644 --- a/msm/misc.go +++ b/msm/misc.go @@ -5,6 +5,7 @@ package metadata import ( "context" + "errors" "fmt" "math" "math/big" @@ -15,9 +16,11 @@ import ( // but are not imported here to prevent us from importing the entire Avalanchego codebase. // Once we incorporate Simplex into Avalanchego, we can remove this file and import the relevant code from Avalanchego instead. +var errOverflow = errors.New("overflow") + func safeAdd(a, b uint64) (uint64, error) { if a > math.MaxUint64-b { - return 0, fmt.Errorf("overflow: %d + %d > maxuint64", a, b) + return 0, fmt.Errorf("%w: %d + %d > maxuint64", errOverflow, a, b) } return a + b, nil } diff --git a/msm/misc_test.go b/msm/misc_test.go index b78d2cd3..ba798adb 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/asn1" + "errors" "fmt" "maps" "math" @@ -25,7 +26,7 @@ func TestSafeAdd(t *testing.T) { name string a, b uint64 sum uint64 - err string + err error }{ { name: "zero plus zero", @@ -50,12 +51,12 @@ func TestSafeAdd(t *testing.T) { { name: "overflow by one", a: math.MaxUint64, b: 1, - err: "overflow", + err: errOverflow, }, { name: "overflow both large", a: math.MaxUint64 - 5, b: 10, - err: "overflow", + err: errOverflow, }, { name: "max uint64 boundary no overflow", @@ -65,8 +66,8 @@ func TestSafeAdd(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { result, err := safeAdd(tc.a, tc.b) - if tc.err != "" { - require.ErrorContains(t, err, tc.err) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) } else { require.NoError(t, err) require.Equal(t, tc.sum, result) @@ -487,10 +488,66 @@ func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertifica panic("unused in tests") } +var errTestAggregationFailed = errors.New("aggregation failed") + func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { - return nil, fmt.Errorf("aggregation failed") + return nil, errTestAggregationFailed } func (failingAggregator) IsQuorum([]simplex.NodeID) bool { return false } + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +} diff --git a/msm/msm.go b/msm/msm.go index 2b9c548b..fdf46da9 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -6,6 +6,8 @@ package metadata import ( "context" "crypto/sha256" + "encoding/binary" + "errors" "fmt" "time" @@ -13,6 +15,36 @@ import ( "go.uber.org/zap" ) +var ( + errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") + errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") + errUnknownState = errors.New("unknown state") + errNilInnerBlock = errors.New("InnerBlock is nil") + errBuiltGenesisInnerBlock = errors.New("received a genesis block") + errZeroBlockParentNoInnerBlock = errors.New("failed constructing zero block: parent block has no inner block") + errNilBlock = errors.New("block is nil") + errParentInnerBlockHasNoInnerBlock = errors.New("parent inner block has no inner block") + errInvalidPChainHeight = errors.New("invalid P-chain height") + errInvalidSimplexEpochInfo = errors.New("invalid SimplexEpochInfo") + errZeroBlockHasInnerBlock = errors.New("zero block must not have an inner block") + errZeroBlockInnerDigestMismatch = errors.New("zero block inner block digest does not match last non-Simplex inner block digest") + errZeroBlockTimestampMismatch = errors.New("zero block timestamp does not match last non-Simplex inner block timestamp") + errPrevSealingBlockNotFinalized = errors.New("previous sealing InnerBlock is not finalized") + errFirstEverSimplexBlockNotSet = errors.New("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") + errSealingBlockSeqUnset = errors.New("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + errNilNextEpochApprovals = errors.New("next epoch approvals is nil") +) + +var ( + errPChainReferenceHeightMismatch = errors.New("unexpected P-chain reference height") + errPChainReferenceHeightDecreased = errors.New("P-chain reference height is decreasing") + errValidatorSetUnchanged = errors.New("validator set unchanged; next P-chain reference height should not have advanced") + errPChainHeightNotReached = errors.New("haven't reached referenced P-chain height yet") + errUnknownBlockType = errors.New("unknown block type") + errPChainHeightTooBig = errors.New("invalid P-chain height: greater than current") + errPChainHeightSmallerThanParent = errors.New("invalid P-chain height: smaller than parent block's") +) + // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. type StateMachineBlock struct { // InnerBlock is the VM-level block, or nil if this is a block without an inner block (e.g., a Telock block). @@ -148,7 +180,7 @@ const ( func NewStateMachine(config *Config) (*StateMachine, error) { if config.LastNonSimplexInnerBlock == nil { config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") - return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + return nil, errLastNonSimplexInnerBlockNil } sm := StateMachine{Config: config} return &sm, nil @@ -158,7 +190,7 @@ func NewStateMachine(config *Config) (*StateMachine, error) { func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.ProtocolMetadata, blacklist *simplex.Blacklist) (*StateMachineBlock, error) { // The zero sequence number is reserved for the genesis block, which should never be built. if metadata.Seq == 0 { - return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) + return nil, fmt.Errorf("%w: got %d", errInvalidProtocolMetadataSeq, metadata.Seq) } prevBlockSeq := metadata.Seq - 1 @@ -206,7 +238,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco case stateBuildBlockEpochSealed: return sm.buildBlockEpochSealed(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) default: - return nil, fmt.Errorf("unknown state %d", currentState) + return nil, fmt.Errorf("%w: %d", errUnknownState, currentState) } } @@ -214,7 +246,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco // and inner block against the previous block and the current state. func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { if block == nil { - return fmt.Errorf("InnerBlock is nil") + return errNilInnerBlock } pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) @@ -225,7 +257,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc seq := pmd.Seq if seq == 0 { - return fmt.Errorf("attempted to build a genesis inner block") + return errBuiltGenesisInnerBlock } prevBlock, _, err := sm.GetBlock(seq-1, pmd.Prev) @@ -240,49 +272,25 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc case stateFirstSimplexBlock: err = sm.verifyBlockZero(block, prevBlock) default: - err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) + err = sm.verifyNonZeroBlock(ctx, block, &prevBlock, seq-1) } return err } -func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { - blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) - sm.Logger.Debug("Identified block type", - zap.Stringer("blockType", blockType), - zap.Bool("nextHasBVD", block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("nextEpochNumber", block.Metadata.SimplexEpochInfo.EpochNumber), - zap.Bool("prevHasBVD", prevBlockMD.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("prevEpochNumber", prevBlockMD.SimplexEpochInfo.EpochNumber), - zap.Uint64("prevNextPChainRefHeight", prevBlockMD.SimplexEpochInfo.NextPChainReferenceHeight), - zap.Uint64("prevSealingBlockSeq", prevBlockMD.SimplexEpochInfo.SealingBlockSeq), - zap.Uint64("prevSeq", prevSeq), - ) - - var innerBlockTimestamp time.Time - if block.InnerBlock != nil { - innerBlockTimestamp = block.InnerBlock.Timestamp() - } - - for _, verifier := range sm.verifiers { - if err := verifier.Verify(verificationInput{ - proposedBlockMD: block.Metadata, - nextBlockType: blockType, - prevMD: prevBlockMD, - state: state, - prevBlockSeq: prevSeq, - hasInnerBlock: block.InnerBlock != nil, - innerBlockTimestamp: innerBlockTimestamp, - }); err != nil { - sm.Logger.Debug("Invalid block", zap.Error(err)) - return err - } - } +func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock *StateMachineBlock, prevSeq uint64) error { + prevBlockMD := prevBlock.Metadata + currentState := prevBlockMD.SimplexEpochInfo.NextState() - if block.InnerBlock == nil { - return nil + switch currentState { + case stateBuildBlockNormalOp: + return sm.verifyNormalBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildCollectingApprovals: + return sm.verifyCollectingApprovalsBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildBlockEpochSealed: + return sm.verifyBlockEpochSealed(ctx, *prevBlock, block, prevSeq) + default: + return fmt.Errorf("%w: %d", errUnknownBlockType, currentState) } - - return block.InnerBlock.Verify(ctx) } // buildBlockNormalOp builds a block while potentially also transitioning to a new epoch, depending on the P-chain. @@ -329,6 +337,137 @@ func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentB return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil } +func (sm *StateMachine) verifyNormalBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := parentBlock.Metadata.PChainHeight + proposedPChainHeight := nextBlock.Metadata.PChainHeight + + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + if err := sm.verifyNextPChainRefHeightNormal(parentBlock.Metadata, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for normal block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight + + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + return err + } + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock.Metadata.SimplexProtocolMetadata, nextBlock.Metadata.SimplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil +} + +func verifyPChainHeight(proposedPChainHeight uint64, currentPChainHeight uint64, prevPChainHeight uint64) error { + if proposedPChainHeight > currentPChainHeight { + return fmt.Errorf("%w: proposed %d, current %d", + errPChainHeightTooBig, proposedPChainHeight, currentPChainHeight) + } + + if prevPChainHeight > proposedPChainHeight { + return fmt.Errorf("%w: proposed %d, parent %d", + errPChainHeightSmallerThanParent, proposedPChainHeight, prevPChainHeight) + } + return nil +} + +func (sm *StateMachine) verifyNextPChainRefHeightNormal(prevMD StateMachineMetadata, next SimplexEpochInfo) error { + prev := prevMD.SimplexEpochInfo + // Next P-chain height can only increase, not decrease. + if next.NextPChainReferenceHeight > 0 && prev.PChainReferenceHeight > next.NextPChainReferenceHeight { + return fmt.Errorf("%w: previous P-chain reference height is %d and the proposed P-chain reference height is %d", errPChainReferenceHeightDecreased, prev.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If the previous block already has a next P-chain reference height, + // we should keep the same next P-chain reference height until we reach it. + if prev.NextPChainReferenceHeight > 0 { + if next.NextPChainReferenceHeight != prev.NextPChainReferenceHeight { + return fmt.Errorf("%w: expected %d but got %d", errPChainReferenceHeightMismatch, prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) + } + return nil + } + + // If we reached here, then prev.NextPChainReferenceHeight == 0. + // It might be that this block is the first block that has set the next P-chain reference height for the epoch, + // so check if it has done so correctly by observing whether the validator set has indeed changed. + + currentValidatorSet, err := sm.GetValidatorSet(prevMD.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + // If the validator set doesn't change, we shouldn't have increased the next P-chain reference height. + if currentValidatorSet.Equal(newValidatorSet) && next.NextPChainReferenceHeight > 0 { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches previous block's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, prev.PChainReferenceHeight) + } + + // Else, either the validator set has changed, or the next P-chain reference height is still 0. + // Both of these cases are fine, but we should verify that we have observed the next P-chain reference height if it is > 0. + + pChainHeight := sm.GetPChainHeight() + + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + return nil +} + +// verifyNextPChainRefHeightForNewEpoch validates the proposed NextPChainReferenceHeight on the +// first block of a new epoch. The parent's NextPChainReferenceHeight describes the transition +// that just completed, so we cannot reuse verifyNextPChainRefHeightNormal here — the baseline +// for the validator-set change check is the new epoch's PChainReferenceHeight, not the parent's. +func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(newEpoch SimplexEpochInfo, next SimplexEpochInfo) error { + if next.NextPChainReferenceHeight == 0 { + return nil + } + + if next.NextPChainReferenceHeight < newEpoch.PChainReferenceHeight { + return fmt.Errorf("%w: new epoch P-chain reference height is %d and the proposed next P-chain reference height is %d", + errPChainReferenceHeightDecreased, newEpoch.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + currentValidatorSet, err := sm.GetValidatorSet(newEpoch.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + if currentValidatorSet.Equal(newValidatorSet) { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches new epoch's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, newEpoch.PChainReferenceHeight) + } + + pChainHeight := sm.GetPChainHeight() + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + return nil +} + func (sm *StateMachine) createBlockBuildingDecider(pChainReferenceHeight uint64) blockBuildingDecider { blockBuildingDecider := blockBuildingDecider{ logger: sm.Logger, @@ -389,7 +528,7 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet // We can only have blocks without inner blocks in Simplex blocks, but this is the first Simplex block. // Therefore, the parent block must have an inner block. sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") - return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") + return nil, errZeroBlockParentNoInnerBlock } timestamp := sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli() @@ -415,29 +554,25 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock StateMachineBlock) error { if block == nil { - return fmt.Errorf("block is nil") + return errNilBlock } simplexEpochInfo := block.Metadata.SimplexEpochInfo - if simplexEpochInfo.EpochNumber != 1 { - return fmt.Errorf("invalid epoch number (%d), should be 1", simplexEpochInfo.EpochNumber) - } - if prevBlock.InnerBlock == nil { - return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) + return fmt.Errorf("%w: parent digest %s", errParentInnerBlockHasNoInnerBlock, prevBlock.Digest()) } pChainHeight := sm.LastNonSimplexBlockPChainHeight prevVMBlockSeq := prevBlock.InnerBlock.Height() if block.Metadata.PChainHeight != pChainHeight { - return fmt.Errorf("invalid P-chain height (%d), expected to be %d", - block.Metadata.PChainHeight, pChainHeight) + return fmt.Errorf("%w: got %d, expected %d", + errInvalidPChainHeight, block.Metadata.PChainHeight, pChainHeight) } var expectedValidatorSet NodeBLSMappings - if prevBlock.InnerBlock.Height() == 0 { + if prevVMBlockSeq == 0 { expectedValidatorSet = sm.GenesisValidatorSet } else { var err error @@ -447,40 +582,159 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat } } - if simplexEpochInfo.BlockValidationDescriptor == nil { - return fmt.Errorf("invalid BlockValidationDescriptor: should not be nil") - } - - membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members - if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { - return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", pChainHeight) - } - // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo expectedSimplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, expectedValidatorSet, prevVMBlockSeq) if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { - return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) + return fmt.Errorf("%w: expected %v, got %v", errInvalidSimplexEpochInfo, expectedSimplexEpochInfo, simplexEpochInfo) } // The InnerBlock must match the last non-Simplex inner block. if block.InnerBlock != nil { - return fmt.Errorf("zero block must not have an inner block") + return errZeroBlockHasInnerBlock } if prevBlock.InnerBlock.Digest() != sm.LastNonSimplexInnerBlock.Digest() { - return fmt.Errorf("zero block inner block digest does not match last non-Simplex inner block digest") + return errZeroBlockInnerDigestMismatch } // The timestamp must equal the last non-Simplex inner block's timestamp. expectedTimestamp := uint64(sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli()) if block.Metadata.Timestamp != expectedTimestamp { - return fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, block.Metadata.Timestamp) + return fmt.Errorf("%w: expected %d but got %d", errZeroBlockTimestampMismatch, expectedTimestamp, block.Metadata.Timestamp) } return nil } func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + newApprovals, err := sm.computeNewApprovals(parentBlock) + if err != nil { + return nil, err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, newApprovals) + + pChainHeight := parentBlock.Metadata.PChainHeight + + // We might not have enough approvals to seal the current epoch, + // in which case we just carry over the approvals we have so far to the next block, + // so that eventually we'll have enough approvals to seal the epoch. + if !newApprovals.canSeal { + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + } + + sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") + + // Else, we have enough approvals to seal the epoch, so we create the sealing block. + return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) +} + +func (sm *StateMachine) verifyCollectingApprovalsBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + nextMD := nextBlock.Metadata + newApprovals := nextMD.SimplexEpochInfo.NextEpochApprovals + if newApprovals == nil { + return errNilNextEpochApprovals + } + + prevEpochInfo := parentBlock.Metadata.SimplexEpochInfo + nextEpochInfo := nextBlock.Metadata.SimplexEpochInfo + + validators, err := sm.GetValidatorSet(prevEpochInfo.NextPChainReferenceHeight) + if err != nil { + return err + } + + err = sm.verifyNextEpochApprovalsSignature(prevEpochInfo, nextEpochInfo, validators) + if err != nil { + return err + } + + // A node cannot remove other nodes' approvals, only add its own approval if it wasn't included in the previous block. + // So the set of signers in next.NextEpochApprovals should be a superset of the set of signers in prev.NextEpochApprovals. + if err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prevEpochInfo, nextEpochInfo); err != nil { + return err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, &approvals{ + nodeIDs: newApprovals.NodeIDs, + signature: newApprovals.Signature, + }) + + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + approvals := bitmaskFromBytes(newApprovals.NodeIDs) + canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvals)) + + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + sm.Logger.Debug("Failed verifying inner block", zap.Error(err)) + return err + } + } + + blacklist := nextMD.SimplexBlacklist + protocolMD := nextMD.SimplexProtocolMetadata + pChainHeight := parentBlock.Metadata.PChainHeight + + if !canSeal { + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil + } + + // Else, we verify the sealing block. + newSimplexEpochInfo, err = sm.computeSimplexEpochInfoForSealingBlock(newSimplexEpochInfo) + if err != nil { + return fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil +} + +func (sm *StateMachine) verifyNextEpochApprovalsSignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { + // First figure out which validators are approving the next epoch by looking at the bitmask of approving nodes, + // and then aggregate their public keys together to verify the signature. + + nodeIDsBitmask := next.NextEpochApprovals.NodeIDs + aggPK, err := sm.aggregatePubKeysForBitmask(nodeIDsBitmask, validators) + if err != nil { + return err + } + + pChainHeight := prev.NextPChainReferenceHeight + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) + + if err := sm.SignatureVerifier.VerifySignature(next.NextEpochApprovals.Signature, pChainHeightBuff, aggPK); err != nil { + return fmt.Errorf("failed to verify signature: %w", err) + } + return nil +} + +func (sm *StateMachine) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { + approvingNodes := bitmaskFromBytes(nodeIDsBitmask) + publicKeys := make([][]byte, 0, len(validators)) + for i := range validators { + if !approvingNodes.Contains(i) { + continue + } + publicKeys = append(publicKeys, validators[i].BLSKey) + } + + aggPK, err := sm.KeyAggregator.AggregateKeys(publicKeys...) + if err != nil { + return nil, fmt.Errorf("failed to aggregate public keys: %w", err) + } + return aggPK, nil +} + +func computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock StateMachineBlock, prevBlockSeq uint64, newApprovals *approvals) SimplexEpochInfo { // The P-chain reference height and epoch number should remain the same until we transition to the new epoch. // The next P-chain reference height should have been set in the previous block, // which is the reason why we are collecting approvals in the first place. @@ -491,6 +745,18 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + // This might be the first time we created approvals for the next epoch, + // so we need to initialize the NextEpochApprovals. + if newSimplexEpochInfo.NextEpochApprovals == nil { + newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} + } + // The node IDs and signature are aggregated across all past and present approvals. + newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs + newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature + return newSimplexEpochInfo +} + +func (sm *StateMachine) computeNewApprovals(parentBlock StateMachineBlock) (*approvals, error) { // We prepare information that is needed to compute the approvals for the new epoch, // such as the validator set for the next epoch, and the approvals from peers. validators, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) @@ -498,43 +764,21 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren return nil, err } + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + // We retrieve approvals that validators have sent us for the next epoch. // These approvals are signed by validators of the next epoch. approvalsFromPeers := sm.ApprovalsRetriever.Approvals() sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) - nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight + nextPChainHeight := parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals - sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators, sm.Logger) if err != nil { return nil, err } - - // This might be the first time we created approvals for the next epoch, - // so we need to initialize the NextEpochApprovals. - if newSimplexEpochInfo.NextEpochApprovals == nil { - newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} - } - // The node IDs and signature are aggregated across all past and present approvals. - newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs - newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature - pChainHeight := parentBlock.Metadata.PChainHeight - - // We might not have enough approvals to seal the current epoch, - // in which case we just carry over the approvals we have so far to the next block, - // so that eventually we'll have enough approvals to seal the epoch. - if !newApprovals.canSeal { - sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) - } - - sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") - - // Else, we have enough approvals to seal the epoch, so we create the sealing block. - return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + return newApprovals, nil } // buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. @@ -561,9 +805,17 @@ func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock S } func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + simplexEpochInfo, err := sm.computeSimplexEpochInfoForSealingBlock(simplexEpochInfo) + if err != nil { + return nil, fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) +} + +func (sm *StateMachine) computeSimplexEpochInfoForSealingBlock(simplexEpochInfo SimplexEpochInfo) (SimplexEpochInfo, error) { validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) if err != nil { - return nil, err + return SimplexEpochInfo{}, err } if simplexEpochInfo.BlockValidationDescriptor == nil { simplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} @@ -575,22 +827,22 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) if err != nil { sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) - return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) + return SimplexEpochInfo{}, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) } if finalization == nil { sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) - return nil, fmt.Errorf("previous sealing InnerBlock at epoch %d is not finalized", simplexEpochInfo.EpochNumber-1) + return SimplexEpochInfo{}, fmt.Errorf("%w: epoch %d", errPrevSealingBlockNotFinalized, simplexEpochInfo.EpochNumber-1) } simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. firstSimplexBlock := sm.FirstEverSimplexBlock() if firstSimplexBlock == nil { - return nil, fmt.Errorf("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") + return SimplexEpochInfo{}, errFirstEverSimplexBlockNotSet } simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() } - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) + return simplexEpochInfo, nil } // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. @@ -617,11 +869,7 @@ func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBl } } -// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. -func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { - // We check if the sealing block has already been finalized. - // If not, we build a Telock block. - +func (sm *StateMachine) isSealingBlockFinalized(parentBlock StateMachineBlock, prevBlockSeq uint64) (bool, uint64, StateMachineBlock, error) { sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.SealingBlockSeq // If the sealing block sequence is still 0, it means previous block was the sealing block. @@ -630,23 +878,27 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S } if sealingBlockSeq == 0 { - return nil, fmt.Errorf("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + return false, 0, StateMachineBlock{}, errSealingBlockSeqUnset } - newSimplexEpochInfo := SimplexEpochInfo{ - PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, - EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, - NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, - SealingBlockSeq: sealingBlockSeq, - PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return false, 0, StateMachineBlock{}, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) } - sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + return finalization != nil, sealingBlockSeq, sealingBlock, nil +} + +// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // We check if the sealing block has already been finalized. + // If not, we build a Telock block. + isSealingBlockFinalized, sealingBlockSeq, sealingBlock, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) if err != nil { - return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) + return nil, err } - isSealingBlockFinalized := finalization != nil + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) if !isSealingBlockFinalized { pChainHeight := parentBlock.Metadata.PChainHeight @@ -654,6 +906,15 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S } // Else, we build a block for the new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // TODO: This P-chain height should be taken from the ICM epoch + + return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + +} + +func computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo SimplexEpochInfo, parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo = SimplexEpochInfo{ // P-chain reference height is previous block's NextPChainReferenceHeight. PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, @@ -661,9 +922,70 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S EpochNumber: sealingBlockSeq, PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + return newSimplexEpochInfo +} + +func computeSimplexEpochInfoForTelock(parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + SealingBlockSeq: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + return newSimplexEpochInfo +} + +func (sm *StateMachine) verifyBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + isSealingBlockFinalized, sealingBlockSeq, _, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) + if err != nil { + return err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) + + simplexMetadata := nextBlock.Metadata.SimplexProtocolMetadata + simplexBlacklist := nextBlock.Metadata.SimplexBlacklist + pChainHeight := parentBlock.Metadata.PChainHeight + + if !isSealingBlockFinalized { + expectedBlock := sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil + } + + // Else, it's a new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // The first block of the new epoch may itself transition again, so trust and validate + // the proposed pchain height and (optional) next pchain reference height, mirroring + // what buildBlockOrTransitionEpoch does on the build side. + proposedPChainHeight := nextBlock.Metadata.PChainHeight + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := parentBlock.Metadata.PChainHeight + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + if err := sm.verifyNextPChainRefHeightForNewEpoch(newSimplexEpochInfo, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for new epoch block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight // TODO: This P-chain height should be taken from the ICM epoch - return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + return err + } + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, simplexMetadata, simplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil } // constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. @@ -738,7 +1060,7 @@ func computeNewApproverSignaturesAndSigners( logger simplex.Logger, ) ([]byte, bitmask, error) { if nextEpochApprovals == nil { - return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") + return nil, bitmask{}, errNilNextEpochApprovals } // Prepare the new signatures from the new approvals that haven't approved yet and that agree with our candidate auxiliary info digest and P-Chain height. newSignatures := make([][]byte, 0, len(approvalsFromPeers)+1) @@ -826,21 +1148,17 @@ func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) u } var ( - errSignerSetShrunk = fmt.Errorf("some signers from parent block are missing from next epoch approvals of proposed block") - errNextEpochApprovalsShrunk = fmt.Errorf("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") + errSignerSetShrunk = errors.New("some signers from parent block are missing from next epoch approvals of proposed block") + errNextEpochApprovalsShrunk = errors.New("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") ) -func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if prev.NextEpochApprovals == nil { - // Condition satisfied vacuously. +func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if prev.NextEpochApprovals == nil || len(prev.NextEpochApprovals.NodeIDs) == 0 { return nil } - // Else, prev.NextEpochApprovals is not nil. - // If next.NextEpochApprovals is nil, condition is not satisfied. if next.NextEpochApprovals == nil { - return errNextEpochApprovalsShrunk + return fmt.Errorf("%w: previous block has next epoch approvals but proposed block doesn't have next epoch approvals", errNextEpochApprovalsShrunk) } - // Make sure that previous signers are still there. prevSigners := bitmaskFromBytes(prev.NextEpochApprovals.NodeIDs) nextSigners := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) diff --git a/msm/msm_test.go b/msm/msm_test.go index eff624da..dc843705 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -6,6 +6,7 @@ package metadata import ( "context" "crypto/rand" + "errors" "fmt" "testing" "time" @@ -15,6 +16,8 @@ import ( "github.com/stretchr/testify/require" ) +var errBlockDigestMismatch = errors.New("does not match proposed block digest") + func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ Round: 1, @@ -26,7 +29,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { for _, testCase := range []struct { name string md simplex.ProtocolMetadata - err string + err error configure func(*StateMachine, *testConfig) mutateBlock func(*StateMachineBlock) }{ @@ -43,7 +46,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { md.Seq = 0 block.Metadata.SimplexProtocolMetadata = md.Bytes() }, - err: "attempted to build a genesis inner block", + err: errBuiltGenesisInnerBlock, }, { name: "previous block not found", @@ -51,7 +54,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(_ *StateMachine, tc *testConfig) { delete(tc.blockStore, 0) }, - err: "failed to retrieve previous (0) inner block", + err: simplex.ErrBlockNotFound, }, { name: "parent has no inner block", @@ -61,7 +64,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { block: StateMachineBlock{}, } }, - err: "parent inner block (", + err: errParentInnerBlockHasNoInnerBlock, }, { name: "wrong epoch number", @@ -69,7 +72,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.EpochNumber = 2 }, - err: "invalid epoch number (2), should be 1", + err: errInvalidSimplexEpochInfo, }, { name: "P-chain height too big", @@ -77,7 +80,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.PChainHeight = 110 }, - err: "invalid P-chain height (110), expected to be 100", + err: errInvalidPChainHeight, }, { name: "P-chain height smaller than parent", @@ -85,7 +88,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(sm *StateMachine, tc *testConfig) { sm.LastNonSimplexBlockPChainHeight = 99 }, - err: "invalid P-chain height (100), expected to be 99", + err: errInvalidPChainHeight, }, { name: "nil BlockValidationDescriptor", @@ -93,7 +96,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = nil }, - err: "invalid BlockValidationDescriptor: should not be nil", + err: errInvalidSimplexEpochInfo, }, { name: "membership mismatch", @@ -103,7 +106,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { {BLSKey: []byte{1}, Weight: 1}, } }, - err: "invalid BlockValidationDescriptor: should match validator set", + err: errInvalidSimplexEpochInfo, }, { name: "SimplexEpochInfo mismatch", @@ -111,7 +114,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 }, - err: "invalid SimplexEpochInfo", + err: errInvalidSimplexEpochInfo, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -131,8 +134,8 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { } err = sm2.VerifyBlock(context.Background(), block) - if testCase.err != "" { - require.ErrorContains(t, err, testCase.err) + if testCase.err != nil { + require.ErrorIs(t, err, testCase.err) return } require.NoError(t, err) @@ -212,6 +215,8 @@ func TestMSMNormalOp(t *testing.T) { for _, testCase := range []struct { name string setup func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + err error expectedPChainHeight uint64 expectedNextPChainRefHeight uint64 }{ @@ -219,6 +224,82 @@ func TestMSMNormalOp(t *testing.T) { name: "correct information", expectedPChainHeight: 100, }, + { + name: "trying to build a genesis block", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: errBuiltGenesisInnerBlock, + }, + { + name: "previous block not found", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 999 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: simplex.ErrBlockNotFound, + }, + { + name: "P-chain height too big", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: errPChainHeightTooBig, + }, + { + name: "P-chain height smaller than parent", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 0 + }, + err: errPChainHeightSmallerThanParent, + }, + { + name: "wrong epoch number", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: errBlockDigestMismatch, + }, + { + name: "non-nil BlockValidationDescriptor", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + }, + err: errBlockDigestMismatch, + }, + { + name: "non-zero sealing block seq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.SealingBlockSeq = 5 + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PChainReferenceHeight", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PChainReferenceHeight = 50 + }, + err: errBlockDigestMismatch, + }, + { + name: "non-empty PrevSealingBlockHash", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevSealingBlockHash = [32]byte{1, 2, 3} + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PrevVMBlockSeq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: errBlockDigestMismatch, + }, { name: "validator set change detected", setup: func(sm *StateMachine, tc *testConfig) { @@ -234,9 +315,11 @@ func TestMSMNormalOp(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { chain := makeChain(t, 5, 10) sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) for i, block := range chain { testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} + testConfig2.blockStore[uint64(i)] = &outerBlock{block: block} } lastBlock := chain[len(chain)-1] @@ -264,13 +347,29 @@ func TestMSMNormalOp(t *testing.T) { if testCase.setup != nil { testCase.setup(sm1, testConfig1) + testCase.setup(sm2, testConfig2) } block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) require.NoError(t, err) require.NotNil(t, block1) - require.Equal(t, &StateMachineBlock{ + if testCase.mutateBlock != nil { + testCase.mutateBlock(block1) + } + + err = sm2.VerifyBlock(context.Background(), block1) + if testCase.err != nil { + if testCase.err == errBlockDigestMismatch { + require.ErrorContains(t, err, testCase.err.Error()) + } else { + require.ErrorIs(t, err, testCase.err) + } + return + } + require.NoError(t, err) + + expected := &StateMachineBlock{ InnerBlock: &InnerBlock{ TS: blockTime, BlockHeight: lastBlock.InnerBlock.Height(), @@ -288,7 +387,8 @@ func TestMSMNormalOp(t *testing.T) { NextPChainReferenceHeight: testCase.expectedNextPChainRefHeight, }, }, - }, block1) + } + require.Equal(t, expected.Digest(), block1.Digest()) }) } } @@ -775,7 +875,7 @@ func TestAreNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(t *testing.T }, } { t.Run(tc.name, func(t *testing.T) { - err := ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) + err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) if tc.err != nil { require.ErrorIs(t, err, tc.err) } else { @@ -983,6 +1083,6 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { } _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}, logger) - require.ErrorContains(t, err, "aggregation failed") + require.ErrorIs(t, err, errTestAggregationFailed) }) } From 2c540083a2947a1aff6d1aae96420e3b5f750454 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 13 May 2026 20:42:04 +0200 Subject: [PATCH 2/3] Add ApprovalStore for validator approvals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also self-sign during aggregation - Introduces ApprovalStore (msm/approvals.go) — an in-memory store of validator approvals for epoch transitions, keyed by (NodeID, PChainHeight). It verifies BLS signatures on ingest, deduplicates by timestamp (newer wins, older is dropped), and prunes per-node entries to at most len(validators) by evicting the oldest timestamp. - Adds Timestamp to ValidatorSetApproval (msm/encoding.go) so the store can order/evict approvals deterministically. - computeNewApprovals (msm/msm.go) now optimistically self-signs the next epoch's P-chain reference height each round and appends its own ValidatorSetApproval to the peer set since the store deduplicates it later. Signed-off-by: Yacov Manevich --- msm/approvals.go | 139 +++++++++++++++++++++ msm/approvals_test.go | 283 ++++++++++++++++++++++++++++++++++++++++++ msm/encoding.go | 1 + msm/misc_test.go | 60 +-------- msm/msm.go | 28 +++-- testutil/util.go | 10 +- 6 files changed, 454 insertions(+), 67 deletions(-) create mode 100644 msm/approvals.go create mode 100644 msm/approvals_test.go diff --git a/msm/approvals.go b/msm/approvals.go new file mode 100644 index 00000000..80e37793 --- /dev/null +++ b/msm/approvals.go @@ -0,0 +1,139 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "encoding/binary" + "fmt" + + "github.com/ava-labs/simplex" + "go.uber.org/zap" +) + +type approvalsByPChainHeight map[uint64]*ValidatorSetApproval + +type ApprovalStore struct { + signatureVerifier SignatureVerifier + validators NodeBLSMappings + logger simplex.Logger + pkByNodeID map[nodeID][]byte + approvalsByNodes map[nodeID]approvalsByPChainHeight + storedCount int +} + +func NewApprovalStore(signatureVerifier SignatureVerifier, validators NodeBLSMappings, logger simplex.Logger) *ApprovalStore { + pkByNodeID := make(map[nodeID][]byte) + for _, vdr := range validators { + pkByNodeID[vdr.NodeID] = vdr.BLSKey + } + + approvalsByNodes := make(map[nodeID]approvalsByPChainHeight, len(validators)) + for _, vdr := range validators { + approvalsByNodes[vdr.NodeID] = make(approvalsByPChainHeight) + } + + return &ApprovalStore{ + signatureVerifier: signatureVerifier, + validators: validators, + pkByNodeID: pkByNodeID, + logger: logger, + approvalsByNodes: approvalsByNodes, + } +} + +func (as *ApprovalStore) Approvals() ValidatorSetApprovals { + approvals := make(ValidatorSetApprovals, 0, as.storedCount) + for _, approvalsByHeight := range as.approvalsByNodes { + for _, approval := range approvalsByHeight { + approvals = append(approvals, *approval) + } + } + return approvals +} + +func (as *ApprovalStore) HandleApproval(approval *ValidatorSetApproval) error { + // First thing we check is if the node that sent this approval is a validator. + pk, exists := as.getPKOfNode(approval.NodeID) + if !exists { + as.logger.Debug("Received an approval from a node that is not a validator", zap.String("nodeID", + fmt.Sprintf("%x", approval.NodeID)), zap.Uint64("pChainHeight", approval.PChainHeight)) + return nil + } + + // Second thing we check is if we already have an approval for this height from this node. + if as.approvalExistsAndUpToDate(approval) { + as.logger.Debug("Already have an approval from the node", zap.String("nodeID", + fmt.Sprintf("%x", approval.NodeID)), zap.Uint64("pChainHeight", approval.PChainHeight)) + return nil + } + + // Third thing we check is if the signature of the approval is valid. + // We need it to be valid in order for nodes to be able to aggregate it later on along with other approvals. + if err := as.checkApprovalSignature(approval, pk); err != nil { + as.logger.Debug("Received an approval with an invalid signature", zap.String("nodeID", + fmt.Sprintf("%x", approval.NodeID)), zap.Uint64("pChainHeight", approval.PChainHeight)) + return nil + } + + // Store the approval. + oldApproval := as.approvalsByNodes[approval.NodeID][approval.PChainHeight] + as.approvalsByNodes[approval.NodeID][approval.PChainHeight] = approval + + if oldApproval == nil { + as.storedCount++ + } + + // We only store the last |as.validators| of approvals for each node, + // so we need to delete old approvals if we have more than |as.validators| approvals stored for this node. + as.maybePruneOldApprovals(approval) + + return nil +} + +func (as *ApprovalStore) maybePruneOldApprovals(approval *ValidatorSetApproval) { + for len(as.approvalsByNodes[approval.NodeID]) > len(as.validators) { + // Find the oldest approval and delete it. + var oldestApproval *ValidatorSetApproval + for _, approval := range as.approvalsByNodes[approval.NodeID] { + if oldestApproval == nil || approval.Timestamp < oldestApproval.Timestamp { + oldestApproval = approval + } + } + + if oldestApproval != nil { + as.logger.Debug("Deleting old approval from node", + zap.String("nodeID", fmt.Sprintf("%x", oldestApproval.NodeID)), + zap.String("oldestApprovalPChainHeight", + fmt.Sprintf("%d", oldestApproval.PChainHeight)), zap.Uint64("oldestApprovalTimestamp", oldestApproval.Timestamp)) + delete(as.approvalsByNodes[approval.NodeID], oldestApproval.PChainHeight) + as.storedCount-- + } + } +} + +func (as *ApprovalStore) checkApprovalSignature(approval *ValidatorSetApproval, pk []byte) error { + pChainHeight := approval.PChainHeight + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) + + // We check if the signature is valid before we store the approval. + return as.signatureVerifier.VerifySignature(approval.Signature, pChainHeightBuff, pk) +} + +func (as *ApprovalStore) getPKOfNode(nodeID nodeID) ([]byte, bool) { + pk, exists := as.pkByNodeID[nodeID] + return pk, exists +} + +func (as *ApprovalStore) approvalExistsAndUpToDate(approval *ValidatorSetApproval) bool { + if as.approvalsByNodes[approval.NodeID] == nil { + return false + } + existingApproval := as.approvalsByNodes[approval.NodeID][approval.PChainHeight] + if existingApproval == nil { + return false + } + + return existingApproval.Timestamp >= approval.Timestamp +} diff --git a/msm/approvals_test.go b/msm/approvals_test.go new file mode 100644 index 00000000..534af61a --- /dev/null +++ b/msm/approvals_test.go @@ -0,0 +1,283 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "errors" + "math" + "testing" + + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +func makeNodeID(seed byte) nodeID { + var n nodeID + n[0] = seed + return n +} + +func makeValidators(n int) NodeBLSMappings { + vdrs := make(NodeBLSMappings, n) + for i := 0; i < n; i++ { + vdrs[i] = NodeBLSMapping{ + NodeID: makeNodeID(byte(i + 1)), + BLSKey: []byte{byte(i + 1)}, + Weight: 1, + } + } + return vdrs +} + +func newApprovalStoreForTest(t *testing.T, validators NodeBLSMappings, sigErr error) *ApprovalStore { + t.Helper() + return NewApprovalStore(&signatureVerifier{err: sigErr}, validators, testutil.MakeLogger(t)) +} + +func TestApprovalStoreHandleApprovalUnknownNode(t *testing.T) { + // Verifies that an approval from a node that is not in the validator set is silently dropped. + + as := newApprovalStoreForTest(t, makeValidators(3), nil) + require.NoError(t, as.HandleApproval(&ValidatorSetApproval{ + NodeID: makeNodeID(99), + PChainHeight: 1, + Timestamp: 1, + })) + require.Empty(t, as.Approvals()) + require.Equal(t, 0, as.storedCount) +} + +func TestApprovalStoreHandleApprovalInvalidSignature(t *testing.T) { + // Verifies that an approval from a known validator whose signature fails verification is dropped without being stored. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, errors.New("bad sig")) + + require.NoError(t, as.HandleApproval(&ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 1, + Timestamp: 1, + Signature: []byte{0xAA}, + })) + require.Empty(t, as.Approvals()) + require.Equal(t, 0, as.storedCount) +} + +func TestApprovalStoreHandleApprovalStoresValidApproval(t *testing.T) { + // happy path: an approval from a known validator with a valid signature is stored and + //is retrievable via Approvals() and also increases storedCount. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + a := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 100, + Signature: []byte{0x01}, + } + require.NoError(t, as.HandleApproval(a)) + + got := as.Approvals() + require.Len(t, got, 1) + require.Equal(t, *a, got[0]) + require.Equal(t, 1, as.storedCount) +} + +func TestApprovalStoreHandleApprovalDuplicateSameTimestamp(t *testing.T) { + // Verifies that handing the same (NodeID, PChainHeight, Timestamp) twice is a no-op. + // The store keeps exactly one copy and storedCount does not double-count. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + a := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 100, + Signature: []byte{0x01}, + } + require.NoError(t, as.HandleApproval(a)) + require.NoError(t, as.HandleApproval(a)) + require.Len(t, as.Approvals(), 1) + require.Equal(t, 1, as.storedCount) +} + +func TestApprovalStoreHandleApprovalOlderTimestampIgnored(t *testing.T) { + // Verifies that when a newer approval is already stored for a (NodeID, PChainHeight), a subsequent + // approval with an older Timestamp is dropped and does not overwrite it. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + newer := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 200, + Signature: []byte{0x02}, + } + require.NoError(t, as.HandleApproval(newer)) + + older := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 100, + Signature: []byte{0x01}, + } + require.NoError(t, as.HandleApproval(older)) + + got := as.Approvals() + require.Len(t, got, 1) + require.Equal(t, *newer, got[0]) + require.Equal(t, 1, as.storedCount) +} + +func TestApprovalStoreHandleApprovalNewerTimestampReplaces(t *testing.T) { + // verifies that a newer approval at the same (NodeID, PChainHeight) replaces the previously stored + // one in place without changing storedCount. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + older := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 100, + Signature: []byte{0x01}, + } + require.NoError(t, as.HandleApproval(older)) + + newer := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: 200, + Signature: []byte{0x02}, + } + require.NoError(t, as.HandleApproval(newer)) + + got := as.Approvals() + require.Len(t, got, 1) + require.Equal(t, *newer, got[0]) + require.Equal(t, 1, as.storedCount) +} + +func TestApprovalStoreHandleApprovalMultipleNodesAndHeights(t *testing.T) { + // Verifies that the store keeps independent entries per (NodeID, PChainHeight) tuple. + // Different validators and different heights all coexist. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + for i, v := range vdrs { + for _, h := range []uint64{1, 2} { + require.NoError(t, as.HandleApproval(&ValidatorSetApproval{ + NodeID: v.NodeID, + PChainHeight: h, + Timestamp: uint64(i*10) + h, + Signature: []byte{byte(i)}, + })) + } + } + + require.Len(t, as.Approvals(), len(vdrs)*2) + require.Equal(t, len(vdrs)*2, as.storedCount) +} + +func TestApprovalStoreHandleApprovalPrunesOldestWhenOverCap(t *testing.T) { + // Verifies that once a node accumulates more approvals than len(validators), + // the entry with the oldest Timestamp is evicted and storedCount stays in sync. + vdrs := makeValidators(2) + as := newApprovalStoreForTest(t, vdrs, nil) + + node := vdrs[0].NodeID + for _, a := range []*ValidatorSetApproval{ + {NodeID: node, PChainHeight: 1, Timestamp: 10, Signature: []byte{1}}, + {NodeID: node, PChainHeight: 2, Timestamp: 20, Signature: []byte{2}}, + {NodeID: node, PChainHeight: 3, Timestamp: 30, Signature: []byte{3}}, + } { + require.NoError(t, as.HandleApproval(a)) + } + + got := as.Approvals() + require.Len(t, got, 2, "store should be pruned down to cap=len(validators)=2") + require.Equal(t, 2, as.storedCount) + + timestamps := map[uint64]bool{} + for _, a := range got { + timestamps[a.Timestamp] = true + } + require.False(t, timestamps[10], "oldest (ts=10) should have been pruned") + require.True(t, timestamps[20]) + require.True(t, timestamps[30]) +} + +func TestApprovalStoreHandleApprovalPruningIsPerNode(t *testing.T) { + // Verifies that the cap is applied per-NodeID, not globally: filling one node up to its cap does not + // affect another node's approvals. + + vdrs := makeValidators(2) + as := newApprovalStoreForTest(t, vdrs, nil) + + require.NoError(t, as.HandleApproval(&ValidatorSetApproval{ + NodeID: vdrs[1].NodeID, + PChainHeight: 1, + Timestamp: 100, + })) + + for h := uint64(1); h <= 10; h++ { + require.NoError(t, as.HandleApproval(&ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: h, + Timestamp: h, + })) + } + require.Len(t, as.Approvals().UniqueByNodeID(), 2) +} + +func TestApprovalStoreHandleApprovalMaxUint64Timestamp(t *testing.T) { + // Verifies that an approval with the maximum uint64 timestamp is stored, + // and that a subsequent approval at the same (NodeID, PChainHeight) with any + // smaller timestamp is treated as older and does not replace it. + vdrs := makeValidators(3) + as := newApprovalStoreForTest(t, vdrs, nil) + + maxTS := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: math.MaxUint64, + Signature: []byte{0xFF}, + } + require.NoError(t, as.HandleApproval(maxTS)) + + got := as.Approvals() + require.Len(t, got, 1) + require.Equal(t, *maxTS, got[0]) + require.Equal(t, 1, as.storedCount) + + older := &ValidatorSetApproval{ + NodeID: vdrs[0].NodeID, + PChainHeight: 7, + Timestamp: math.MaxUint64 - 1, + Signature: []byte{0x01}, + } + require.NoError(t, as.HandleApproval(older)) + + got = as.Approvals() + require.Len(t, got, 1) + require.Equal(t, *maxTS, got[0]) + require.Equal(t, 1, as.storedCount) +} + +func TestApprovalStoreHandleApprovalStoredCountStaysConsistent(t *testing.T) { + // runs a mixed workload (insert, duplicate, replace, new height, prune) + // and asserts that storedCount equals len(Approvals()) after every step. + vdrs := makeValidators(2) + as := newApprovalStoreForTest(t, vdrs, nil) + node := vdrs[0].NodeID + + for _, a := range []*ValidatorSetApproval{ + {NodeID: node, PChainHeight: 1, Timestamp: 10}, + {NodeID: node, PChainHeight: 1, Timestamp: 10}, // duplicate + {NodeID: node, PChainHeight: 1, Timestamp: 20}, // replaces + {NodeID: node, PChainHeight: 2, Timestamp: 30}, // new height + {NodeID: node, PChainHeight: 3, Timestamp: 40}, // triggers prune + } { + require.NoError(t, as.HandleApproval(a)) + } + require.Equal(t, 2, len(as.Approvals())) +} diff --git a/msm/encoding.go b/msm/encoding.go index 7fd22189..b5c527e9 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -294,6 +294,7 @@ type ValidatorSetApproval struct { AuxInfoSeqDigest [32]byte `canoto:"fixed bytes,2"` PChainHeight uint64 `canoto:"uint,3"` Signature []byte `canoto:"bytes,4"` + Timestamp uint64 `canoto:"uint,5"` canotoData canotoData_ValidatorSetApproval } diff --git a/msm/misc_test.go b/msm/misc_test.go index ba798adb..40d86a55 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -416,6 +416,10 @@ func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { bs := make(blockStore) bs[0] = &outerBlock{block: genesisBlock} + var myNodeID nodeID + _, err := rand.Read(myNodeID[:]) + require.NoError(t, err) + var testConfig testConfig testConfig.blockStore = bs testConfig.validatorSetRetriever.result = NodeBLSMappings{ @@ -458,6 +462,8 @@ func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, PChainProgressListener: &noOpPChainListener{}, LastNonSimplexInnerBlock: genesisBlock.InnerBlock, + MyNodeID: myNodeID[:], + Signer: &testutil.TestSigner{}, } sm, err := NewStateMachine(&smConfig) @@ -497,57 +503,3 @@ func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { func (failingAggregator) IsQuorum([]simplex.NodeID) bool { return false } - -type testBlockStore map[uint64]StateMachineBlock - -func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, ok := bs[seq] - if !ok { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) - } - return blk, nil, nil -} - -type testVMBlock struct { - bytes []byte - height uint64 -} - -func (b *testVMBlock) Digest() [32]byte { - return sha256.Sum256(b.bytes) -} - -func (b *testVMBlock) Height() uint64 { - return b.height -} - -func (b *testVMBlock) Timestamp() time.Time { - return time.Now() -} - -func (b *testVMBlock) Verify(_ context.Context) error { - return nil -} - -type testSigVerifier struct { - err error -} - -func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { - return sv.err -} - -type testKeyAggregator struct { - err error -} - -func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - if ka.err != nil { - return nil, ka.err - } - var agg []byte - for _, k := range keys { - agg = append(agg, k...) - } - return agg, nil -} diff --git a/msm/msm.go b/msm/msm.go index fdf46da9..de294c90 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -111,16 +111,8 @@ type verificationInput struct { state state } -type verifier interface { - Verify(in verificationInput) error -} - // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { - // verifiers is the list of verifiers used to verify proposed blocks. - // Each verifier is responsible for verifying a specific aspect of the block's metadata. - verifiers []verifier - *Config } @@ -166,6 +158,10 @@ type Config struct { LastNonSimplexInnerBlock VMBlock // GenesisValidatorSet is the validator set used for the genesis block. GenesisValidatorSet NodeBLSMappings + // MyNodeID + MyNodeID simplex.NodeID + // Signer + Signer simplex.Signer } type state uint8 @@ -771,6 +767,22 @@ func (sm *StateMachine) computeNewApprovals(parentBlock StateMachineBlock) (*app approvalsFromPeers := sm.ApprovalsRetriever.Approvals() sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) + // Optimistically sign the epoch transition even if we have already did so in a previous round. + // We'll just deduplicate this approval later on. + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) + sig, err := sm.Signer.Sign(pChainHeightBuff) + if err != nil { + return nil, fmt.Errorf("failed to sign approval: %w", err) + } + + approvalsFromPeers = append(approvalsFromPeers, ValidatorSetApproval{ + NodeID: nodeID(sm.MyNodeID), + PChainHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + Timestamp: uint64(sm.GetTime().UnixMilli()), + Signature: sig, + }) + nextPChainHeight := parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals diff --git a/testutil/util.go b/testutil/util.go index 4be23174..5ab5af05 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -27,7 +27,7 @@ func DefaultTestNodeEpochConfig(t *testing.T, nodeID simplex.NodeID, comm simple Comm: comm, Logger: l, ID: nodeID, - Signer: &testSigner{}, + Signer: &TestSigner{}, WAL: wal, Verifier: &testVerifier{}, Storage: storage, @@ -51,7 +51,7 @@ func NewTestVote(block AnyBlock, id simplex.NodeID) (*simplex.Vote, error) { vote := simplex.ToBeSignedVote{ BlockHeader: block.BlockHeader(), } - sig, err := vote.Sign(&testSigner{}) + sig, err := vote.Sign(&TestSigner{}) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func InjectTestVote(t *testing.T, e *simplex.Epoch, block simplex.VerifiedBlock, func NewTestFinalizeVote(t *testing.T, block simplex.VerifiedBlock, id simplex.NodeID) *simplex.FinalizeVote { f := simplex.ToBeSignedFinalization{BlockHeader: block.BlockHeader()} - sig, err := f.Sign(&testSigner{}) + sig, err := f.Sign(&TestSigner{}) require.NoError(t, err) return &simplex.FinalizeVote{ Signature: simplex.Signature{ @@ -182,10 +182,10 @@ func (t TestQC) Bytes() []byte { return bytes } -type testSigner struct { +type TestSigner struct { } -func (t *testSigner) Sign([]byte) ([]byte, error) { +func (t *TestSigner) Sign([]byte) ([]byte, error) { return []byte{1, 2, 3}, nil } From c1ffb03ba56af1825867fb726c525a3fed4eef78 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 20 May 2026 23:39:27 +0200 Subject: [PATCH 3/3] Add inlined documentation about state transitions Signed-off-by: Yacov Manevich --- msm/encoding.go | 11 ++-- msm/msm.go | 138 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 122 insertions(+), 27 deletions(-) diff --git a/msm/encoding.go b/msm/encoding.go index b5c527e9..9b80dde9 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -29,12 +29,14 @@ type StateMachineMetadata struct { // For Simplex epoching, the P-Chain height that matters is the PChainReferenceHeight in the SimplexEpochInfo. PChainHeight uint64 `canoto:"uint,4"` // Timestamp is the time when the block is being built, in milliseconds since Unix epoch. + // It is derived from the timestamp of the inner block, otherwise it is inherited from the previous block. Timestamp uint64 `canoto:"uint,5"` canotoData canotoData_StateMachineMetadata } // SimplexEpochInfo is metadata used by the StateMachine. + type SimplexEpochInfo struct { // PChainReferenceHeight is the P-Chain height that the StateMachine uses as a reference for the current epoch. // The validator set is determined based on the validators on the P-Chain at the PChainReferenceHeight. @@ -44,18 +46,21 @@ type SimplexEpochInfo struct { // of the sealing block of the previous epoch. EpochNumber uint64 `canoto:"uint,2"` // PrevSealingBlockHash is the hash of the sealing block of the previous epoch. - // It is empty for the first epoch, and the second epoch has the PrevSealingBlockHash set to be - // the hash of the first ever block built by the StateMachine. + // It is set to the hash of the zero block in the first epoch, and in subsequent epochs it is set to be + // the hash of the sealing block of the previous epoch. + // This is used to be able to quickly fetch and verify the sealing blocks without having to retrieve the interleaving blocks, + // which allows to bootstrap the BLS keys of the validator set for each epoch before fully syncing the interleaving blocks. PrevSealingBlockHash [32]byte `canoto:"fixed bytes,3"` // NextPChainReferenceHeight is the P-Chain height that the StateMachine uses as a reference for the next epoch. // When the NextPChainReferenceHeight is > 0, it means the StateMachine is on its way to transition to a new epoch // in which the validator set will be based on the given P-chain height. + // It sets the PChainReferenceHeight for the next epoch. NextPChainReferenceHeight uint64 `canoto:"uint,4"` // PrevVMBlockSeq is the block sequence of the previous block that has a VM block (inner block). // This is used to know on which VM block to build the next block. PrevVMBlockSeq uint64 `canoto:"uint,5"` // BlockValidationDescriptor is the metadata that describes the validator set of the next epoch. - // It is only set in the sealing block, and nil in all other blocks. + // It is only set in the sealing block and zero block, and nil in all other blocks. BlockValidationDescriptor *BlockValidationDescriptor `canoto:"pointer,6"` // NextEpochApprovals is the metadata that contains the approvals from validators for the next epoch. // It is set only in the sealing block and the blocks preceding it starting from a block that has a NextPChainReferenceHeight set. diff --git a/msm/msm.go b/msm/msm.go index de294c90..a3e20705 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -15,6 +15,51 @@ import ( "go.uber.org/zap" ) +// state encodes the different stages of the epoch transition process, which determines how we build and verify blocks. +// +// SimplexEpochInfo.NextState() inspects the parent block's metadata to perform the following state transitions: +// +// (initial state: No Simplex blocks yet) +// │ +// ▼ +// ┌───────────────────────────────────┐ +// │ stateFirstSimplexBlock │ builds the zero block (no inner block); +// │ │ creates epoch 1 with the initial validator set +// └─────────────────┬─────────────────┘ +// │ +// ▼ +// ┌───────────────────────────────────┐ ◀── validator set unchanged ──┐ +// │ stateBuildBlockNormalOp │ │ +// │ builds inner blocks within the │ ──────────────────────────────┘ +// │ current epoch │ ◀────────────────────────────────────────────┐ +// └─────────────────┬─────────────────┘ │ +// │ validator set changed │ +// │ (sets NextPChainReferenceHeight > 0) │ +// ▼ │ +// ┌───────────────────────────────────┐ ◀── not enough approvals ─────┐ │ +// │ stateBuildCollectingApprovals │ │ │ +// │ aggregates approvals from │ ──────────────────────────────┘ │ +// │ the next epoch's validator set │ │ +// └─────────────────┬─────────────────┘ │ +// │ quorum reached: emit sealing block │ +// │ (BlockValidationDescriptor set) │ +// ▼ │ +// ┌───────────────────────────────────┐ ◀── sealing block ────────────┐ │ +// │ stateBuildBlockEpochSealed │ not finalized yet │ │ +// │ emits Telock (no inner block) │ ──────────────────────────────┘ │ +// │ until the sealing block is │ │ +// │ finalized; then opens the new │ ─── new epoch (EpochNumber advanced) ─────────┘ +// │ epoch │ +// └───────────────────────────────────┘ +type state uint8 + +const ( + stateFirstSimplexBlock state = iota + 1 + stateBuildBlockNormalOp + stateBuildCollectingApprovals + stateBuildBlockEpochSealed +) + var ( errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") @@ -101,16 +146,6 @@ type BlockBuilder interface { WaitForPendingBlock(ctx context.Context) } -type verificationInput struct { - prevMD StateMachineMetadata - proposedBlockMD StateMachineMetadata - hasInnerBlock bool - innerBlockTimestamp time.Time // only set when hasInnerBlock is true - prevBlockSeq uint64 - nextBlockType BlockType - state state -} - // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { *Config @@ -164,15 +199,6 @@ type Config struct { Signer simplex.Signer } -type state uint8 - -const ( - stateFirstSimplexBlock state = iota + 1 - stateBuildBlockNormalOp - stateBuildCollectingApprovals - stateBuildBlockEpochSealed -) - func NewStateMachine(config *Config) (*StateMachine, error) { if config.LastNonSimplexInnerBlock == nil { config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") @@ -290,6 +316,17 @@ func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock } // buildBlockNormalOp builds a block while potentially also transitioning to a new epoch, depending on the P-chain. +// +// Relevant SimplexEpochInfo fields (PCH = PChainReferenceHeight, +// EN = EpochNumber, NPCH = NextPChainReferenceHeight): +// +// parent (NormalOp) validator set unchanged validator set changed at p' +// ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +// │ PCH = p │ ───► │ PCH = p (copy) │ │ PCH = p (copy) │ +// │ EN = e │ OR │ EN = e (copy) │ OR │ EN = e (copy) │ +// │ NPCH = 0 │ │ NPCH = 0 │ │ NPCH = p' (> 0) │ +// └─────────────────┘ └─────────────────┘ └─────────────────┘ +// → stays NormalOp → CollectingApprovals func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { // Since in the previous block, we were not transitioning to a new epoch, // the P-chain reference height and epoch of the new block should remain the same. @@ -503,6 +540,28 @@ func (sm *StateMachine) createBlockBuildingDecider(pChainReferenceHeight uint64) // buildBlockZero builds the first ever block for Simplex, // which is a special block that introduces the first validator set and starts the first epoch. +// +// How EpochNumber (EN), PrevSealingBlockHash (PSH), and SealingBlockSeq (SBS) +// evolve along the block chain (Seq = block sequence number; h(n) = digest of +// the block at sequence n): +// +// ────────────────── Epoch 1 ────────────────────────────────────│─── Epoch s ──── +// │ +// Seq: z ... s s+1 ... s+x │ s+1 (Telocks get pruned) ... +// ┌──────┐ ┌────────┐ ┌──────┐ ┌──────┐ │ ┌────────────┐ +// │ Zero │ ... │Sealing │ │Telock│ ... │Telock│ │ │first block │ ... +// │ block│ │ block │ │ │ │ │ │ │ of epoch s │ +// └──────┘ └────────┘ └──────┘ └──────┘ │ └────────────┘ +// EN = 1 EN = 1 EN = 1 EN = 1 │ EN = s +// SBS = 0 SBS = 0 SBS = s SBS = s │ SBS = 0 +// PSH = 0 PSH = h(z) PSH = 0 PSH = 0 │ PSH = 0 +// +// - EN : copied within an epoch; on the first block of a new epoch, EN +// equals the sequence number of the previous epoch's sealing block. +// - PSH : only set on a sealing block. In epoch 1 it points to the zero block; +// in epoch e > 1 it points to the previous epoch's sealing block. +// - SBS : 0 except on Telocks of a sealed-but-not-yet-finalized epoch, where +// it equals the sequence number of that epoch's sealing block. func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { pChainHeight := sm.LastNonSimplexBlockPChainHeight @@ -602,6 +661,22 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat return nil } +// buildBlockCollectingApprovals builds either another collecting-approvals block (if not enough approvals yet) +// or a sealing block (if quorum is reached). +// +// Relevant SimplexEpochInfo fields (EN = EpochNumber, NPCH = NextPChainReferenceHeight, +// NEA = NextEpochApprovals, BVD = BlockValidationDescriptor, PSH = PrevSealingBlockHash): +// +// parent (Collecting) not enough approvals yet quorum of approvals reached: sealing block +// ┌──────────────────┐ ┌────────────────────┐ ┌────────────────────────────┐ +// │ EN = e │ │ EN = e │ │ EN = e │ +// │ NPCH = p' │ ────► │ NPCH = p' │ │ NPCH = p' │ +// │ NEA = A_old │ │ NEA = A_old ∪ new │ OR │ NEA = A_old ∪ new │ +// │ BVD = nil │ │ BVD = nil │ │ BVD = validator set at p' │ +// │ │ │ │ │ PSH = h(prev epoch's │ +// │ │ │ │ │ sealing block) │ +// └──────────────────┘ └────────────────────┘ └────────────────────────────┘ +// → stays Collecting → BuildBlockEpochSealed func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { newApprovals, err := sm.computeNewApprovals(parentBlock) if err != nil { @@ -858,19 +933,19 @@ func (sm *StateMachine) computeSimplexEpochInfoForSealingBlock(simplexEpochInfo } // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. -func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { +func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, innerBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { timestamp := parentBlock.Metadata.Timestamp - hasChildBlock := childBlock != nil + hasInnerBlock := innerBlock != nil var newTimestamp time.Time - if hasChildBlock { - newTimestamp = childBlock.Timestamp() + if hasInnerBlock { + newTimestamp = innerBlock.Timestamp() timestamp = uint64(newTimestamp.UnixMilli()) } return &StateMachineBlock{ - InnerBlock: childBlock, + InnerBlock: innerBlock, Metadata: StateMachineMetadata{ Timestamp: timestamp, SimplexProtocolMetadata: simplexMetadata, @@ -902,6 +977,21 @@ func (sm *StateMachine) isSealingBlockFinalized(parentBlock StateMachineBlock, p } // buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +// +// Relevant SimplexEpochInfo fields (PCH = PChainReferenceHeight, EN = EpochNumber, +// NPCH = NextPChainReferenceHeight, SBS = SealingBlockSeq, BVD = BlockValidationDescriptor): +// +// parent (sealing block) sealing block NOT finalized sealing block IS finalized +// → emit Telock (no inner block) → first block of new epoch +// ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────────────┐ +// │ Seq = s │ │ Seq = s+1 │ │ Seq = s+1 │ +// │ PCH = p │ │ PCH = p (copy) │ │ PCH = p' (was NPCH) │ +// │ EN = e │ ──► │ EN = e (copy) │ OR │ EN = s │ +// │ NPCH = p' │ OR │ NPCH = p' (copy) │ │ NPCH = 0 (reset) │ +// │ SBS = 0 │ │ SBS = s │ │ SBS = 0 │ +// │ BVD = vset@p' │ │ BVD = nil │ │ BVD = nil │ +// └──────────────────┘ └──────────────────┘ └──────────────────────────┘ +// → stays EpochSealed → NormalOp (new epoch) func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { // We check if the sealing block has already been finalized. // If not, we build a Telock block.