diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 682e2b3847c2..4b7885bb20ec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -234,6 +234,11 @@ jobs: with: name: "${{ github.sha }}-store-coverage" continue-on-error: true + - uses: actions/download-artifact@v4 + if: env.GIT_DIFF + with: + name: "${{ github.sha }}-blockstm-coverage" + continue-on-error: true - uses: actions/download-artifact@v4 if: env.GIT_DIFF with: @@ -253,7 +258,7 @@ jobs: if: env.GIT_DIFF uses: codecov/codecov-action@v5 with: - files: ./00profile.out,./01profile.out,./02profile.out,./03profile.out,./integration-profile.out,./e2e-profile.out,./client/v2/coverage.out,./core/coverage.out,./depinject/coverage.out,./errors/coverage.out,./math/coverage.out,./schema/coverage.out,./collections/coverage.out,./tools/cosmovisor/coverage.out,./tools/confix/coverage.out,./store/coverage.out,./log/coverage.out,./x/tx/coverage.out,./tools/benchmark/coverage.out + files: ./00profile.out,./01profile.out,./02profile.out,./03profile.out,./integration-profile.out,./e2e-profile.out,./client/v2/coverage.out,./core/coverage.out,./depinject/coverage.out,./errors/coverage.out,./math/coverage.out,./schema/coverage.out,./collections/coverage.out,./tools/cosmovisor/coverage.out,./tools/confix/coverage.out,./store/coverage.out,./log/coverage.out,./x/tx/coverage.out,./tools/benchmark/coverage.out,./blockstm/coverage.out fail_ci_if_error: false verbose: true token: ${{ secrets.CODECOV_TOKEN }} @@ -570,6 +575,33 @@ jobs: with: name: "${{ github.sha }}-store-coverage" path: ./store/coverage.out + test-blockstm: + runs-on: depot-ubuntu-22.04-4 + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-go@v6 + with: + go-version: "1.25" + check-latest: true + cache: true + cache-dependency-path: store/go.sum + - uses: technote-space/get-diff-action@v6.1.2 + id: git_diff + with: + PATTERNS: | + blockstm/**/*.go + blockstm/go.mod + blockstm/go.sum + - name: tests + if: env.GIT_DIFF + run: | + cd blockstm + go test -mod=readonly -timeout 30m -coverprofile=coverage.out -covermode=atomic -coverpkg=./,./tree -tags='norace ledger test_ledger_mock' ./... + - uses: actions/upload-artifact@v4 + if: env.GIT_DIFF + with: + name: "${{ github.sha }}-blockstm-coverage" + path: ./blockstm/coverage.out test-log: runs-on: depot-ubuntu-22.04-4 diff --git a/blockstm/README.md b/blockstm/README.md new file mode 100644 index 000000000000..eaedabf6dfa1 --- /dev/null +++ b/blockstm/README.md @@ -0,0 +1,41 @@ +`blockstm` implements the [block-stm algorithm](https://arxiv.org/abs/2203.06871), it follows the paper pseudocode pretty closely. + +The main API is a simple function call `ExecuteBlock`: + +```golang +type ExecuteFn func(TxnIndex, MultiStore) +func ExecuteBlock( + ctx context.Context, // context for cancellation + blockSize int, // the number of the transactions to be executed + stores []storetypes.StoreKey, // the list of store keys to support + storage MultiStore, // the parent storage, after all transactions are executed, the whole change sets are written into parent storage at once + executors int, // how many concurrent executors to spawn + executeFn ExecuteFn, // callback function to actually execute a transaction with a wrapped `MultiStore`. +) error +``` + +The main deviations from the paper are: + +### Optimisation + +We applied the optimization described in section 4 of the paper: + +``` +Block-STM calls add_dependency from the VM itself, and can thus re-read and continue execution when false is returned. +``` + +When the VM execution reads an `ESTIMATE` mark, it'll hang on a `CondVar`, so it can resume execution after the dependency is resolved, +much more efficient than abortion and rerun. + +### Support Deletion, Iteration, and MultiStore + +These features are necessary for integration with cosmos-sdk. + +The multi-version data structure is implemented with nested btree for easier iteration support, +the `WriteSet` is also implemented with a btree, and it takes advantage of ordered property to optimize some logic. + +The internal data structures are also adapted with multiple stores in mind. + +### Attribution + +This package was originally authored in [go-block-stm](https://github.com/crypto-org-chain/go-block-stm). We have brought the full source tree into the SDK so that we can natively incorporate the library and required changes into the SDK. Over time we expect to incoporate optimizations and deviations from the upstream implementation. diff --git a/blockstm/bench_test.go b/blockstm/bench_test.go new file mode 100644 index 000000000000..ddbe3c752acf --- /dev/null +++ b/blockstm/bench_test.go @@ -0,0 +1,54 @@ +package blockstm + +import ( + "context" + "strconv" + "testing" + + "github.com/test-go/testify/require" + + storetypes "cosmossdk.io/store/types" +) + +func BenchmarkBlockSTM(b *testing.B) { + stores := map[storetypes.StoreKey]int{StoreKeyAuth: 0, StoreKeyBank: 1} + for i := 0; i < 26; i++ { + key := storetypes.NewKVStoreKey(strconv.FormatInt(int64(i), 10)) + stores[key] = i + 2 + } + storage := NewMultiMemDB(stores) + testCases := []struct { + name string + block *MockBlock + }{ + {"random-10000/100", testBlock(10000, 100)}, + {"no-conflict-10000", noConflictBlock(10000)}, + {"worst-case-10000", worstCaseBlock(10000)}, + {"iterate-10000/100", iterateBlock(10000, 100)}, + } + for _, tc := range testCases { + b.Run(tc.name+"-sequential", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + runSequential(storage, tc.block) + } + }) + for _, worker := range []int{1, 5, 10, 15, 20} { + b.Run(tc.name+"-worker-"+strconv.Itoa(worker), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + require.NoError( + b, + ExecuteBlock(context.Background(), tc.block.Size(), stores, storage, worker, tc.block.ExecuteTx), + ) + } + }) + } + } +} + +func runSequential(storage MultiStore, block *MockBlock) { + for i, tx := range block.Txs { + block.Results[i] = tx(storage) + } +} diff --git a/blockstm/condvar.go b/blockstm/condvar.go new file mode 100644 index 000000000000..543de1d7188c --- /dev/null +++ b/blockstm/condvar.go @@ -0,0 +1,30 @@ +package blockstm + +import "sync" + +type Condvar struct { + sync.Mutex + notified bool + cond sync.Cond +} + +func NewCondvar() *Condvar { + c := &Condvar{} + c.cond = *sync.NewCond(c) + return c +} + +func (cv *Condvar) Wait() { + cv.Lock() + for !cv.notified { + cv.cond.Wait() + } + cv.Unlock() +} + +func (cv *Condvar) Notify() { + cv.Lock() + cv.notified = true + cv.Unlock() + cv.cond.Signal() +} diff --git a/blockstm/executor.go b/blockstm/executor.go new file mode 100644 index 000000000000..4014c4ead966 --- /dev/null +++ b/blockstm/executor.go @@ -0,0 +1,89 @@ +package blockstm + +import ( + "context" + "fmt" +) + +// Executor fields are not mutated during execution. +type Executor struct { + ctx context.Context // context for cancellation + scheduler *Scheduler // scheduler for task management + txExecutor TxExecutor // callback to actually execute a transaction + mvMemory *MVMemory // multi-version memory for the executor + + // index of the executor, used for debugging output + i int +} + +func NewExecutor( + ctx context.Context, + scheduler *Scheduler, + txExecutor TxExecutor, + mvMemory *MVMemory, + i int, +) *Executor { + return &Executor{ + ctx: ctx, + scheduler: scheduler, + txExecutor: txExecutor, + mvMemory: mvMemory, + i: i, + } +} + +// Run executes all tasks until completion +// Invariant `num_active_tasks`: +// - `NextTask` increases it if returns a valid task. +// - `TryExecute` and `NeedsReexecution` don't change it if it returns a new valid task to run, +// otherwise it decreases it. +func (e *Executor) Run() error { + var kind TaskKind + version := InvalidTxnVersion + for !e.scheduler.Done() { + if !version.Valid() { + // check for cancellation + select { + case <-e.ctx.Done(): + return nil + default: + } + + version, kind = e.scheduler.NextTask() + continue + } + + switch kind { + case TaskKindExecution: + version, kind = e.TryExecute(version) + case TaskKindValidation: + version, kind = e.NeedsReexecution(version) + default: + return fmt.Errorf("unknown task kind %v", kind) + } + } + return nil +} + +func (e *Executor) TryExecute(version TxnVersion) (TxnVersion, TaskKind) { + e.scheduler.executedTxns.Add(1) + view := e.execute(version.Index) + wroteNewLocation := e.mvMemory.Record(version, view) + return e.scheduler.FinishExecution(version, wroteNewLocation) +} + +func (e *Executor) NeedsReexecution(version TxnVersion) (TxnVersion, TaskKind) { + e.scheduler.validatedTxns.Add(1) + valid := e.mvMemory.ValidateReadSet(version.Index) + aborted := !valid && e.scheduler.TryValidationAbort(version) + if aborted { + e.mvMemory.ConvertWritesToEstimates(version.Index) + } + return e.scheduler.FinishValidation(version.Index, aborted) +} + +func (e *Executor) execute(txn TxnIndex) *MultiMVMemoryView { + view := e.mvMemory.View(txn) + e.txExecutor(txn, view) + return view +} diff --git a/blockstm/memdb.go b/blockstm/memdb.go new file mode 100644 index 000000000000..9fa95b5fc19d --- /dev/null +++ b/blockstm/memdb.go @@ -0,0 +1,218 @@ +package blockstm + +import ( + "io" + + "github.com/tidwall/btree" + + "cosmossdk.io/store/cachekv" + "cosmossdk.io/store/tracekv" + storetypes "cosmossdk.io/store/types" + + "github.com/cosmos/cosmos-sdk/blockstm/tree" +) + +type ( + MemDB = GMemDB[[]byte] + ObjMemDB = GMemDB[any] +) + +var ( + _ storetypes.KVStore = (*MemDB)(nil) + _ storetypes.ObjKVStore = (*ObjMemDB)(nil) +) + +// NewMemDB constructs a new in memory store over a []byte value type. +func NewMemDB() *MemDB { + return NewGMemDB(storetypes.BytesIsZero, storetypes.BytesValueLen) +} + +// NewObjMemDB constructs a new in memory store over a generic any type. +func NewObjMemDB() *ObjMemDB { + return NewGMemDB(storetypes.AnyIsZero, storetypes.AnyValueLen) +} + +// GMemDB is a generic implementation of an in memory Store backed by tidwall/btree. +type GMemDB[V any] struct { + btree.BTreeG[memdbItem[V]] + isZero func(V) bool + valueLen func(V) int +} + +// NewGMemDB is the generic constructor for a GMemDB. +func NewGMemDB[V any]( + isZero func(V) bool, + valueLen func(V) int, +) *GMemDB[V] { + return &GMemDB[V]{ + BTreeG: *btree.NewBTreeG[memdbItem[V]](tree.KeyItemLess), + isZero: isZero, + valueLen: valueLen, + } +} + +// NewGMemDBNonConcurrent returns a new BTree which is not concurrency safe. +func NewGMemDBNonConcurrent[V any]( + isZero func(V) bool, + valueLen func(V) int, +) *GMemDB[V] { + return &GMemDB[V]{ + BTreeG: *btree.NewBTreeGOptions[memdbItem[V]](tree.KeyItemLess, btree.Options{ + NoLocks: true, + }), + isZero: isZero, + valueLen: valueLen, + } +} + +func (db *GMemDB[V]) Scan(cb func(key Key, value V) bool) { + db.BTreeG.Scan(func(item memdbItem[V]) bool { + return cb(item.key, item.value) + }) +} + +func (db *GMemDB[V]) Get(key []byte) V { + item, ok := db.BTreeG.Get(memdbItem[V]{key: key}) + if !ok { + var empty V + return empty + } + return item.value +} + +func (db *GMemDB[V]) Has(key []byte) bool { + return !db.isZero(db.Get(key)) +} + +func (db *GMemDB[V]) Set(key []byte, value V) { + if db.isZero(value) { + panic("nil value not allowed") + } + db.BTreeG.Set(memdbItem[V]{key: key, value: value}) +} + +func (db *GMemDB[V]) Delete(key []byte) { + db.BTreeG.Delete(memdbItem[V]{key: key}) +} + +// OverlayGet returns a value from the btree and true if we found a value. +// When used as an overlay (e.g. WriteSet), it stores the `nil` value to represent deleted keys, +// so we return separate bool value for found status. +func (db *GMemDB[V]) OverlayGet(key Key) (V, bool) { + item, ok := db.BTreeG.Get(memdbItem[V]{key: key}) + if !ok { + var zero V + return zero, false + } + return item.value, true +} + +// OverlaySet sets a value in the btree +// When used as an overlay (e.g. WriteSet), it stores the `nil` value to represent deleted keys, +func (db *GMemDB[V]) OverlaySet(key Key, value V) { + db.BTreeG.Set(memdbItem[V]{key: key, value: value}) +} + +func (db *GMemDB[V]) Iterator(start, end []byte) storetypes.GIterator[V] { + return db.iterator(start, end, true) +} + +func (db *GMemDB[V]) ReverseIterator(start, end []byte) storetypes.GIterator[V] { + return db.iterator(start, end, false) +} + +func (db *GMemDB[V]) iterator(start, end Key, ascending bool) storetypes.GIterator[V] { + return NewMemDBIterator(start, end, db.Iter(), ascending) +} + +func (db *GMemDB[V]) GetStoreType() storetypes.StoreType { + return storetypes.StoreTypeIAVL +} + +// CacheWrap implements types.KVStore. +func (db *GMemDB[V]) CacheWrap() storetypes.CacheWrap { + return cachekv.NewGStore(db, db.isZero, db.valueLen) +} + +// CacheWrapWithTrace implements types.KVStore. +func (db *GMemDB[V]) CacheWrapWithTrace(w io.Writer, tc storetypes.TraceContext) storetypes.CacheWrap { + if store, ok := any(db).(*GMemDB[[]byte]); ok { + return cachekv.NewGStore(tracekv.NewStore(store, w, tc), store.isZero, store.valueLen) + } + return db.CacheWrap() +} + +// MemDBIterator wraps a generic BTreeIteratorG over a memdbItem. +// It is used as an iterator over a GMemDB implementation. +type MemDBIterator[V any] struct { + tree.BTreeIteratorG[memdbItem[V]] +} + +var _ storetypes.Iterator = (*MemDBIterator[[]byte])(nil) + +func NewMemDBIterator[V any](start, end Key, iter btree.IterG[memdbItem[V]], ascending bool) *MemDBIterator[V] { + return &MemDBIterator[V]{*tree.NewBTreeIteratorG( + memdbItem[V]{key: start}, + memdbItem[V]{key: end}, + iter, + ascending, + )} +} + +// NewNoopIterator constructs a storetypes.GIterator with an invalidated wrapped iterator. +func NewNoopIterator[V any](start, end Key, ascending bool) storetypes.GIterator[V] { + return &MemDBIterator[V]{tree.NewNoopBTreeIteratorG[memdbItem[V]]( + start, + end, + ascending, + false, + )} +} + +func (it *MemDBIterator[V]) Value() V { + return it.Item().value +} + +type memdbItem[V any] struct { + key Key + value V +} + +var _ tree.KeyItem = memdbItem[[]byte]{} + +func (item memdbItem[V]) GetKey() []byte { + return item.key +} + +type MultiMemDB struct { + dbs map[storetypes.StoreKey]storetypes.Store +} + +var _ MultiStore = (*MultiMemDB)(nil) + +func NewMultiMemDB(stores map[storetypes.StoreKey]int) *MultiMemDB { + dbs := make(map[storetypes.StoreKey]storetypes.Store, len(stores)) + for name := range stores { + switch name.(type) { + case *storetypes.ObjectStoreKey: + dbs[name] = NewObjMemDB() + default: + dbs[name] = NewMemDB() + } + } + return &MultiMemDB{ + dbs: dbs, + } +} + +func (mmdb *MultiMemDB) GetStore(store storetypes.StoreKey) storetypes.Store { + return mmdb.dbs[store] +} + +func (mmdb *MultiMemDB) GetKVStore(store storetypes.StoreKey) storetypes.KVStore { + return mmdb.GetStore(store).(storetypes.KVStore) +} + +func (mmdb *MultiMemDB) GetObjKVStore(store storetypes.StoreKey) storetypes.ObjKVStore { + return mmdb.GetStore(store).(storetypes.ObjKVStore) +} diff --git a/blockstm/memdb_test.go b/blockstm/memdb_test.go new file mode 100644 index 000000000000..85a529a2b064 --- /dev/null +++ b/blockstm/memdb_test.go @@ -0,0 +1,63 @@ +package blockstm + +import ( + "testing" + + "github.com/test-go/testify/require" + + "cosmossdk.io/store/cachekv" + storetypes "cosmossdk.io/store/types" +) + +type ( + foo struct { + a bool + b bool + } + bar struct { + one int + two int + } +) + +func TestObjMemDB(t *testing.T) { + t.Parallel() + obj1 := foo{true, true} + obj2 := bar{1, 2} + storeKey := storetypes.NewObjectStoreKey("foobar") + + // attach to a new multistore + mmdb := NewMultiMemDB(map[storetypes.StoreKey]int{storeKey: 0}) + + // get the memdb + storage := mmdb.GetObjKVStore(storeKey) + + require.Equal(t, storetypes.StoreTypeIAVL, storage.GetStoreType()) + + // initial value + storage.Set([]byte("foo"), obj1) + storage.Set([]byte("bar"), obj2) + + require.True(t, storage.Has([]byte("foo"))) + require.True(t, storage.Has([]byte("bar"))) + require.False(t, storage.Has([]byte("baz"))) + require.Equal(t, storage.Get([]byte("foo")), obj1) + require.Equal(t, storage.Get([]byte("bar")), obj2) +} + +func TestCacheWraps(t *testing.T) { + t.Parallel() + storeKey := storetypes.NewObjectStoreKey("foobar") + + // attach to a new multistore + mmdb := NewMultiMemDB(map[storetypes.StoreKey]int{storeKey: 0}) + + // get the memdb + storage := mmdb.GetObjKVStore(storeKey) + // attempt to cachewrap + cacheWrapper := storage.CacheWrap() + require.IsType(t, &cachekv.GStore[any]{}, cacheWrapper) + + cacheWrappedWithTrace := storage.CacheWrapWithTrace(nil, nil) + require.IsType(t, &cachekv.GStore[any]{}, cacheWrappedWithTrace) +} diff --git a/blockstm/mergeiterator.go b/blockstm/mergeiterator.go new file mode 100644 index 000000000000..a0b55506be56 --- /dev/null +++ b/blockstm/mergeiterator.go @@ -0,0 +1,247 @@ +package blockstm + +import ( + "bytes" + "errors" + + "cosmossdk.io/store/types" +) + +// cacheMergeIterator merges a parent Iterator and a cache Iterator. +// The cache iterator may return nil keys to signal that an item +// had been deleted (but not deleted in the parent). +// If the cache iterator has the same key as the parent, the +// cache shadows (overrides) the parent. +// +// TODO: Optimize by memoizing. +type cacheMergeIterator[V any] struct { + parent types.GIterator[V] + cache types.GIterator[V] + onClose func(types.GIterator[V]) + isZero func(V) bool + + ascending bool + valid bool +} + +var _ types.Iterator = (*cacheMergeIterator[[]byte])(nil) + +func NewCacheMergeIterator[V any]( + parent, cache types.GIterator[V], + ascending bool, onClose func(types.GIterator[V]), + isZero func(V) bool, +) types.GIterator[V] { + iter := &cacheMergeIterator[V]{ + parent: parent, + cache: cache, + ascending: ascending, + onClose: onClose, + isZero: isZero, + } + + iter.valid = iter.skipUntilExistsOrInvalid() + return iter +} + +// Domain implements Iterator. +// Returns parent domain because cache and parent domains are the same. +func (iter *cacheMergeIterator[V]) Domain() (start, end []byte) { + return iter.parent.Domain() +} + +// Valid implements Iterator. +func (iter *cacheMergeIterator[V]) Valid() bool { + return iter.valid +} + +// Next implements Iterator +func (iter *cacheMergeIterator[V]) Next() { + iter.assertValid() + + switch { + case !iter.parent.Valid(): + // If parent is invalid, get the next cache item. + iter.cache.Next() + case !iter.cache.Valid(): + // If cache is invalid, get the next parent item. + iter.parent.Next() + default: + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + switch iter.compare(keyP, keyC) { + case -1: // parent < cache + iter.parent.Next() + case 0: // parent == cache + iter.parent.Next() + iter.cache.Next() + case 1: // parent > cache + iter.cache.Next() + } + } + iter.valid = iter.skipUntilExistsOrInvalid() +} + +// Key implements Iterator +func (iter *cacheMergeIterator[V]) Key() []byte { + iter.assertValid() + + // If parent is invalid, get the cache key. + if !iter.parent.Valid() { + return iter.cache.Key() + } + + // If cache is invalid, get the parent key. + if !iter.cache.Valid() { + return iter.parent.Key() + } + + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + + cmp := iter.compare(keyP, keyC) + switch cmp { + case -1: // parent < cache + return keyP + case 0: // parent == cache + return keyP + case 1: // parent > cache + return keyC + default: + panic("invalid compare result") + } +} + +// Value implements Iterator +func (iter *cacheMergeIterator[V]) Value() V { + iter.assertValid() + + // If parent is invalid, get the cache value. + if !iter.parent.Valid() { + return iter.cache.Value() + } + + // If cache is invalid, get the parent value. + if !iter.cache.Valid() { + return iter.parent.Value() + } + + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + + cmp := iter.compare(keyP, keyC) + switch cmp { + case -1: // parent < cache + return iter.parent.Value() + case 0: // parent == cache + return iter.cache.Value() + case 1: // parent > cache + return iter.cache.Value() + default: + panic("invalid comparison result") + } +} + +// Close implements Iterator +func (iter *cacheMergeIterator[V]) Close() error { + if iter.onClose != nil { + iter.onClose(iter) + } + + err1 := iter.cache.Close() + if err := iter.parent.Close(); err != nil { + return err + } + + return err1 +} + +// Error returns an error if the cacheMergeIterator is invalid defined by the +// Valid method. +func (iter *cacheMergeIterator[V]) Error() error { + if !iter.Valid() { + return errors.New("invalid cacheMergeIterator") + } + + return nil +} + +// If not valid, panics. +// NOTE: May have side-effect of iterating over cache. +func (iter *cacheMergeIterator[V]) assertValid() { + if err := iter.Error(); err != nil { + panic(err) + } +} + +// Like bytes.Compare but opposite if not ascending. +func (iter *cacheMergeIterator[V]) compare(a, b []byte) int { + if iter.ascending { + return bytes.Compare(a, b) + } + + return bytes.Compare(a, b) * -1 +} + +// Skip all delete-items from the cache w/ `key < until`. After this function, +// current cache item is a non-delete-item, or `until <= key`. +// If the current cache item is not a delete item, does nothing. +// If `until` is nil, there is no limit, and cache may end up invalid. +// CONTRACT: cache is valid. +func (iter *cacheMergeIterator[V]) skipCacheDeletes(until []byte) { + for iter.cache.Valid() && + iter.isZero(iter.cache.Value()) && + (until == nil || iter.compare(iter.cache.Key(), until) < 0) { + iter.cache.Next() + } +} + +// Fast forwards cache (or parent+cache in case of deleted items) until current +// item exists, or until iterator becomes invalid. +// Returns whether the iterator is valid. +func (iter *cacheMergeIterator[V]) skipUntilExistsOrInvalid() bool { + for { + // If parent is invalid, fast-forward cache. + if !iter.parent.Valid() { + iter.skipCacheDeletes(nil) + return iter.cache.Valid() + } + // Parent is valid. + + if !iter.cache.Valid() { + return true + } + // Parent is valid, cache is valid. + + // Compare parent and cache. + keyP := iter.parent.Key() + keyC := iter.cache.Key() + + switch iter.compare(keyP, keyC) { + case -1: // parent < cache. + return true + + case 0: // parent == cache. + // Skip over if cache item is a delete. + valueC := iter.cache.Value() + if iter.isZero(valueC) { + iter.parent.Next() + iter.cache.Next() + + continue + } + // Cache is not a delete. + + return true // cache exists. + case 1: // cache < parent + // Skip over if cache item is a delete. + valueC := iter.cache.Value() + if iter.isZero(valueC) { + iter.skipCacheDeletes(keyP) + continue + } + // Cache is not a delete. + + return true // cache exists. + } + } +} diff --git a/blockstm/mock_block.go b/blockstm/mock_block.go new file mode 100644 index 000000000000..6649e2166ab7 --- /dev/null +++ b/blockstm/mock_block.go @@ -0,0 +1,169 @@ +package blockstm + +import ( + cryptorand "crypto/rand" + "encoding/binary" + "fmt" + "strings" + + "github.com/cometbft/cometbft/crypto/secp256k1" + + storetypes "cosmossdk.io/store/types" +) + +var ( + StoreKeyAuth = storetypes.NewKVStoreKey("acc") + StoreKeyBank = storetypes.NewKVStoreKey("bank") +) + +type Tx func(MultiStore) error + +type MockBlock struct { + Txs []Tx + Results []error +} + +func NewMockBlock(txs []Tx) *MockBlock { + return &MockBlock{ + Txs: txs, + Results: make([]error, len(txs)), + } +} + +func (b *MockBlock) Size() int { + return len(b.Txs) +} + +func (b *MockBlock) ExecuteTx(txn TxnIndex, store MultiStore) { + b.Results[txn] = b.Txs[txn](store) +} + +// Simulated transaction logic for tests and benchmarks + +// NoopTx verifies a signature and increases the nonce of the sender +func NoopTx(i int, sender string) Tx { + verifySig := genRandomSignature() + return func(store MultiStore) error { + verifySig() + return increaseNonce(i, sender, store.GetKVStore(StoreKeyAuth)) + } +} + +func BankTransferTx(i int, sender, receiver string, amount uint64) Tx { + base := NoopTx(i, sender) + return func(store MultiStore) error { + if err := base(store); err != nil { + return err + } + + return bankTransfer(i, sender, receiver, amount, store.GetKVStore(StoreKeyBank)) + } +} + +func IterateTx(i int, sender, receiver string, amount uint64) Tx { + base := BankTransferTx(i, sender, receiver, amount) + return func(store MultiStore) error { + if err := base(store); err != nil { + return err + } + + // find a nearby account, do a bank transfer + accStore := store.GetKVStore(StoreKeyAuth) + + { + it := accStore.Iterator([]byte("nonce"+sender), nil) + defer it.Close() + + var j int + for ; it.Valid(); it.Next() { + j++ + if j > 5 { + recipient := strings.TrimPrefix(string(it.Key()), "nonce") + return bankTransfer(i, sender, recipient, amount, store.GetKVStore(StoreKeyBank)) + } + } + } + + { + it := accStore.ReverseIterator([]byte("nonce"), []byte("nonce"+sender)) + defer it.Close() + + var j int + for ; it.Valid(); it.Next() { + j++ + if j > 5 { + recipient := strings.TrimPrefix(string(it.Key()), "nonce") + return bankTransfer(i, sender, recipient, amount, store.GetKVStore(StoreKeyBank)) + } + } + } + + return nil + } +} + +func genRandomSignature() func() { + privKey := secp256k1.GenPrivKey() + signBytes := make([]byte, 1024) + if _, err := cryptorand.Read(signBytes); err != nil { + panic(err) + } + sig, _ := privKey.Sign(signBytes) + pubKey := privKey.PubKey() + + return func() { + pubKey.VerifySignature(signBytes, sig) + } +} + +func increaseNonce(i int, sender string, store storetypes.KVStore) error { + nonceKey := []byte("nonce" + sender) + var nonce uint64 + v := store.Get(nonceKey) + if v != nil { + nonce = binary.BigEndian.Uint64(v) + } + + var bz [8]byte + binary.BigEndian.PutUint64(bz[:], nonce+1) + store.Set(nonceKey, bz[:]) + + v = store.Get(nonceKey) + if binary.BigEndian.Uint64(v) != nonce+1 { + return fmt.Errorf("nonce not incremented: %d", binary.BigEndian.Uint64(v)) + } + + return nil +} + +func bankTransfer(i int, sender, receiver string, amount uint64, store storetypes.KVStore) error { + senderKey := []byte("balance" + sender) + receiverKey := []byte("balance" + receiver) + + var senderBalance, receiverBalance uint64 + v := store.Get(senderKey) + if v != nil { + senderBalance = binary.BigEndian.Uint64(v) + } + + v = store.Get(receiverKey) + if v != nil { + receiverBalance = binary.BigEndian.Uint64(v) + } + + if senderBalance >= amount { + // avoid the failure + senderBalance -= amount + } + + receiverBalance += amount + + var bz1, bz2 [8]byte + binary.BigEndian.PutUint64(bz1[:], senderBalance) + store.Set(senderKey, bz1[:]) + + binary.BigEndian.PutUint64(bz2[:], receiverBalance) + store.Set(receiverKey, bz2[:]) + + return nil +} diff --git a/blockstm/multimvview.go b/blockstm/multimvview.go new file mode 100644 index 000000000000..af7523b50d1f --- /dev/null +++ b/blockstm/multimvview.go @@ -0,0 +1,65 @@ +package blockstm + +import storetypes "cosmossdk.io/store/types" + +const ViewsPreAllocate = 4 + +// MultiMVMemoryView don't need to be thread-safe, there's a dedicated instance for each tx execution. +type MultiMVMemoryView struct { + stores map[storetypes.StoreKey]int + views map[storetypes.StoreKey]MVView + newMVView func(storetypes.StoreKey, TxnIndex) MVView + txn TxnIndex +} + +var _ MultiStore = (*MultiMVMemoryView)(nil) + +func NewMultiMVMemoryView( + stores map[storetypes.StoreKey]int, + newMVView func(storetypes.StoreKey, TxnIndex) MVView, + txn TxnIndex, +) *MultiMVMemoryView { + return &MultiMVMemoryView{ + stores: stores, + views: make(map[storetypes.StoreKey]MVView, ViewsPreAllocate), + newMVView: newMVView, + txn: txn, + } +} + +func (mv *MultiMVMemoryView) getViewOrInit(name storetypes.StoreKey) MVView { + view, ok := mv.views[name] + if !ok { + view = mv.newMVView(name, mv.txn) + mv.views[name] = view + } + return view +} + +func (mv *MultiMVMemoryView) GetStore(name storetypes.StoreKey) storetypes.Store { + return mv.getViewOrInit(name) +} + +func (mv *MultiMVMemoryView) GetKVStore(name storetypes.StoreKey) storetypes.KVStore { + return mv.GetStore(name).(storetypes.KVStore) +} + +func (mv *MultiMVMemoryView) GetObjKVStore(name storetypes.StoreKey) storetypes.ObjKVStore { + return mv.GetStore(name).(storetypes.ObjKVStore) +} + +func (mv *MultiMVMemoryView) ReadSet() *MultiReadSet { + rs := make(MultiReadSet, len(mv.views)) + for key, view := range mv.views { + rs[mv.stores[key]] = view.ReadSet() + } + return &rs +} + +func (mv *MultiMVMemoryView) ApplyWriteSet(version TxnVersion) MultiLocations { + newLocations := make(MultiLocations, len(mv.views)) + for key, view := range mv.views { + newLocations[mv.stores[key]] = view.ApplyWriteSet(version) + } + return newLocations +} diff --git a/blockstm/mvdata.go b/blockstm/mvdata.go new file mode 100644 index 000000000000..7934642e0ee9 --- /dev/null +++ b/blockstm/mvdata.go @@ -0,0 +1,235 @@ +package blockstm + +import ( + "bytes" + + storetypes "cosmossdk.io/store/types" + + "github.com/cosmos/cosmos-sdk/blockstm/tree" +) + +const ( + OuterBTreeDegree = 4 // Since we do copy-on-write a lot, smaller degree means smaller allocations + InnerBTreeDegree = 4 +) + +type MVData = GMVData[[]byte] + +func NewMVData() *MVData { + return NewGMVData(storetypes.BytesIsZero, storetypes.BytesValueLen) +} + +type GMVData[V any] struct { + tree.BTree[dataItem[V]] + isZero func(V) bool + valueLen func(V) int +} + +func NewMVStore(key storetypes.StoreKey) MVStore { + switch key.(type) { + case *storetypes.ObjectStoreKey: + return NewGMVData(storetypes.AnyIsZero, storetypes.AnyValueLen) + default: + return NewGMVData(storetypes.BytesIsZero, storetypes.BytesValueLen) + } +} + +func NewGMVData[V any](isZero func(V) bool, valueLen func(V) int) *GMVData[V] { + return &GMVData[V]{ + BTree: *tree.NewBTree(tree.KeyItemLess[dataItem[V]], OuterBTreeDegree), + isZero: isZero, + valueLen: valueLen, + } +} + +// getTree returns `nil` if not found +func (d *GMVData[V]) getTree(key Key) *tree.BTree[secondaryDataItem[V]] { + outer, _ := d.Get(dataItem[V]{Key: key}) + return outer.Tree +} + +// getTreeOrDefault set a new tree atomically if not found. +func (d *GMVData[V]) getTreeOrDefault(key Key) *tree.BTree[secondaryDataItem[V]] { + return d.GetOrDefault(dataItem[V]{Key: key}, func(item *dataItem[V]) { + if item.Tree == nil { + item.Tree = tree.NewBTree(secondaryLesser[V], InnerBTreeDegree) + } + }).Tree +} + +func (d *GMVData[V]) Write(key Key, value V, version TxnVersion) { + tree := d.getTreeOrDefault(key) + tree.Set(secondaryDataItem[V]{Index: version.Index, Incarnation: version.Incarnation, Value: value}) +} + +func (d *GMVData[V]) WriteEstimate(key Key, txn TxnIndex) { + tree := d.getTreeOrDefault(key) + tree.Set(secondaryDataItem[V]{Index: txn, Estimate: true}) +} + +func (d *GMVData[V]) Delete(key Key, txn TxnIndex) { + tree := d.getTreeOrDefault(key) + tree.Delete(secondaryDataItem[V]{Index: txn}) +} + +// Read returns the value and the version of the value that's less than the given txn. +// If the key is not found, returns `(nil, InvalidTxnVersion, false)`. +// If the key is found but value is an estimate, returns `(nil, BlockingTxn, true)`. +// If the key is found, returns `(value, version, false)`, `value` can be `nil` which means deleted. +func (d *GMVData[V]) Read(key Key, txn TxnIndex) (V, TxnVersion, bool) { + var zero V + if txn == 0 { + return zero, InvalidTxnVersion, false + } + + tree := d.getTree(key) + if tree == nil { + return zero, InvalidTxnVersion, false + } + + // find the closing txn that's less than the given txn + item, ok := seekClosestTxn(tree, txn) + if !ok { + return zero, InvalidTxnVersion, false + } + + return item.Value, item.Version(), item.Estimate +} + +func (d *GMVData[V]) Iterator( + opts IteratorOptions, txn TxnIndex, + waitFn func(TxnIndex), +) *MVIterator[V] { + return NewMVIterator(opts, txn, d.Iter(), waitFn) +} + +// ValidateReadSet validates the read descriptors, +// returns true if valid. +func (d *GMVData[V]) ValidateReadSet(txn TxnIndex, rs *ReadSet) bool { + for _, desc := range rs.Reads { + _, version, estimate := d.Read(desc.Key, txn) + if estimate { + // previously read entry from data, now ESTIMATE + return false + } + if version != desc.Version { + // previously read entry from data, now NOT_FOUND, + // or read some entry, but not the same version as before + return false + } + } + + for _, desc := range rs.Iterators { + if !d.validateIterator(desc, txn) { + return false + } + } + + return true +} + +// validateIterator validates the iteration descriptor by replaying and compare the recorded reads. +// returns true if valid. +func (d *GMVData[V]) validateIterator(desc IteratorDescriptor, txn TxnIndex) bool { + it := NewMVIterator(desc.IteratorOptions, txn, d.Iter(), nil) + defer it.Close() + + var i int + for ; it.Valid(); it.Next() { + if desc.Stop != nil { + if BytesBeyond(it.Key(), desc.Stop, desc.Ascending) { + break + } + } + + if i >= len(desc.Reads) { + return false + } + + read := desc.Reads[i] + if read.Version != it.Version() || !bytes.Equal(read.Key, it.Key()) { + return false + } + + i++ + } + + // we read an estimate value, fail the validation. + if it.ReadEstimateValue() { + return false + } + + return i == len(desc.Reads) +} + +func (d *GMVData[V]) Snapshot() (snapshot []GKVPair[V]) { + d.SnapshotTo(func(key Key, value V) bool { + snapshot = append(snapshot, GKVPair[V]{key, value}) + return true + }) + return snapshot +} + +func (d *GMVData[V]) SnapshotTo(cb func(Key, V) bool) { + d.Scan(func(outer dataItem[V]) bool { + item, ok := outer.Tree.Max() + if !ok { + return true + } + + if item.Estimate { + return true + } + + return cb(outer.Key, item.Value) + }) +} + +func (d *GMVData[V]) SnapshotToStore(store storetypes.Store) { + kv := store.(storetypes.GKVStore[V]) + d.SnapshotTo(func(key Key, value V) bool { + if d.isZero(value) { + kv.Delete(key) + } else { + kv.Set(key, value) + } + return true + }) +} + +type GKVPair[V any] struct { + Key Key + Value V +} +type KVPair = GKVPair[[]byte] + +type dataItem[V any] struct { + Key Key + Tree *tree.BTree[secondaryDataItem[V]] +} + +var _ tree.KeyItem = dataItem[[]byte]{} + +func (item dataItem[V]) GetKey() []byte { + return item.Key +} + +type secondaryDataItem[V any] struct { + Index TxnIndex + Incarnation Incarnation + Value V + Estimate bool +} + +func secondaryLesser[V any](a, b secondaryDataItem[V]) bool { + return a.Index < b.Index +} + +func (item secondaryDataItem[V]) Version() TxnVersion { + return TxnVersion{Index: item.Index, Incarnation: item.Incarnation} +} + +// seekClosestTxn returns the closest txn that's less than the given txn. +func seekClosestTxn[V any](tree *tree.BTree[secondaryDataItem[V]], txn TxnIndex) (secondaryDataItem[V], bool) { + return tree.ReverseSeek(secondaryDataItem[V]{Index: txn - 1}) +} diff --git a/blockstm/mvdata_test.go b/blockstm/mvdata_test.go new file mode 100644 index 000000000000..1cf8dd3a8b38 --- /dev/null +++ b/blockstm/mvdata_test.go @@ -0,0 +1,111 @@ +package blockstm + +import ( + "errors" + "fmt" + "testing" + + "github.com/test-go/testify/require" +) + +func TestEmptyMVData(t *testing.T) { + data := NewMVData() + value, _, estimate := data.Read([]byte("a"), 1) + require.False(t, estimate) + require.Nil(t, value) +} + +func TestMVData(t *testing.T) { + data := NewMVData() + + // read closest version + data.Write([]byte("a"), []byte("1"), TxnVersion{Index: 1, Incarnation: 1}) + data.Write([]byte("a"), []byte("2"), TxnVersion{Index: 2, Incarnation: 1}) + data.Write([]byte("a"), []byte("3"), TxnVersion{Index: 3, Incarnation: 1}) + data.Write([]byte("b"), []byte("2"), TxnVersion{Index: 2, Incarnation: 1}) + + // read closest version + value, _, estimate := data.Read([]byte("a"), 1) + require.False(t, estimate) + require.Nil(t, value) + + // read closest version + value, version, estimate := data.Read([]byte("a"), 4) + require.False(t, estimate) + require.Equal(t, []byte("3"), value) + require.Equal(t, TxnVersion{Index: 3, Incarnation: 1}, version) + + // read closest version + value, version, estimate = data.Read([]byte("a"), 3) + require.False(t, estimate) + require.Equal(t, []byte("2"), value) + require.Equal(t, TxnVersion{Index: 2, Incarnation: 1}, version) + + // read closest version + value, version, estimate = data.Read([]byte("b"), 3) + require.False(t, estimate) + require.Equal(t, []byte("2"), value) + require.Equal(t, TxnVersion{Index: 2, Incarnation: 1}, version) + + // new incarnation overrides old + data.Write([]byte("a"), []byte("3-2"), TxnVersion{Index: 3, Incarnation: 2}) + value, version, estimate = data.Read([]byte("a"), 4) + require.False(t, estimate) + require.Equal(t, []byte("3-2"), value) + require.Equal(t, TxnVersion{Index: 3, Incarnation: 2}, version) + + // read estimate + data.WriteEstimate([]byte("a"), 3) + _, version, estimate = data.Read([]byte("a"), 4) + require.True(t, estimate) + require.Equal(t, TxnIndex(3), version.Index) + + // delete value + data.Delete([]byte("a"), 3) + value, version, estimate = data.Read([]byte("a"), 4) + require.False(t, estimate) + require.Equal(t, []byte("2"), value) + require.Equal(t, TxnVersion{Index: 2, Incarnation: 1}, version) + + data.Delete([]byte("b"), 2) + value, _, estimate = data.Read([]byte("b"), 4) + require.False(t, estimate) + require.Nil(t, value) +} + +func TestReadErrConversion(t *testing.T) { + err := fmt.Errorf("wrap: %w", ErrReadError{BlockingTxn: 1}) + var readErr ErrReadError + require.True(t, errors.As(err, &readErr)) + require.Equal(t, TxnIndex(1), readErr.BlockingTxn) +} + +func TestSnapshot(t *testing.T) { + storage := NewMemDB() + // initial value + storage.Set([]byte("a"), []byte("0")) + storage.Set([]byte("d"), []byte("0")) + + data := NewMVData() + // read closest version + data.Write([]byte("a"), []byte("1"), TxnVersion{Index: 1, Incarnation: 1}) + data.Write([]byte("a"), []byte("2"), TxnVersion{Index: 2, Incarnation: 1}) + data.Write([]byte("a"), []byte("3"), TxnVersion{Index: 3, Incarnation: 1}) + data.Write([]byte("b"), []byte("2"), TxnVersion{Index: 2, Incarnation: 1}) + data.Write([]byte("d"), []byte("1"), TxnVersion{Index: 2, Incarnation: 1}) + // delete the key "d" in tx 3 + data.Write([]byte("d"), nil, TxnVersion{Index: 3, Incarnation: 1}) + data.WriteEstimate([]byte("c"), 2) + + require.Equal(t, []KVPair{ + {[]byte("a"), []byte("3")}, + {[]byte("b"), []byte("2")}, + {[]byte("d"), nil}, + }, data.Snapshot()) + + data.SnapshotToStore(storage) + require.Equal(t, []byte("3"), storage.Get([]byte("a"))) + require.Equal(t, []byte("2"), storage.Get([]byte("b"))) + require.Nil(t, storage.Get([]byte("d"))) + require.Equal(t, 2, storage.Len()) +} diff --git a/blockstm/mviterator.go b/blockstm/mviterator.go new file mode 100644 index 000000000000..ce2f2fdfcc37 --- /dev/null +++ b/blockstm/mviterator.go @@ -0,0 +1,126 @@ +package blockstm + +import ( + "github.com/tidwall/btree" + + storetypes "cosmossdk.io/store/types" + + "github.com/cosmos/cosmos-sdk/blockstm/tree" +) + +// MVIterator is an iterator for a multi-versioned store. +type MVIterator[V any] struct { + tree.BTreeIteratorG[dataItem[V]] + txn TxnIndex + + // cache current found value and version + value V + version TxnVersion + + // record the observed reads during iteration during execution + reads []ReadDescriptor + // blocking call to wait for dependent transaction to finish, `nil` in validation mode + waitFn func(TxnIndex) + // signal the validation to fail + readEstimateValue bool +} + +var _ storetypes.Iterator = (*MVIterator[[]byte])(nil) + +func NewMVIterator[V any]( + opts IteratorOptions, txn TxnIndex, iter btree.IterG[dataItem[V]], + waitFn func(TxnIndex), +) *MVIterator[V] { + it := &MVIterator[V]{ + BTreeIteratorG: *tree.NewBTreeIteratorG( + dataItem[V]{Key: opts.Start}, + dataItem[V]{Key: opts.End}, + iter, + opts.Ascending, + ), + txn: txn, + waitFn: waitFn, + } + it.resolveValue() + return it +} + +// Executing returns if the iterator is running in execution mode. +func (it *MVIterator[V]) Executing() bool { + return it.waitFn != nil +} + +func (it *MVIterator[V]) Next() { + it.BTreeIteratorG.Next() + it.resolveValue() +} + +func (it *MVIterator[V]) Value() V { + return it.value +} + +func (it *MVIterator[V]) Version() TxnVersion { + return it.version +} + +func (it *MVIterator[V]) Reads() []ReadDescriptor { + return it.reads +} + +func (it *MVIterator[V]) ReadEstimateValue() bool { + return it.readEstimateValue +} + +// resolveValue skips the non-exist values in the iterator based on the txn index, and caches the first existing one. +func (it *MVIterator[V]) resolveValue() { + inner := &it.BTreeIteratorG + for ; inner.Valid(); inner.Next() { + v, ok := it.resolveValueInner(inner.Item().Tree) + if !ok { + // abort the iterator + it.Invalidate() + // signal the validation to fail + it.readEstimateValue = true + return + } + if v == nil { + continue + } + + it.value = v.Value + it.version = v.Version() + if it.Executing() { + it.reads = append(it.reads, ReadDescriptor{ + Key: inner.Item().Key, + Version: it.version, + }) + } + return + } +} + +// resolveValueInner loop until we find a value that is not an estimate, +// wait for dependency if gets an ESTIMATE. +// returns: +// - (nil, true) if the value is not found +// - (nil, false) if the value is an estimate and we should fail the validation +// - (v, true) if the value is found +func (it *MVIterator[V]) resolveValueInner(tree *tree.BTree[secondaryDataItem[V]]) (*secondaryDataItem[V], bool) { + for { + v, ok := seekClosestTxn(tree, it.txn) + if !ok { + return nil, true + } + + if v.Estimate { + if it.Executing() { + it.waitFn(v.Index) + continue + } + // in validation mode, it should fail validation immediately + return nil, false + } + + return &v, true + } +} diff --git a/blockstm/mvmemory.go b/blockstm/mvmemory.go new file mode 100644 index 000000000000..f93186695adb --- /dev/null +++ b/blockstm/mvmemory.go @@ -0,0 +1,149 @@ +package blockstm + +import ( + "sync/atomic" + + storetypes "cosmossdk.io/store/types" +) + +type ( + Locations []Key // keys are sorted + MultiLocations map[int]Locations +) + +// MVMemory implements `Algorithm 2 The MVMemory module` +type MVMemory struct { + storage MultiStore + scheduler *Scheduler + stores map[storetypes.StoreKey]int + data []MVStore + lastWrittenLocations []atomic.Pointer[MultiLocations] + lastReadSet []atomic.Pointer[MultiReadSet] +} + +func NewMVMemory( + block_size int, stores map[storetypes.StoreKey]int, + storage MultiStore, scheduler *Scheduler, +) *MVMemory { + return NewMVMemoryWithEstimates(block_size, stores, storage, scheduler, nil) +} + +func NewMVMemoryWithEstimates( + block_size int, stores map[storetypes.StoreKey]int, + storage MultiStore, scheduler *Scheduler, estimates []MultiLocations, +) *MVMemory { + data := make([]MVStore, len(stores)) + for key, i := range stores { + data[i] = NewMVStore(key) + } + + mv := &MVMemory{ + storage: storage, + scheduler: scheduler, + stores: stores, + data: data, + lastWrittenLocations: make([]atomic.Pointer[MultiLocations], block_size), + lastReadSet: make([]atomic.Pointer[MultiReadSet], block_size), + } + + // init with pre-estimates + for txn, est := range estimates { + mv.rcuUpdateWrittenLocations(TxnIndex(txn), est) + mv.ConvertWritesToEstimates(TxnIndex(txn)) + } + + return mv +} + +func (mv *MVMemory) Record(version TxnVersion, view *MultiMVMemoryView) bool { + newLocations := view.ApplyWriteSet(version) + wroteNewLocation := mv.rcuUpdateWrittenLocations(version.Index, newLocations) + mv.lastReadSet[version.Index].Store(view.ReadSet()) + return wroteNewLocation +} + +// newLocations are sorted +func (mv *MVMemory) rcuUpdateWrittenLocations(txn TxnIndex, newLocations MultiLocations) bool { + var wroteNewLocation bool + + prevLocations := mv.readLastWrittenLocations(txn) + for i, newLoc := range newLocations { + prevLoc, ok := prevLocations[i] + if !ok { + if len(newLocations[i]) > 0 { + wroteNewLocation = true + } + continue + } + + DiffOrderedList(prevLoc, newLoc, func(key Key, is_new bool) bool { + if is_new { + wroteNewLocation = true + } else { + mv.data[i].Delete(key, txn) + } + return true + }) + } + + // delete all the keys in un-touched stores + for i, prevLoc := range prevLocations { + if _, ok := newLocations[i]; ok { + continue + } + + for _, key := range prevLoc { + mv.data[i].Delete(key, txn) + } + } + + mv.lastWrittenLocations[txn].Store(&newLocations) + return wroteNewLocation +} + +func (mv *MVMemory) ConvertWritesToEstimates(txn TxnIndex) { + for i, locations := range mv.readLastWrittenLocations(txn) { + for _, key := range locations { + mv.data[i].WriteEstimate(key, txn) + } + } +} + +func (mv *MVMemory) ValidateReadSet(txn TxnIndex) bool { + // Invariant: at least one `Record` call has been made for `txn` + rs := *mv.lastReadSet[txn].Load() + for store, readSet := range rs { + if !mv.data[store].ValidateReadSet(txn, readSet) { + return false + } + } + return true +} + +func (mv *MVMemory) readLastWrittenLocations(txn TxnIndex) MultiLocations { + p := mv.lastWrittenLocations[txn].Load() + if p != nil { + return *p + } + return nil +} + +func (mv *MVMemory) WriteSnapshot(storage MultiStore) { + for name, i := range mv.stores { + mv.data[i].SnapshotToStore(storage.GetStore(name)) + } +} + +// View creates a view for a particular transaction. +func (mv *MVMemory) View(txn TxnIndex) *MultiMVMemoryView { + return NewMultiMVMemoryView(mv.stores, mv.newMVView, txn) +} + +func (mv *MVMemory) newMVView(name storetypes.StoreKey, txn TxnIndex) MVView { + i := mv.stores[name] + return NewMVView(i, mv.storage.GetStore(name), mv.GetMVStore(i), mv.scheduler, txn) +} + +func (mv *MVMemory) GetMVStore(i int) MVStore { + return mv.data[i] +} diff --git a/blockstm/mvmemory_test.go b/blockstm/mvmemory_test.go new file mode 100644 index 000000000000..ec7d3939aff1 --- /dev/null +++ b/blockstm/mvmemory_test.go @@ -0,0 +1,219 @@ +package blockstm + +import ( + "testing" + + "github.com/test-go/testify/require" + + storetypes "cosmossdk.io/store/types" +) + +func TestMVMemoryRecord(t *testing.T) { + stores := map[storetypes.StoreKey]int{StoreKeyAuth: 0} + storage := NewMultiMemDB(stores) + scheduler := NewScheduler(16) + mv := NewMVMemory(16, stores, storage, scheduler) + + var views []*MultiMVMemoryView + for i := TxnIndex(0); i < 3; i++ { + version := TxnVersion{i, 0} + view := mv.View(version.Index) + store := view.GetKVStore(StoreKeyAuth) + + _ = store.Get([]byte("a")) + _ = store.Get([]byte("d")) + store.Set([]byte("a"), []byte("1")) + store.Set([]byte("b"), []byte("1")) + store.Set([]byte("c"), []byte("1")) + + views = append(views, view) + } + + for i, view := range views { + wroteNewLocation := mv.Record(TxnVersion{TxnIndex(i), 0}, view) + require.True(t, wroteNewLocation) + } + + require.True(t, mv.ValidateReadSet(0)) + require.False(t, mv.ValidateReadSet(1)) + require.False(t, mv.ValidateReadSet(2)) + + // abort 2 and 3 + mv.ConvertWritesToEstimates(1) + mv.ConvertWritesToEstimates(2) + + resultCh := make(chan struct{}, 1) + go func() { + view := mv.View(3) + store := view.GetKVStore(StoreKeyAuth) + // will wait for tx 2 + store.Get([]byte("a")) + wroteNewLocation := mv.Record(TxnVersion{3, 1}, view) + require.False(t, wroteNewLocation) + require.True(t, mv.ValidateReadSet(3)) + resultCh <- struct{}{} + }() + + { + data := mv.GetMVStore(0).(*MVData) + value, version, estimate := data.Read(Key("a"), 1) + require.False(t, estimate) + require.Equal(t, []byte("1"), value) + require.Equal(t, TxnVersion{0, 0}, version) + + _, version, estimate = data.Read(Key("a"), 2) + require.True(t, estimate) + require.Equal(t, TxnIndex(1), version.Index) + + _, version, estimate = data.Read(Key("a"), 3) + require.True(t, estimate) + require.Equal(t, TxnIndex(2), version.Index) + } + + // rerun tx 1 + { + view := mv.View(1) + store := view.GetKVStore(StoreKeyAuth) + + _ = store.Get([]byte("a")) + _ = store.Get([]byte("d")) + store.Set([]byte("a"), []byte("2")) + store.Set([]byte("b"), []byte("2")) + store.Set([]byte("c"), []byte("2")) + + wroteNewLocation := mv.Record(TxnVersion{1, 1}, view) + require.False(t, wroteNewLocation) + require.True(t, mv.ValidateReadSet(1)) + } + + // rerun tx 2 + // don't write `c` this time + { + version := TxnVersion{2, 1} + view := mv.View(version.Index) + store := view.GetKVStore(StoreKeyAuth) + + _ = store.Get([]byte("a")) + _ = store.Get([]byte("d")) + store.Set([]byte("a"), []byte("3")) + store.Set([]byte("b"), []byte("3")) + + wroteNewLocation := mv.Record(version, view) + require.False(t, wroteNewLocation) + require.True(t, mv.ValidateReadSet(2)) + + scheduler.FinishExecution(version, wroteNewLocation) + + // wait for dependency to finish + <-resultCh + } + + // run tx 3 + { + view := mv.View(3) + store := view.GetKVStore(StoreKeyAuth) + + _ = store.Get([]byte("a")) + + wroteNewLocation := mv.Record(TxnVersion{3, 1}, view) + require.False(t, wroteNewLocation) + require.True(t, mv.ValidateReadSet(3)) + } + + { + data := mv.GetMVStore(0).(*MVData) + value, version, estimate := data.Read(Key("a"), 2) + require.False(t, estimate) + require.Equal(t, []byte("2"), value) + require.Equal(t, TxnVersion{1, 1}, version) + + value, version, estimate = data.Read(Key("a"), 3) + require.False(t, estimate) + require.Equal(t, []byte("3"), value) + require.Equal(t, TxnVersion{2, 1}, version) + + value, version, estimate = data.Read(Key("c"), 3) + require.False(t, estimate) + require.Equal(t, []byte("2"), value) + require.Equal(t, TxnVersion{1, 1}, version) + } +} + +func TestMVMemoryDelete(t *testing.T) { + nonceKey, balanceKey := []byte("nonce"), []byte("balance") + + stores := map[storetypes.StoreKey]int{StoreKeyAuth: 0, StoreKeyBank: 1} + storage := NewMultiMemDB(stores) + { + // genesis state + authStore := storage.GetKVStore(StoreKeyAuth) + authStore.Set(nonceKey, []byte{0}) + bankStore := storage.GetKVStore(StoreKeyBank) + bankStore.Set(balanceKey, []byte{100}) + } + scheduler := NewScheduler(16) + mv := NewMVMemory(16, stores, storage, scheduler) + + genMockTx := func(txNonce int) func(*MultiMVMemoryView) bool { + return func(view *MultiMVMemoryView) bool { + bankStore := view.GetKVStore(StoreKeyBank) + balance := int(bankStore.Get(balanceKey)[0]) + if balance < 50 { + // insurfficient balance + return false + } + + authStore := view.GetKVStore(StoreKeyAuth) + nonce := int(authStore.Get(nonceKey)[0]) + // do a set no matter what + authStore.Set(nonceKey, []byte{byte(nonce)}) + if nonce != txNonce { + // invalid nonce + return false + } + + authStore.Set(nonceKey, []byte{byte(nonce + 1)}) + bankStore.Set(balanceKey, []byte{byte(balance - 50)}) + return true + } + } + + tx0, tx1, tx2 := genMockTx(0), genMockTx(1), genMockTx(2) + + view0 := mv.View(0) + require.True(t, tx0(view0)) + view1 := mv.View(1) + require.False(t, tx1(view1)) + view2 := mv.View(2) + require.False(t, tx2(view2)) + + require.True(t, mv.Record(TxnVersion{1, 0}, view1)) + require.True(t, mv.Record(TxnVersion{2, 0}, view2)) + require.True(t, mv.Record(TxnVersion{0, 0}, view0)) + + require.True(t, mv.ValidateReadSet(0)) + require.False(t, mv.ValidateReadSet(1)) + mv.ConvertWritesToEstimates(1) + require.False(t, mv.ValidateReadSet(2)) + mv.ConvertWritesToEstimates(2) + + // re-execute tx 1 and 2 + view1 = mv.View(1) + require.True(t, tx1(view1)) + mv.Record(TxnVersion{1, 1}, view1) + require.True(t, mv.ValidateReadSet(1)) + + view2 = mv.View(2) + // tx 2 fail due to insufficient balance, but stm validation is successful. + require.False(t, tx2(view2)) + mv.Record(TxnVersion{2, 1}, view2) + require.True(t, mv.ValidateReadSet(2)) + + mv.WriteSnapshot(storage) + { + authStore := storage.GetKVStore(StoreKeyAuth) + require.Equal(t, []byte{2}, authStore.Get(nonceKey)) + bankStore := storage.GetKVStore(StoreKeyBank) + require.Equal(t, []byte{0}, bankStore.Get(balanceKey)) + } +} diff --git a/blockstm/mvview.go b/blockstm/mvview.go new file mode 100644 index 000000000000..5254542cd9b4 --- /dev/null +++ b/blockstm/mvview.go @@ -0,0 +1,204 @@ +package blockstm + +import ( + "io" + + "cosmossdk.io/store/cachekv" + "cosmossdk.io/store/tracekv" + storetypes "cosmossdk.io/store/types" +) + +var ( + _ storetypes.KVStore = (*GMVMemoryView[[]byte])(nil) + _ storetypes.ObjKVStore = (*GMVMemoryView[any])(nil) + _ MVView = (*GMVMemoryView[[]byte])(nil) + _ MVView = (*GMVMemoryView[any])(nil) +) + +// GMVMemoryView wraps `MVMemory` for execution of a single transaction. +type GMVMemoryView[V any] struct { + storage storetypes.GKVStore[V] + mvData *GMVData[V] + scheduler *Scheduler + store int + + txn TxnIndex + readSet *ReadSet + writeSet *GMemDB[V] +} + +func NewMVView(store int, storage storetypes.Store, mvData MVStore, scheduler *Scheduler, txn TxnIndex) MVView { + switch data := mvData.(type) { + case *GMVData[any]: + return NewGMVMemoryView(store, storage.(storetypes.ObjKVStore), data, scheduler, txn) + case *GMVData[[]byte]: + return NewGMVMemoryView(store, storage.(storetypes.KVStore), data, scheduler, txn) + default: + panic("unsupported value type") + } +} + +func NewGMVMemoryView[V any](store int, storage storetypes.GKVStore[V], mvData *GMVData[V], scheduler *Scheduler, txn TxnIndex) *GMVMemoryView[V] { + return &GMVMemoryView[V]{ + store: store, + storage: storage, + mvData: mvData, + scheduler: scheduler, + txn: txn, + readSet: new(ReadSet), + } +} + +func (s *GMVMemoryView[V]) init() { + if s.writeSet == nil { + s.writeSet = NewGMemDBNonConcurrent(s.mvData.isZero, s.mvData.valueLen) + } +} + +func (s *GMVMemoryView[V]) waitFor(txn TxnIndex) { + cond := s.scheduler.WaitForDependency(s.txn, txn) + if cond != nil { + cond.Wait() + } +} + +func (s *GMVMemoryView[V]) ApplyWriteSet(version TxnVersion) Locations { + if s.writeSet == nil || s.writeSet.Len() == 0 { + return nil + } + + newLocations := make([]Key, 0, s.writeSet.Len()) + s.writeSet.Scan(func(key Key, value V) bool { + s.mvData.Write(key, value, version) + newLocations = append(newLocations, key) + return true + }) + + return newLocations +} + +func (s *GMVMemoryView[V]) ReadSet() *ReadSet { + return s.readSet +} + +func (s *GMVMemoryView[V]) Get(key []byte) V { + if s.writeSet != nil { + if value, found := s.writeSet.OverlayGet(key); found { + // value written by this txn + // nil value means deleted + return value + } + } + + for { + value, version, estimate := s.mvData.Read(key, s.txn) + if estimate { + // read ESTIMATE mark, wait for the blocking txn to finish + s.waitFor(version.Index) + continue + } + + // record the read version, invalid version is ⊥. + // if not found, record version ⊥ when reading from storage. + s.readSet.Reads = append(s.readSet.Reads, ReadDescriptor{key, version}) + if !version.Valid() { + return s.storage.Get(key) + } + return value + } +} + +func (s *GMVMemoryView[V]) Has(key []byte) bool { + return !s.mvData.isZero(s.Get(key)) +} + +func (s *GMVMemoryView[V]) Set(key []byte, value V) { + if s.mvData.isZero(value) { + panic("nil value is not allowed") + } + s.init() + s.writeSet.OverlaySet(key, value) +} + +func (s *GMVMemoryView[V]) Delete(key []byte) { + var empty V + s.init() + s.writeSet.OverlaySet(key, empty) +} + +func (s *GMVMemoryView[V]) Iterator(start, end []byte) storetypes.GIterator[V] { + return s.iterator(IteratorOptions{Start: start, End: end, Ascending: true}) +} + +func (s *GMVMemoryView[V]) ReverseIterator(start, end []byte) storetypes.GIterator[V] { + return s.iterator(IteratorOptions{Start: start, End: end, Ascending: false}) +} + +func (s *GMVMemoryView[V]) iterator(opts IteratorOptions) storetypes.GIterator[V] { + mvIter := s.mvData.Iterator(opts, s.txn, s.waitFor) + + var parentIter, wsIter storetypes.GIterator[V] + + if s.writeSet == nil { + wsIter = NewNoopIterator[V](opts.Start, opts.End, opts.Ascending) + } else { + wsIter = s.writeSet.iterator(opts.Start, opts.End, opts.Ascending) + } + + if opts.Ascending { + parentIter = s.storage.Iterator(opts.Start, opts.End) + } else { + parentIter = s.storage.ReverseIterator(opts.Start, opts.End) + } + + onClose := func(iter storetypes.GIterator[V]) { + reads := mvIter.Reads() + + var stopKey Key + if iter.Valid() { + stopKey = iter.Key() + + // if the iterator is not exhausted, the merge iterator may have read one more key which is not observed by + // the caller, in that case we remove that read descriptor. + if len(reads) > 0 { + lastRead := reads[len(reads)-1].Key + if BytesBeyond(lastRead, stopKey, opts.Ascending) { + reads = reads[:len(reads)-1] + } + } + } + + s.readSet.Iterators = append(s.readSet.Iterators, IteratorDescriptor{ + IteratorOptions: opts, + Stop: stopKey, + Reads: reads, + }) + } + + // three-way merge iterator + return NewCacheMergeIterator( + NewCacheMergeIterator(parentIter, mvIter, opts.Ascending, nil, s.mvData.isZero), + wsIter, + opts.Ascending, + onClose, + s.mvData.isZero, + ) +} + +// CacheWrap implements types.Store. +func (s *GMVMemoryView[V]) CacheWrap() storetypes.CacheWrap { + return cachekv.NewGStore(s, s.mvData.isZero, s.mvData.valueLen) +} + +// CacheWrapWithTrace implements types.Store. +func (s *GMVMemoryView[V]) CacheWrapWithTrace(w io.Writer, tc storetypes.TraceContext) storetypes.CacheWrap { + if store, ok := any(s).(*GMVMemoryView[[]byte]); ok { + return cachekv.NewGStore(tracekv.NewStore(store, w, tc), store.mvData.isZero, store.mvData.valueLen) + } + return s.CacheWrap() +} + +// GetStoreType implements types.Store. +func (s *GMVMemoryView[V]) GetStoreType() storetypes.StoreType { + return s.storage.GetStoreType() +} diff --git a/blockstm/mvview_test.go b/blockstm/mvview_test.go new file mode 100644 index 000000000000..edacf88d1d90 --- /dev/null +++ b/blockstm/mvview_test.go @@ -0,0 +1,181 @@ +package blockstm + +import ( + "fmt" + "testing" + + "github.com/test-go/testify/require" + + storetypes "cosmossdk.io/store/types" +) + +func TestMVMemoryViewDelete(t *testing.T) { + stores := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + } + storage := NewMultiMemDB(stores) + mv := NewMVMemory(16, stores, storage, nil) + + mview := mv.View(0) + view := mview.GetKVStore(StoreKeyAuth) + view.Set(Key("a"), []byte("1")) + view.Set(Key("b"), []byte("1")) + view.Set(Key("c"), []byte("1")) + require.True(t, mv.Record(TxnVersion{0, 0}, mview)) + + mview = mv.View(1) + view = mview.GetKVStore(StoreKeyAuth) + view.Delete(Key("a")) + view.Set(Key("b"), []byte("2")) + require.True(t, mv.Record(TxnVersion{1, 0}, mview)) + + mview = mv.View(2) + view = mview.GetKVStore(StoreKeyAuth) + require.Nil(t, view.Get(Key("a"))) + require.False(t, view.Has(Key("a"))) +} + +func TestMVMemoryViewIteration(t *testing.T) { + stores := map[storetypes.StoreKey]int{StoreKeyAuth: 0} + storage := NewMultiMemDB(stores) + mv := NewMVMemory(16, stores, storage, nil) + { + parentState := []KVPair{ + {Key("a"), []byte("1")}, + {Key("A"), []byte("1")}, + } + parent := storage.GetKVStore(StoreKeyAuth) + for _, kv := range parentState { + parent.Set(kv.Key, kv.Value) + } + } + + sets := [][]KVPair{ + {{Key("a"), []byte("1")}, {Key("b"), []byte("1")}, {Key("c"), []byte("1")}}, + {{Key("b"), []byte("2")}, {Key("c"), []byte("2")}, {Key("d"), []byte("2")}}, + {{Key("c"), []byte("3")}, {Key("d"), []byte("3")}, {Key("e"), []byte("3")}}, + {{Key("d"), []byte("4")}, {Key("f"), []byte("4")}}, + {{Key("e"), []byte("5")}, {Key("f"), []byte("5")}, {Key("g"), []byte("5")}}, + {{Key("f"), []byte("6")}, {Key("g"), []byte("6")}, {Key("a"), []byte("6")}}, + } + deletes := [][]Key{ + {}, + {}, + {Key("a")}, + {Key("A"), Key("e")}, + {}, + {Key("b"), Key("c"), Key("d")}, + } + + for i, pairs := range sets { + mview := mv.View(TxnIndex(i)) + view := mview.GetKVStore(StoreKeyAuth) + for _, kv := range pairs { + view.Set(kv.Key, kv.Value) + } + for _, key := range deletes[i] { + view.Delete(key) + } + require.True(t, mv.Record(TxnVersion{TxnIndex(i), 0}, mview)) + } + + testCases := []struct { + index TxnIndex + start, end Key + ascending bool + expect []KVPair + }{ + {2, nil, nil, true, []KVPair{ + {Key("A"), []byte("1")}, + {Key("a"), []byte("1")}, + {Key("b"), []byte("2")}, + {Key("c"), []byte("2")}, + {Key("d"), []byte("2")}, + }}, + {3, nil, nil, true, []KVPair{ + {Key("A"), []byte("1")}, + {Key("b"), []byte("2")}, + {Key("c"), []byte("3")}, + {Key("d"), []byte("3")}, + {Key("e"), []byte("3")}, + }}, + {3, nil, nil, false, []KVPair{ + {Key("e"), []byte("3")}, + {Key("d"), []byte("3")}, + {Key("c"), []byte("3")}, + {Key("b"), []byte("2")}, + {Key("A"), []byte("1")}, + }}, + {4, nil, nil, true, []KVPair{ + {Key("b"), []byte("2")}, + {Key("c"), []byte("3")}, + {Key("d"), []byte("4")}, + {Key("f"), []byte("4")}, + }}, + {5, nil, nil, true, []KVPair{ + {Key("b"), []byte("2")}, + {Key("c"), []byte("3")}, + {Key("d"), []byte("4")}, + {Key("e"), []byte("5")}, + {Key("f"), []byte("5")}, + {Key("g"), []byte("5")}, + }}, + {6, nil, nil, true, []KVPair{ + {Key("a"), []byte("6")}, + {Key("e"), []byte("5")}, + {Key("f"), []byte("6")}, + {Key("g"), []byte("6")}, + }}, + {6, Key("e"), Key("g"), true, []KVPair{ + {Key("e"), []byte("5")}, + {Key("f"), []byte("6")}, + }}, + {6, Key("e"), Key("g"), false, []KVPair{ + {Key("f"), []byte("6")}, + {Key("e"), []byte("5")}, + }}, + {6, Key("b"), nil, true, []KVPair{ + {Key("e"), []byte("5")}, + {Key("f"), []byte("6")}, + {Key("g"), []byte("6")}, + }}, + {6, Key("b"), nil, false, []KVPair{ + {Key("g"), []byte("6")}, + {Key("f"), []byte("6")}, + {Key("e"), []byte("5")}, + }}, + {6, nil, Key("g"), true, []KVPair{ + {Key("a"), []byte("6")}, + {Key("e"), []byte("5")}, + {Key("f"), []byte("6")}, + }}, + {6, nil, Key("g"), false, []KVPair{ + {Key("f"), []byte("6")}, + {Key("e"), []byte("5")}, + {Key("a"), []byte("6")}, + }}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("version-%d", tc.index), func(t *testing.T) { + view := mv.View(tc.index).GetKVStore(StoreKeyAuth) + var iter storetypes.Iterator + if tc.ascending { + iter = view.Iterator(tc.start, tc.end) + } else { + iter = view.ReverseIterator(tc.start, tc.end) + } + require.Equal(t, tc.expect, CollectIterator(iter)) + require.NoError(t, iter.Close()) + }) + } +} + +func CollectIterator[V any](iter storetypes.GIterator[V]) []GKVPair[V] { + var res []GKVPair[V] + for iter.Valid() { + res = append(res, GKVPair[V]{iter.Key(), iter.Value()}) + iter.Next() + } + return res +} diff --git a/blockstm/scheduler.go b/blockstm/scheduler.go new file mode 100644 index 000000000000..930ddd1d8c9b --- /dev/null +++ b/blockstm/scheduler.go @@ -0,0 +1,215 @@ +package blockstm + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" +) + +type TaskKind int + +const ( + TaskKindExecution TaskKind = iota + TaskKindValidation +) + +type TxDependency struct { + sync.Mutex + dependents []TxnIndex +} + +func (t *TxDependency) Swap(new []TxnIndex) []TxnIndex { + t.Lock() + old := t.dependents + t.dependents = new + t.Unlock() + return old +} + +// Scheduler implements the scheduler for the block-stm +// ref: `Algorithm 4 The Scheduler module, variables, utility APIs and next task logic` +type Scheduler struct { + blockSize int + + // An index that tracks the next transaction to try and execute. + executionIdx atomic.Uint64 + // A similar index for tracking validation. + validationIdx atomic.Uint64 + // Number of times validationIdx or executionIdx was decreased + decreaseCnt atomic.Uint64 + // Number of ongoing validation and execution tasks + numActiveTasks atomic.Uint64 + // Marker for completion + doneMarker atomic.Bool + + // txnIdx to a mutex-protected set of dependent transaction indices + txnDependency []TxDependency + // txnIdx to a mutex-protected pair (incarnationNumber, status), where status ∈ {READY_TO_EXECUTE, EXECUTING, EXECUTED, ABORTING}. + txnStatus []StatusEntry + + // metrics + executedTxns atomic.Int64 + validatedTxns atomic.Int64 +} + +func NewScheduler(blockSize int) *Scheduler { + return &Scheduler{ + blockSize: blockSize, + txnDependency: make([]TxDependency, blockSize), + txnStatus: make([]StatusEntry, blockSize), + } +} + +func (s *Scheduler) Done() bool { + return s.doneMarker.Load() +} + +func (s *Scheduler) DecreaseValidationIdx(target TxnIndex) { + StoreMin(&s.validationIdx, uint64(target)) + s.decreaseCnt.Add(1) +} + +func (s *Scheduler) CheckDone() { + observedCnt := s.decreaseCnt.Load() + if s.executionIdx.Load() >= uint64(s.blockSize) && + s.validationIdx.Load() >= uint64(s.blockSize) && + s.numActiveTasks.Load() == 0 { + if observedCnt == s.decreaseCnt.Load() { + s.doneMarker.Store(true) + } + } + // avoid busy waiting + runtime.Gosched() +} + +// TryIncarnate tries to incarnate a transaction index to execute. +// Returns the transaction version if successful, otherwise returns invalid version. +// +// Invariant `numActiveTasks`: decreased if an invalid task is returned. +func (s *Scheduler) TryIncarnate(idx TxnIndex) TxnVersion { + if int(idx) < s.blockSize { + if incarnation, ok := s.txnStatus[idx].TrySetExecuting(); ok { + return TxnVersion{idx, incarnation} + } + } + DecrAtomic(&s.numActiveTasks) + return InvalidTxnVersion +} + +// NextVersionToExecute get the next transaction index to execute, +// returns invalid version if no task is available +// +// Invariant `numActiveTasks`: increased if a valid task is returned. +func (s *Scheduler) NextVersionToExecute() TxnVersion { + if s.executionIdx.Load() >= uint64(s.blockSize) { + s.CheckDone() + return InvalidTxnVersion + } + IncrAtomic(&s.numActiveTasks) + idxToExecute := s.executionIdx.Add(1) - 1 + return s.TryIncarnate(TxnIndex(idxToExecute)) +} + +// NextVersionToValidate get the next transaction index to validate, +// returns invalid version if no task is available. +// +// Invariant `numActiveTasks`: increased if a valid task is returned. +func (s *Scheduler) NextVersionToValidate() TxnVersion { + if s.validationIdx.Load() >= uint64(s.blockSize) { + s.CheckDone() + return InvalidTxnVersion + } + IncrAtomic(&s.numActiveTasks) + idxToValidate := FetchIncr(&s.validationIdx) + if idxToValidate < uint64(s.blockSize) { + if ok, incarnation := s.txnStatus[idxToValidate].IsExecuted(); ok { + return TxnVersion{TxnIndex(idxToValidate), incarnation} + } + } + + DecrAtomic(&s.numActiveTasks) + return InvalidTxnVersion +} + +// NextTask returns the transaction index and task kind for the next task to execute or validate, +// returns invalid version if no task is available. +// +// Invariant `numActiveTasks`: increased if a valid task is returned. +func (s *Scheduler) NextTask() (TxnVersion, TaskKind) { + validationIdx := s.validationIdx.Load() + executionIdx := s.executionIdx.Load() + if validationIdx < executionIdx { + return s.NextVersionToValidate(), TaskKindValidation + } else { + return s.NextVersionToExecute(), TaskKindExecution + } +} + +func (s *Scheduler) WaitForDependency(txn, blockingTxn TxnIndex) *Condvar { + cond := NewCondvar() + entry := &s.txnDependency[blockingTxn] + entry.Lock() + + // thread holds 2 locks + if ok, _ := s.txnStatus[blockingTxn].IsExecuted(); ok { + // dependency resolved before locking in Line 148 + entry.Unlock() + return nil + } + + s.txnStatus[txn].Suspend(cond) + entry.dependents = append(entry.dependents, txn) + entry.Unlock() + + return cond +} + +func (s *Scheduler) ResumeDependencies(txns []TxnIndex) { + for _, txn := range txns { + s.txnStatus[txn].Resume() + } +} + +// FinishExecution marks an execution task as complete. +// Invariant `numActiveTasks`: decreased if an invalid task is returned. +func (s *Scheduler) FinishExecution(version TxnVersion, wroteNewPath bool) (TxnVersion, TaskKind) { + s.txnStatus[version.Index].SetExecuted() + + deps := s.txnDependency[version.Index].Swap(nil) + s.ResumeDependencies(deps) + if s.validationIdx.Load() > uint64(version.Index) { // otherwise index already small enough + if !wroteNewPath { + // schedule validation for current tx only, don't decrease numActiveTasks + return version, TaskKindValidation + } + // schedule validation for txnIdx and higher txns + s.DecreaseValidationIdx(version.Index) + } + DecrAtomic(&s.numActiveTasks) + return InvalidTxnVersion, 0 +} + +func (s *Scheduler) TryValidationAbort(version TxnVersion) bool { + return s.txnStatus[version.Index].TryValidationAbort(version.Incarnation) +} + +// FinishValidation marks a validation task as complete. +// Invariant `numActiveTasks`: decreased if an invalid task is returned. +func (s *Scheduler) FinishValidation(txn TxnIndex, aborted bool) (TxnVersion, TaskKind) { + if aborted { + s.txnStatus[txn].SetReadyStatus() + s.DecreaseValidationIdx(txn + 1) + if s.executionIdx.Load() > uint64(txn) { + return s.TryIncarnate(txn), TaskKindExecution + } + } + + DecrAtomic(&s.numActiveTasks) + return InvalidTxnVersion, 0 +} + +func (s *Scheduler) Stats() string { + return fmt.Sprintf("executed: %d, validated: %d", + s.executedTxns.Load(), s.validatedTxns.Load()) +} diff --git a/blockstm/status.go b/blockstm/status.go new file mode 100644 index 000000000000..f3d84828e051 --- /dev/null +++ b/blockstm/status.go @@ -0,0 +1,118 @@ +package blockstm + +import "sync" + +type Status uint + +const ( + StatusReadyToExecute Status = iota + StatusExecuting + StatusExecuted + StatusAborting + StatusSuspended +) + +// StatusEntry is a state machine for the status of a transaction, all the transitions are atomic protected by a mutex. +// +// ```mermaid +// stateDiagram-v2 +// +// [*] --> ReadyToExecute +// ReadyToExecute --> Executing: TrySetExecuting() +// Executing --> Executed: SetExecuted() +// Executing --> Suspended: Suspend(cond)\nset cond +// Executed --> Aborting: TryValidationAbort(incarnation) +// Aborting --> ReadyToExecute: SetReadyStatus()\nincarnation++ +// Suspended --> Executing: Resume() +// +// ``` +type StatusEntry struct { + sync.Mutex + + incarnation Incarnation + status Status + + cond *Condvar +} + +func (s *StatusEntry) IsExecuted() (ok bool, incarnation Incarnation) { + s.Lock() + + if s.status == StatusExecuted { + ok = true + incarnation = s.incarnation + } + + s.Unlock() + return ok, incarnation +} + +func (s *StatusEntry) TrySetExecuting() (Incarnation, bool) { + s.Lock() + + if s.status == StatusReadyToExecute { + s.status = StatusExecuting + incarnation := s.incarnation + + s.Unlock() + return incarnation, true + } + + s.Unlock() + return 0, false +} + +func (s *StatusEntry) setStatus(status Status) { + s.Lock() + s.status = status + s.Unlock() +} + +func (s *StatusEntry) Resume() { + // status must be SUSPENDED and cond != nil + s.Lock() + + s.status = StatusExecuting + s.cond.Notify() + s.cond = nil + + s.Unlock() +} + +func (s *StatusEntry) SetExecuted() { + // status must have been EXECUTING + s.setStatus(StatusExecuted) +} + +func (s *StatusEntry) TryValidationAbort(incarnation Incarnation) bool { + s.Lock() + + if s.incarnation == incarnation && s.status == StatusExecuted { + s.status = StatusAborting + + s.Unlock() + return true + } + + s.Unlock() + return false +} + +func (s *StatusEntry) SetReadyStatus() { + s.Lock() + + s.incarnation++ + // status must be ABORTING + s.status = StatusReadyToExecute + + s.Unlock() +} + +func (s *StatusEntry) Suspend(cond *Condvar) { + s.Lock() + + s.cond = cond + s.status = StatusSuspended + + s.Unlock() +} diff --git a/blockstm/stm.go b/blockstm/stm.go new file mode 100644 index 000000000000..16f35b3815f7 --- /dev/null +++ b/blockstm/stm.go @@ -0,0 +1,80 @@ +package blockstm + +import ( + "context" + "errors" + "fmt" + "runtime" + + "golang.org/x/sync/errgroup" + + storetypes "cosmossdk.io/store/types" + + "github.com/cosmos/cosmos-sdk/telemetry" +) + +func ExecuteBlock( + ctx context.Context, + blockSize int, + stores map[storetypes.StoreKey]int, + storage MultiStore, + executors int, + txExecutor TxExecutor, +) error { + return ExecuteBlockWithEstimates( + ctx, blockSize, stores, storage, executors, + nil, txExecutor, + ) +} + +func ExecuteBlockWithEstimates( + ctx context.Context, + blockSize int, + stores map[storetypes.StoreKey]int, + storage MultiStore, + executors int, + estimates []MultiLocations, // txn -> multi-locations + txExecutor TxExecutor, +) error { + if executors < 0 { + return fmt.Errorf("invalid number of executors: %d", executors) + } + if executors == 0 { + executors = maxParallelism() + } + + // Create a new scheduler + scheduler := NewScheduler(blockSize) + mvMemory := NewMVMemoryWithEstimates(blockSize, stores, storage, scheduler, estimates) + + // var wg sync.WaitGroup + var wg errgroup.Group + wg.SetLimit(executors) + for i := 0; i < executors; i++ { + e := NewExecutor(ctx, scheduler, txExecutor, mvMemory, i) + wg.Go(e.Run) + } + if err := wg.Wait(); err != nil { + return err + } + + if !scheduler.Done() { + if ctx.Err() != nil { + // canceled + return ctx.Err() + } + + return errors.New("scheduler did not complete") + } + + telemetry.SetGauge(float32(scheduler.executedTxns.Load()), TelemetrySubsystem, KeyExecutedTxs) + telemetry.SetGauge(float32(scheduler.validatedTxns.Load()), TelemetrySubsystem, KeyValidatedTxs) + + // Write the snapshot into the storage + mvMemory.WriteSnapshot(storage) + return nil +} + +func maxParallelism() int { + return min(runtime.GOMAXPROCS(0), runtime.NumCPU()) +} diff --git a/blockstm/stm_test.go b/blockstm/stm_test.go new file mode 100644 index 000000000000..485bafd4bf28 --- /dev/null +++ b/blockstm/stm_test.go @@ -0,0 +1,175 @@ +package blockstm + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "math/rand" + "testing" + + "github.com/test-go/testify/require" + + storetypes "cosmossdk.io/store/types" +) + +func accountName(i int64) string { + return fmt.Sprintf("account%05d", i) +} + +func testBlock(size, accounts int) *MockBlock { + txs := make([]Tx, size) + g := rand.New(rand.NewSource(0)) + for i := 0; i < size; i++ { + sender := g.Int63n(int64(accounts)) + receiver := g.Int63n(int64(accounts)) + txs[i] = BankTransferTx(i, accountName(sender), accountName(receiver), 1) + } + return NewMockBlock(txs) +} + +func iterateBlock(size, accounts int) *MockBlock { + txs := make([]Tx, size) + g := rand.New(rand.NewSource(0)) + for i := 0; i < size; i++ { + sender := g.Int63n(int64(accounts)) + receiver := g.Int63n(int64(accounts)) + txs[i] = IterateTx(i, accountName(sender), accountName(receiver), 1) + } + return NewMockBlock(txs) +} + +func noConflictBlock(size int) *MockBlock { + txs := make([]Tx, size) + for i := 0; i < size; i++ { + sender := accountName(int64(i)) + txs[i] = BankTransferTx(i, sender, sender, 1) + } + return NewMockBlock(txs) +} + +func worstCaseBlock(size int) *MockBlock { + txs := make([]Tx, size) + for i := 0; i < size; i++ { + // all transactions are from the same account + sender := "account0" + txs[i] = BankTransferTx(i, sender, sender, 1) + } + return NewMockBlock(txs) +} + +func determisticBlock() *MockBlock { + return NewMockBlock([]Tx{ + NoopTx(0, "account0"), + NoopTx(1, "account1"), + NoopTx(2, "account1"), + NoopTx(3, "account1"), + NoopTx(4, "account3"), + NoopTx(5, "account1"), + NoopTx(6, "account4"), + NoopTx(7, "account5"), + NoopTx(8, "account6"), + }) +} + +func TestSTM(t *testing.T) { + stores := map[storetypes.StoreKey]int{StoreKeyAuth: 0, StoreKeyBank: 1} + testCases := []struct { + name string + blk *MockBlock + executors int + }{ + { + name: "testBlock(100,80),10", + blk: testBlock(100, 80), + executors: 10, + }, + { + name: "testBlock(100,3),10", + blk: testBlock(100, 3), + executors: 10, + }, + { + name: "determisticBlock(),5", + blk: determisticBlock(), + executors: 5, + }, + { + name: "noConflictBlock(100),5", + blk: noConflictBlock(100), + executors: 5, + }, + { + name: "worstCaseBlock(100),5", + blk: worstCaseBlock(100), + executors: 5, + }, + { + name: "iterateBlock(100,80),10", + blk: iterateBlock(100, 80), + executors: 10, + }, + { + name: "iterateBlock(100,10),10", + blk: iterateBlock(100, 10), + executors: 10, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + storage := NewMultiMemDB(stores) + require.NoError(t, + ExecuteBlock(context.Background(), tc.blk.Size(), stores, storage, tc.executors, tc.blk.ExecuteTx), + ) + for _, err := range tc.blk.Results { + require.NoError(t, err) + } + + crossCheck := NewMultiMemDB(stores) + runSequential(crossCheck, tc.blk) + + // check parallel execution matches sequential execution + for store := range stores { + require.True(t, StoreEqual(crossCheck.GetKVStore(store), storage.GetKVStore(store))) + } + + // check total nonce increased the same amount as the number of transactions + var total uint64 + store := storage.GetKVStore(StoreKeyAuth) + it := store.Iterator(nil, nil) + defer it.Close() + + for ; it.Valid(); it.Next() { + if !bytes.HasPrefix(it.Key(), []byte("nonce")) { + continue + } + total += binary.BigEndian.Uint64(it.Value()) + continue + } + require.Equal(t, uint64(tc.blk.Size()), total) + }) + } +} + +func StoreEqual(a, b storetypes.KVStore) bool { + // compare with iterators + iter1 := a.Iterator(nil, nil) + iter2 := b.Iterator(nil, nil) + defer iter1.Close() + defer iter2.Close() + + for { + if !iter1.Valid() && !iter2.Valid() { + return true + } + if !iter1.Valid() || !iter2.Valid() { + return false + } + if !bytes.Equal(iter1.Key(), iter2.Key()) || !bytes.Equal(iter1.Value(), iter2.Value()) { + return false + } + iter1.Next() + iter2.Next() + } +} diff --git a/blockstm/tree/btree.go b/blockstm/tree/btree.go new file mode 100644 index 000000000000..f15ae1b977fd --- /dev/null +++ b/blockstm/tree/btree.go @@ -0,0 +1,91 @@ +package tree + +import ( + "sync/atomic" + + "github.com/tidwall/btree" +) + +// BTree wraps an atomic pointer to an unsafe btree.BTreeG +type BTree[T any] struct { + atomic.Pointer[btree.BTreeG[T]] +} + +// NewBTree returns a new BTree. +func NewBTree[T any](less func(a, b T) bool, degree int) *BTree[T] { + tree := btree.NewBTreeGOptions(less, btree.Options{ + NoLocks: true, + ReadOnly: true, + Degree: degree, + }) + t := &BTree[T]{} + t.Store(tree) + return t +} + +func (bt *BTree[T]) Get(item T) (result T, ok bool) { + return bt.Load().Get(item) +} + +func (bt *BTree[T]) GetOrDefault(item T, fillDefaults func(*T)) T { + for { + t := bt.Load() + result, ok := t.Get(item) + if ok { + return result + } + fillDefaults(&item) + c := t.Copy() + c.Set(item) + c.Freeze() + if bt.CompareAndSwap(t, c) { + return item + } + } +} + +func (bt *BTree[T]) Set(item T) (prev T, ok bool) { + for { + t := bt.Load() + c := t.Copy() + prev, ok = c.Set(item) + c.Freeze() + if bt.CompareAndSwap(t, c) { + return prev, ok + } + } +} + +func (bt *BTree[T]) Delete(item T) (prev T, ok bool) { + for { + t := bt.Load() + c := t.Copy() + prev, ok = c.Delete(item) + c.Freeze() + if bt.CompareAndSwap(t, c) { + return prev, ok + } + } +} + +func (bt *BTree[T]) Scan(iter func(item T) bool) { + bt.Load().Scan(iter) +} + +func (bt *BTree[T]) Max() (T, bool) { + return bt.Load().Max() +} + +func (bt *BTree[T]) Iter() btree.IterG[T] { + return bt.Load().Iter() +} + +// ReverseSeek returns the first item that is less than or equal to the pivot +func (bt *BTree[T]) ReverseSeek(pivot T) (result T, ok bool) { + bt.Load().Descend(pivot, func(item T) bool { + result = item + ok = true + return false + }) + return result, ok +} diff --git a/blockstm/tree/btreeiterator.go b/blockstm/tree/btreeiterator.go new file mode 100644 index 000000000000..dc1743e1a913 --- /dev/null +++ b/blockstm/tree/btreeiterator.go @@ -0,0 +1,138 @@ +package tree + +import ( + "bytes" + "errors" + + "github.com/tidwall/btree" +) + +// BTreeIteratorG iterates over btree. +// Implements Iterator. +type BTreeIteratorG[T KeyItem] struct { + iter btree.IterG[T] + + start []byte + end []byte + ascending bool + valid bool +} + +func NewNoopBTreeIteratorG[T KeyItem]( + start, end []byte, + ascending bool, + valid bool, +) BTreeIteratorG[T] { + return BTreeIteratorG[T]{ + start: start, + end: end, + ascending: ascending, + valid: valid, + } +} + +func NewBTreeIteratorG[T KeyItem]( + startItem, endItem T, + iter btree.IterG[T], + ascending bool, +) *BTreeIteratorG[T] { + start := startItem.GetKey() + end := endItem.GetKey() + + var valid bool + if ascending { + if start != nil { + valid = iter.Seek(startItem) + } else { + valid = iter.First() + } + } else { + if end != nil { + valid = iter.Seek(endItem) + if !valid { + valid = iter.Last() + } else { + // end is exclusive + valid = iter.Prev() + } + } else { + valid = iter.Last() + } + } + + mi := &BTreeIteratorG[T]{ + iter: iter, + start: start, + end: end, + ascending: ascending, + valid: valid, + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } + + return mi +} + +func (mi *BTreeIteratorG[T]) Domain() (start, end []byte) { + return mi.start, mi.end +} + +func (mi *BTreeIteratorG[T]) Close() error { + mi.iter.Release() + return nil +} + +func (mi *BTreeIteratorG[T]) Error() error { + if !mi.Valid() { + return errors.New("invalid memIterator") + } + return nil +} + +func (mi *BTreeIteratorG[T]) Valid() bool { + return mi.valid +} + +func (mi *BTreeIteratorG[T]) Invalidate() { + mi.valid = false +} + +func (mi *BTreeIteratorG[T]) Next() { + mi.assertValid() + + if mi.ascending { + mi.valid = mi.iter.Next() + } else { + mi.valid = mi.iter.Prev() + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } +} + +func (mi *BTreeIteratorG[T]) keyInRange(key []byte) bool { + if mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 { + return false + } + if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 { + return false + } + return true +} + +func (mi *BTreeIteratorG[T]) Item() T { + return mi.iter.Item() +} + +func (mi *BTreeIteratorG[T]) Key() []byte { + return mi.Item().GetKey() +} + +func (mi *BTreeIteratorG[T]) assertValid() { + if err := mi.Error(); err != nil { + panic(err) + } +} diff --git a/blockstm/tree/types.go b/blockstm/tree/types.go new file mode 100644 index 000000000000..32ce82d2a104 --- /dev/null +++ b/blockstm/tree/types.go @@ -0,0 +1,11 @@ +package tree + +import "bytes" + +type KeyItem interface { + GetKey() []byte +} + +func KeyItemLess[T KeyItem](a, b T) bool { + return bytes.Compare(a.GetKey(), b.GetKey()) < 0 +} diff --git a/blockstm/txnrunner.go b/blockstm/txnrunner.go new file mode 100644 index 000000000000..ad4c32cdf1ec --- /dev/null +++ b/blockstm/txnrunner.go @@ -0,0 +1,220 @@ +package blockstm + +import ( + "context" + "sync" + "sync/atomic" + + abci "github.com/cometbft/cometbft/abci/types" + + "cosmossdk.io/collections" + storetypes "cosmossdk.io/store/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +var ( + _ sdk.TxRunner = DefaultRunner{} + _ sdk.TxRunner = STMRunner{} +) + +func NewDefaultRunner(txDecoder sdk.TxDecoder) *DefaultRunner { + return &DefaultRunner{ + txDecoder: txDecoder, + } +} + +// DefaultRunner default executor without parallelism +type DefaultRunner struct { + txDecoder sdk.TxDecoder +} + +func (d DefaultRunner) Run(ctx context.Context, _ storetypes.MultiStore, txs [][]byte, deliverTx sdk.DeliverTxFunc) ([]*abci.ExecTxResult, error) { + // Fallback to the default execution logic + txResults := make([]*abci.ExecTxResult, 0, len(txs)) + for i, rawTx := range txs { + var response *abci.ExecTxResult + + if _, err := d.txDecoder(rawTx); err == nil { + response = deliverTx(rawTx, nil, i, nil) + } else { + // In the case where a transaction included in a block proposal is malformed, + // we still want to return a default response to comet. This is because comet + // expects a response for each transaction included in a block proposal. + response = sdkerrors.ResponseExecTxResultWithEvents( + sdkerrors.ErrTxDecode, + 0, + 0, + nil, + false, + ) + } + + // check after every tx if we should abort + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + // continue + } + + txResults = append(txResults, response) + } + return txResults, nil +} + +func NewSTMRunner( + txDecoder sdk.TxDecoder, + stores []storetypes.StoreKey, + workers int, estimate bool, coinDenom string, +) *STMRunner { + return &STMRunner{ + txDecoder: txDecoder, + stores: stores, + workers: workers, + estimate: estimate, + coinDenom: coinDenom, + } +} + +// STMRunner simple implementation of block-stm +type STMRunner struct { + txDecoder sdk.TxDecoder + stores []storetypes.StoreKey + workers int + estimate bool + coinDenom string +} + +func (e STMRunner) Run(ctx context.Context, ms storetypes.MultiStore, txs [][]byte, deliverTx sdk.DeliverTxFunc) ([]*abci.ExecTxResult, error) { + var authStore, bankStore int + index := make(map[storetypes.StoreKey]int, len(e.stores)) + for i, k := range e.stores { + switch k.Name() { + case "acc": + authStore = i + case "bank": + bankStore = i + } + index[k] = i + } + + blockSize := len(txs) + if blockSize == 0 { + return nil, nil + } + results := make([]*abci.ExecTxResult, blockSize) + incarnationCache := make([]atomic.Pointer[map[string]any], blockSize) + for i := 0; i < blockSize; i++ { + m := make(map[string]any) + incarnationCache[i].Store(&m) + } + + var ( + estimates []MultiLocations + memTxs [][]byte + ) + + if e.estimate { + memTxs, estimates = preEstimates(txs, e.workers, authStore, bankStore, e.coinDenom, e.txDecoder) + } + + if err := ExecuteBlockWithEstimates( + ctx, + blockSize, + index, + stmMultiStoreWrapper{ms}, + e.workers, + estimates, + func(txn TxnIndex, ms MultiStore) { + var cache map[string]any + + // only one of the concurrent incarnations gets the cache if there are any, otherwise execute without + // cache, concurrent incarnations should be rare. + v := incarnationCache[txn].Swap(nil) + if v != nil { + cache = *v + } + + var memTx []byte + if memTxs != nil { + memTx = memTxs[txn] + } + results[txn] = deliverTx(memTx, msWrapper{ms}, int(txn), cache) + + if v != nil { + incarnationCache[txn].Store(v) + } + }, + ); err != nil { + return nil, err + } + + return results, nil +} + +// preEstimates returns a static estimation of the written keys for each transaction. +// NOTE: make sure it sync with the latest sdk logic when sdk upgrade. +func preEstimates(txs [][]byte, workers, authStore, bankStore int, coinDenom string, txDecoder sdk.TxDecoder) ([][]byte, []MultiLocations) { + memTxs := make([][]byte, len(txs)) + estimates := make([]MultiLocations, len(txs)) + + job := func(start, end int) { + for i := start; i < end; i++ { + rawTx := txs[i] + tx, err := txDecoder(rawTx) + if err != nil { + continue + } + memTxs[i] = rawTx + + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + continue + } + feePayer := sdk.AccAddress(feeTx.FeePayer()) + + // account key + accKey, err := collections.EncodeKeyWithPrefix( + collections.NewPrefix(1), + sdk.AccAddressKey, + feePayer, + ) + if err != nil { + continue + } + + // balance key + balanceKey, err := collections.EncodeKeyWithPrefix( + collections.NewPrefix(2), + collections.PairKeyCodec(sdk.AccAddressKey, collections.StringKey), + collections.Join(feePayer, coinDenom), + ) + if err != nil { + continue + } + + estimates[i] = MultiLocations{ + authStore: {accKey}, + bankStore: {balanceKey}, + } + } + } + + blockSize := len(txs) + chunk := (blockSize + workers - 1) / workers + var wg sync.WaitGroup + for i := 0; i < blockSize; i += chunk { + start := i + end := min(i+chunk, blockSize) + wg.Add(1) + go func() { + defer wg.Done() + job(start, end) + }() + } + wg.Wait() + + return memTxs, estimates +} diff --git a/blockstm/txnrunner_test.go b/blockstm/txnrunner_test.go new file mode 100644 index 000000000000..76791b0e6951 --- /dev/null +++ b/blockstm/txnrunner_test.go @@ -0,0 +1,701 @@ +package blockstm + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + abci "github.com/cometbft/cometbft/abci/types" + "github.com/stretchr/testify/require" + protov2 "google.golang.org/protobuf/proto" + + "cosmossdk.io/collections" + storetypes "cosmossdk.io/store/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +// Mock TxDecoder for testing +func mockTxDecoder(txBytes []byte) (sdk.Tx, error) { + if len(txBytes) == 0 { + return nil, errors.New("empty tx") + } + // Valid transaction if first byte is not 0xFF + if txBytes[0] == 0xFF { + return nil, errors.New("invalid tx") + } + return &mockTx{txBytes: txBytes}, nil +} + +type mockTx struct { + txBytes []byte +} + +func (m *mockTx) GetMsgs() []sdk.Msg { + return nil +} + +func (m *mockTx) GetMsgsV2() ([]protov2.Message, error) { + return nil, nil +} + +func (m *mockTx) ValidateBasic() error { + return nil +} + +type mockFeeTx struct { + mockTx + feePayer sdk.AccAddress +} + +func (m *mockFeeTx) FeePayer() []byte { + return m.feePayer +} + +func (m *mockFeeTx) GetFee() sdk.Coins { + return nil +} + +func (m *mockFeeTx) GetGas() uint64 { + return 0 +} + +func mockTxDecoderWithFeeTx(txBytes []byte) (sdk.Tx, error) { + if len(txBytes) == 0 { + return nil, errors.New("empty tx") + } + if txBytes[0] == 0xFF { + return nil, errors.New("invalid tx") + } + // Use the tx bytes as the fee payer address for testing + feePayer := sdk.AccAddress(txBytes[:min(len(txBytes), 20)]) + return &mockFeeTx{ + mockTx: mockTx{txBytes: txBytes}, + feePayer: feePayer, + }, nil +} + +// TestNewDefaultRunner tests the constructor +func TestNewDefaultRunner(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + require.NotNil(t, runner) + require.NotNil(t, runner.txDecoder) +} + +// TestDefaultRunner_Run_Success tests successful execution of transactions +func TestDefaultRunner_Run_Success(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + txs := [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, + } + + executionCount := atomic.Int32{} + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + executionCount.Add(1) + return &abci.ExecTxResult{ + Code: 0, + Data: tx, + } + } + + ctx := context.Background() + results, err := runner.Run(ctx, nil, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + require.Equal(t, int32(len(txs)), executionCount.Load()) + + for i, result := range results { + require.Equal(t, uint32(0), result.Code) + require.Equal(t, txs[i], result.Data) + } +} + +// TestDefaultRunner_Run_EmptyTxs tests execution with no transactions +func TestDefaultRunner_Run_EmptyTxs(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + t.Fatal("deliverTx should not be called for empty txs") + return nil + } + + ctx := context.Background() + results, err := runner.Run(ctx, nil, [][]byte{}, deliverTx) + + require.NoError(t, err) + require.Empty(t, results) +} + +// TestDefaultRunner_Run_InvalidTx tests handling of invalid transactions +func TestDefaultRunner_Run_InvalidTx(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + txs := [][]byte{ + {0x01, 0x02, 0x03}, // valid + {0xFF, 0xFF, 0xFF}, // invalid (0xFF marker) + {0x07, 0x08, 0x09}, // valid + } + + validTxCount := atomic.Int32{} + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + validTxCount.Add(1) + return &abci.ExecTxResult{Code: 0} + } + + ctx := context.Background() + results, err := runner.Run(ctx, nil, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + // Only 2 valid transactions should be executed + require.Equal(t, int32(2), validTxCount.Load()) + + // The invalid tx should get an error response + require.Equal(t, sdkerrors.ErrTxDecode.ABCICode(), results[1].Code) +} + +// TestDefaultRunner_Run_ContextCancellation tests that execution stops on context cancellation +func TestDefaultRunner_Run_ContextCancellation(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + txs := [][]byte{ + {0x01, 0x02, 0x03}, + {0x04, 0x05, 0x06}, + {0x07, 0x08, 0x09}, + {0x0A, 0x0B, 0x0C}, + } + + ctx, cancel := context.WithCancel(context.Background()) + + executionCount := atomic.Int32{} + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + count := executionCount.Add(1) + // Cancel after second transaction + if count == 2 { + cancel() + } + return &abci.ExecTxResult{Code: 0} + } + + _, err := runner.Run(ctx, nil, txs, deliverTx) + + require.Error(t, err) + require.Equal(t, context.Canceled, err) + // Results may be nil or partial depending on when cancellation occurs + // The key assertion is that execution was stopped + require.LessOrEqual(t, executionCount.Load(), int32(len(txs))) +} + +// TestDefaultRunner_Run_MultiStoreIsNil tests that nil multistore is handled correctly +func TestDefaultRunner_Run_MultiStoreIsNil(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + txs := [][]byte{{0x01}} + + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + require.Nil(t, ms, "multistore should be nil for DefaultRunner") + require.Nil(t, cache, "cache should be nil for DefaultRunner") + return &abci.ExecTxResult{Code: 0} + } + + ctx := context.Background() + results, err := runner.Run(ctx, nil, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, 1) +} + +// TestNewSTMRunner tests the STMRunner constructor +func TestNewSTMRunner(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + workers := 4 + estimate := true + coinDenom := "stake" + + runner := NewSTMRunner(decoder, stores, workers, estimate, coinDenom) + + require.NotNil(t, runner) + require.NotNil(t, runner.txDecoder) + require.Equal(t, stores, runner.stores) + require.Equal(t, workers, runner.workers) + require.Equal(t, estimate, runner.estimate) + require.Equal(t, coinDenom, runner.coinDenom) +} + +// TestSTMRunner_Run_EmptyBlock tests STMRunner with empty block +func TestSTMRunner_Run_EmptyBlock(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 4, false, "stake") + + ctx := context.Background() + ms := msWrapper{NewMultiMemDB(map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + })} + + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + t.Fatal("deliverTx should not be called for empty block") + return nil + } + + results, err := runner.Run(ctx, ms, [][]byte{}, deliverTx) + + require.NoError(t, err) + require.Nil(t, results) +} + +// TestSTMRunner_Run_WithoutEstimation tests STMRunner without pre-estimation +func TestSTMRunner_Run_WithoutEstimation(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, false, "stake") + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + txs := [][]byte{ + {0x01}, + {0x02}, + {0x03}, + } + + executionCount := atomic.Int32{} + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + executionCount.Add(1) + require.NotNil(t, ms) + return &abci.ExecTxResult{Code: 0} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + // STM may execute transactions multiple times due to conflicts + require.True(t, executionCount.Load() >= int32(len(txs))) +} + +// TestSTMRunner_Run_WithEstimation tests STMRunner with pre-estimation enabled +func TestSTMRunner_Run_WithEstimation(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, true, "stake") + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + // Create transactions with valid structure for estimation + addr1 := []byte("addr1") + addr2 := []byte("addr2") + txs := [][]byte{ + append(addr1, 0x01), + append(addr2, 0x02), + } + + executionCount := atomic.Int32{} + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + executionCount.Add(1) + require.NotNil(t, ms) + return &abci.ExecTxResult{Code: 0} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) +} + +// TestSTMRunner_Run_IncarnationCache tests that incarnation cache is properly managed +func TestSTMRunner_Run_IncarnationCache(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, false, "stake") + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + txs := [][]byte{ + {0x01}, + {0x02}, + } + + cacheReceived := make([]bool, len(txs)) + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + if cache != nil { + cacheReceived[txIndex] = true + } + return &abci.ExecTxResult{Code: 0} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + // Each transaction should receive a cache (even if empty) + for i, received := range cacheReceived { + require.True(t, received, "transaction %d should receive cache", i) + } +} + +// TestSTMRunner_Run_StoreIndexMapping tests that store keys are correctly mapped +func TestSTMRunner_Run_StoreIndexMapping(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, false, "stake") + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + txs := [][]byte{{0x01}} + + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + // Verify we can access both stores + authStore := ms.GetKVStore(StoreKeyAuth) + bankStore := ms.GetKVStore(StoreKeyBank) + require.NotNil(t, authStore) + require.NotNil(t, bankStore) + return &abci.ExecTxResult{Code: 0} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, 1) +} + +// TestSTMRunner_Run_ContextCancellation tests context cancellation for STMRunner +func TestSTMRunner_Run_ContextCancellation(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, false, "stake") + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + // Create a large block to ensure context timeout + txs := make([][]byte, 1000) + for i := range txs { + txs[i] = []byte{byte(i % 256)} + } + + deliverTx := func(tx []byte, ms storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + time.Sleep(1 * time.Millisecond) // Slow down execution + return &abci.ExecTxResult{Code: 0} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + // Should error due to context cancellation + require.Error(t, err) + require.Nil(t, results) +} + +// TestPreEstimates tests the preEstimates function +func TestPreEstimates(t *testing.T) { + t.Run("empty transactions", func(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + memTxs, estimates := preEstimates([][]byte{}, 2, 0, 1, "stake", decoder) + + require.Empty(t, memTxs) + require.Empty(t, estimates) + }) + + t.Run("valid transactions with estimation", func(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + + // Create test addresses + addr1 := sdk.AccAddress([]byte("address1")) + addr2 := sdk.AccAddress([]byte("address2")) + + txs := [][]byte{ + append(addr1, 0x01), + append(addr2, 0x02), + } + + memTxs, estimates := preEstimates(txs, 2, 0, 1, "stake", decoder) + + require.Len(t, memTxs, len(txs)) + require.Len(t, estimates, len(txs)) + + // Check that estimates are generated for valid transactions + for i, estimate := range estimates { + if estimate != nil { + // Should have auth store estimate (index 0) + require.Contains(t, estimate, 0, "transaction %d should have auth store estimate", i) + // Should have bank store estimate (index 1) + require.Contains(t, estimate, 1, "transaction %d should have bank store estimate", i) + } + } + }) + + t.Run("invalid transactions", func(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + + txs := [][]byte{ + {0xFF, 0xFF}, // invalid + {0x01, 0x02}, // valid + } + + memTxs, estimates := preEstimates(txs, 2, 0, 1, "stake", decoder) + + require.Len(t, memTxs, len(txs)) + require.Len(t, estimates, len(txs)) + + // Invalid transaction should not have memTx or estimates + require.Nil(t, memTxs[0]) + require.Nil(t, estimates[0]) + + // Valid transaction should have memTx + require.NotNil(t, memTxs[1]) + }) + + t.Run("parallel processing with multiple workers", func(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + + // Create many transactions + txs := make([][]byte, 100) + for i := range txs { + addr := sdk.AccAddress([]byte{byte(i)}) + txs[i] = append(addr, byte(i)) + } + + memTxs, estimates := preEstimates(txs, 4, 0, 1, "stake", decoder) + + require.Len(t, memTxs, len(txs)) + require.Len(t, estimates, len(txs)) + }) + + t.Run("non-FeeTx transactions", func(t *testing.T) { + // Use decoder that doesn't return FeeTx + decoder := mockTxDecoder + + txs := [][]byte{ + {0x01, 0x02}, + {0x03, 0x04}, + } + + memTxs, estimates := preEstimates(txs, 2, 0, 1, "stake", decoder) + + require.Len(t, memTxs, len(txs)) + require.Len(t, estimates, len(txs)) + + // Non-FeeTx should not have estimates + for _, estimate := range estimates { + require.Nil(t, estimate) + } + }) +} + +// TestPreEstimates_KeyEncoding tests that account and balance keys are correctly encoded +func TestPreEstimates_KeyEncoding(t *testing.T) { + decoder := mockTxDecoderWithFeeTx + + addr := sdk.AccAddress([]byte("testaddress12345")) + tx := append(addr, 0x01) + + memTxs, estimates := preEstimates([][]byte{tx}, 1, 0, 1, "stake", decoder) + + require.Len(t, memTxs, 1) + require.Len(t, estimates, 1) + + if estimates[0] != nil { + // Verify account key encoding + authEstimate := estimates[0][0] + require.NotEmpty(t, authEstimate) + + // The key should be properly encoded + expectedAccKey, err := collections.EncodeKeyWithPrefix( + collections.NewPrefix(1), + sdk.AccAddressKey, + addr, + ) + require.NoError(t, err) + require.Contains(t, authEstimate, expectedAccKey) + + // Verify balance key encoding + bankEstimate := estimates[0][1] + require.NotEmpty(t, bankEstimate) + + expectedBalanceKey, err := collections.EncodeKeyWithPrefix( + collections.NewPrefix(2), + collections.PairKeyCodec(sdk.AccAddressKey, collections.StringKey), + collections.Join(addr, "stake"), + ) + require.NoError(t, err) + require.Contains(t, bankEstimate, expectedBalanceKey) + } +} + +// TestTxRunnerInterface tests that both runners implement TxRunner interface +func TestTxRunnerInterface(t *testing.T) { + decoder := mockTxDecoder + + var _ sdk.TxRunner = NewDefaultRunner(decoder) + var _ sdk.TxRunner = NewSTMRunner(decoder, []storetypes.StoreKey{}, 1, false, "") +} + +// TestSTMRunner_Integration tests integration between STMRunner and actual block execution +func TestSTMRunner_Integration(t *testing.T) { + decoder := mockTxDecoder + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 4, false, "stake") + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + + // Create a mock block with actual transactions + blk := testBlock(20, 10) + + // Use STMRunner to execute + var results []*abci.ExecTxResult + deliverTx := func(tx []byte, mstore storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + // Execute using the mock block's transaction logic + if txIndex < blk.Size() { + // Convert multistore wrapper to MultiStore for block execution + if wrapper, ok := mstore.(msWrapper); ok { + blk.ExecuteTx(TxnIndex(txIndex), wrapper.MultiStore) + } + } + return &abci.ExecTxResult{Code: 0} + } + + // Create raw tx bytes for the runner + txs := make([][]byte, blk.Size()) + for i := range txs { + txs[i] = []byte{byte(i)} + } + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, blk.Size()) +} + +// TestDefaultRunner_Integration tests integration with sequential execution +func TestDefaultRunner_Integration(t *testing.T) { + decoder := mockTxDecoder + runner := NewDefaultRunner(decoder) + + ctx := context.Background() + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + msRaw := NewMultiMemDB(storeIndex) + + // Create a mock block + blk := noConflictBlock(10) + + deliverTx := func(tx []byte, _ storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + if txIndex < blk.Size() { + blk.ExecuteTx(TxnIndex(txIndex), msRaw) + } + return &abci.ExecTxResult{Code: 0} + } + + txs := make([][]byte, blk.Size()) + for i := range txs { + txs[i] = []byte{byte(i)} + } + + results, err := runner.Run(ctx, nil, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, blk.Size()) + + // Verify all transactions succeeded + for i, err := range blk.Results { + require.NoError(t, err, "transaction %d should succeed", i) + } +} + +// TestRunnerComparison tests that both DefaultRunner and STMRunner can execute successfully +func TestRunnerComparison(t *testing.T) { + decoder := mockTxDecoder + + txs := [][]byte{ + {0x01, 0x02}, + {0x03, 0x04}, + {0x05, 0x06}, + } + + executionCount := atomic.Int32{} + deliverTx := func(tx []byte, _ storetypes.MultiStore, txIndex int, cache map[string]any) *abci.ExecTxResult { + executionCount.Add(1) + return &abci.ExecTxResult{Code: 0, Data: tx} + } + + ctx := context.Background() + + // Test DefaultRunner + t.Run("DefaultRunner", func(t *testing.T) { + runner := NewDefaultRunner(decoder) + executionCount.Store(0) + + results, err := runner.Run(ctx, nil, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + require.Equal(t, int32(len(txs)), executionCount.Load()) + }) + + // Test STMRunner + t.Run("STMRunner", func(t *testing.T) { + stores := []storetypes.StoreKey{StoreKeyAuth, StoreKeyBank} + runner := NewSTMRunner(decoder, stores, 2, false, "stake") + storeIndex := map[storetypes.StoreKey]int{ + StoreKeyAuth: 0, + StoreKeyBank: 1, + } + ms := msWrapper{NewMultiMemDB(storeIndex)} + executionCount.Store(0) + + results, err := runner.Run(ctx, ms, txs, deliverTx) + + require.NoError(t, err) + require.Len(t, results, len(txs)) + // STM may execute more times due to conflicts + require.True(t, executionCount.Load() >= int32(len(txs))) + }) +} diff --git a/blockstm/types.go b/blockstm/types.go new file mode 100644 index 000000000000..873647b2f077 --- /dev/null +++ b/blockstm/types.go @@ -0,0 +1,84 @@ +package blockstm + +import ( + storetypes "cosmossdk.io/store/types" +) + +const ( + TelemetrySubsystem = "blockstm" + KeyExecutedTxs = "executed_txs" + KeyValidatedTxs = "validated_txs" +) + +type ( + TxnIndex int + Incarnation uint +) + +type TxnVersion struct { + Index TxnIndex + Incarnation Incarnation +} + +var InvalidTxnVersion = TxnVersion{-1, 0} + +func (v TxnVersion) Valid() bool { + return v.Index >= 0 +} + +type Key []byte + +type ReadDescriptor struct { + Key Key + // invalid Version means the key is read from storage + Version TxnVersion +} + +type IteratorOptions struct { + // [Start, End) is the range of the iterator + Start Key + End Key + Ascending bool +} + +type IteratorDescriptor struct { + IteratorOptions + // Stop is not `nil` if the iteration is not exhausted and stops at a key before reaching the end of the range, + // the effective range is `[start, stop]`. + // when replaying, it should also stops at the stop key. + Stop Key + // Reads is the list of keys that is observed by the iterator. + Reads []ReadDescriptor +} + +type ReadSet struct { + Reads []ReadDescriptor + Iterators []IteratorDescriptor +} + +type MultiReadSet = map[int]*ReadSet + +// TxExecutor executes transactions on top of a multi-version memory view. +type TxExecutor func(TxnIndex, MultiStore) + +type MultiStore interface { + GetStore(storetypes.StoreKey) storetypes.Store + GetKVStore(storetypes.StoreKey) storetypes.KVStore + GetObjKVStore(storetypes.StoreKey) storetypes.ObjKVStore +} + +// MVStore is a value type agnostic interface for `MVData`, to keep `MVMemory` value type agnostic. +type MVStore interface { + Delete(Key, TxnIndex) + WriteEstimate(Key, TxnIndex) + ValidateReadSet(TxnIndex, *ReadSet) bool + SnapshotToStore(storetypes.Store) +} + +// MVView is a value type agnostic interface for `MVMemoryView`, to keep `MultiMVMemoryView` value type agnostic. +type MVView interface { + storetypes.Store + + ApplyWriteSet(TxnVersion) Locations + ReadSet() *ReadSet +} diff --git a/blockstm/utils.go b/blockstm/utils.go new file mode 100644 index 000000000000..4909d8f6993e --- /dev/null +++ b/blockstm/utils.go @@ -0,0 +1,84 @@ +package blockstm + +import ( + "bytes" + "fmt" + "sync/atomic" +) + +type ErrReadError struct { + BlockingTxn TxnIndex +} + +func (e ErrReadError) Error() string { + return fmt.Sprintf("read error: blocked by txn %d", e.BlockingTxn) +} + +// StoreMin implements a compare-and-swap operation that stores the minimum of the current value and the given value. +func StoreMin(a *atomic.Uint64, b uint64) { + for { + old := a.Load() + if old <= b { + return + } + if a.CompareAndSwap(old, b) { + return + } + } +} + +// DecrAtomic decreases the atomic value by 1 +func DecrAtomic(a *atomic.Uint64) { + a.Add(^uint64(0)) +} + +// IncrAtomic increases the atomic value by 1 +func IncrAtomic(a *atomic.Uint64) { + a.Add(1) +} + +// FetchIncr increaes the atomic value by 1 and returns the old value +func FetchIncr(a *atomic.Uint64) uint64 { + return a.Add(1) - 1 +} + +// DiffOrderedList compares two ordered lists +// callback arguments: (value, is_new) +func DiffOrderedList(old, new []Key, callback func(Key, bool) bool) { + i, j := 0, 0 + for i < len(old) && j < len(new) { + switch bytes.Compare(old[i], new[j]) { + case -1: + if !callback(old[i], false) { + return + } + i++ + case 1: + if !callback(new[j], true) { + return + } + j++ + default: + i++ + j++ + } + } + for ; i < len(old); i++ { + if !callback(old[i], false) { + return + } + } + for ; j < len(new); j++ { + if !callback(new[j], true) { + return + } + } +} + +// BytesBeyond returns if a is beyond b in specified iteration order +func BytesBeyond(a, b []byte, ascending bool) bool { + if ascending { + return bytes.Compare(a, b) > 0 + } + return bytes.Compare(a, b) < 0 +} diff --git a/blockstm/utils_test.go b/blockstm/utils_test.go new file mode 100644 index 000000000000..cc7d644b5689 --- /dev/null +++ b/blockstm/utils_test.go @@ -0,0 +1,81 @@ +package blockstm + +import ( + "testing" + + "github.com/test-go/testify/require" +) + +type DiffEntry struct { + Key Key + IsNew bool +} + +func TestDiffOrderedList(t *testing.T) { + testCases := []struct { + name string + old []Key + new []Key + expected []DiffEntry + }{ + { + name: "empty lists", + old: []Key{}, + new: []Key{}, + expected: []DiffEntry{}, + }, + { + name: "old is longer", + old: []Key{ + []byte("a"), + []byte("b"), + []byte("c"), + []byte("d"), + []byte("e"), + }, + new: []Key{ + []byte("b"), + []byte("c"), + []byte("f"), + }, + expected: []DiffEntry{ + {Key: []byte("a"), IsNew: false}, + {Key: []byte("d"), IsNew: false}, + {Key: []byte("e"), IsNew: false}, + {Key: []byte("f"), IsNew: true}, + }, + }, + { + name: "new is longer", + old: []Key{ + []byte("a"), + []byte("c"), + []byte("e"), + }, + new: []Key{ + []byte("b"), + []byte("c"), + []byte("d"), + []byte("e"), + []byte("f"), + }, + expected: []DiffEntry{ + {Key: []byte("a"), IsNew: false}, + {Key: []byte("b"), IsNew: true}, + {Key: []byte("d"), IsNew: true}, + {Key: []byte("f"), IsNew: true}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := []DiffEntry{} + DiffOrderedList(tc.old, tc.new, func(key Key, leftOrRight bool) bool { + result = append(result, DiffEntry{key, leftOrRight}) + return true + }) + require.Equal(t, tc.expected, result) + }) + } +} diff --git a/blockstm/wrappers.go b/blockstm/wrappers.go new file mode 100644 index 000000000000..bc5b2d667216 --- /dev/null +++ b/blockstm/wrappers.go @@ -0,0 +1,89 @@ +package blockstm + +import ( + "io" + + "cosmossdk.io/store/cachemulti" + storetypes "cosmossdk.io/store/types" +) + +var ( + _ storetypes.MultiStore = msWrapper{} + _ MultiStore = stmMultiStoreWrapper{} +) + +type msWrapper struct { + MultiStore +} + +func (ms msWrapper) CacheWrapWithTrace(w io.Writer, tc storetypes.TraceContext) storetypes.CacheWrap { + // TODO implement me + panic("implement me") +} + +func (ms msWrapper) CacheMultiStoreWithVersion(version int64) (storetypes.CacheMultiStore, error) { + // TODO implement me + panic("implement me") +} + +func (ms msWrapper) LatestVersion() int64 { + // TODO implement me + panic("implement me") +} + +func (ms msWrapper) getCacheWrapper(key storetypes.StoreKey) storetypes.CacheWrapper { + return ms.GetStore(key) +} + +func (ms msWrapper) GetStore(key storetypes.StoreKey) storetypes.Store { + return ms.MultiStore.GetStore(key) +} + +func (ms msWrapper) GetKVStore(key storetypes.StoreKey) storetypes.KVStore { + return ms.MultiStore.GetKVStore(key) +} + +func (ms msWrapper) GetObjKVStore(key storetypes.StoreKey) storetypes.ObjKVStore { + return ms.MultiStore.GetObjKVStore(key) +} + +func (ms msWrapper) CacheMultiStore() storetypes.CacheMultiStore { + return cachemulti.NewFromParent(ms.getCacheWrapper, nil, nil) +} + +// CacheWrap Implements CacheWrapper. +func (ms msWrapper) CacheWrap() storetypes.CacheWrap { + return ms.CacheMultiStore().(storetypes.CacheWrap) +} + +// GetStoreType returns the type of the store. +func (ms msWrapper) GetStoreType() storetypes.StoreType { + return storetypes.StoreTypeMulti +} + +// SetTracer Implements interface MultiStore +func (ms msWrapper) SetTracer(io.Writer) storetypes.MultiStore { + return nil +} + +// SetTracingContext Implements interface MultiStore +func (ms msWrapper) SetTracingContext(storetypes.TraceContext) storetypes.MultiStore { + return nil +} + +// TracingEnabled Implements interface MultiStore +func (ms msWrapper) TracingEnabled() bool { + return false +} + +type stmMultiStoreWrapper struct { + storetypes.MultiStore +} + +func (ms stmMultiStoreWrapper) GetStore(key storetypes.StoreKey) storetypes.Store { + return ms.MultiStore.GetStore(key) +} + +func (ms stmMultiStoreWrapper) GetKVStore(key storetypes.StoreKey) storetypes.KVStore { + return ms.MultiStore.GetKVStore(key) +} diff --git a/go.mod b/go.mod index 1477bf9c3a64..01f4b5db4754 100644 --- a/go.mod +++ b/go.mod @@ -55,6 +55,8 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/tendermint/go-amino v0.16.0 + github.com/test-go/testify v1.1.4 + github.com/tidwall/btree v1.8.1 go.uber.org/mock v0.6.0 golang.org/x/crypto v0.43.0 golang.org/x/sync v0.17.0 @@ -195,7 +197,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/supranational/blst v0.3.16 // indirect github.com/syndtr/goleveldb v1.0.1-0.20220721030215-126854af5e6d // indirect - github.com/tidwall/btree v1.8.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ulikunitz/xz v0.5.15 // indirect github.com/zeebo/errs v1.4.0 // indirect @@ -247,6 +248,8 @@ replace ( github.com/gin-gonic/gin => github.com/gin-gonic/gin v1.9.1 // replace broken goleveldb github.com/syndtr/goleveldb => github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + // BlockSTM requires patches to the btree package + github.com/tidwall/btree => github.com/cosmos/btree v0.0.0-20250924232609-2c6195d95951 ) replace cosmossdk.io/store => ./store diff --git a/go.sum b/go.sum index 122b1d42ee97..2888a8635af6 100644 --- a/go.sum +++ b/go.sum @@ -217,6 +217,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cosmos/btcutil v1.0.5 h1:t+ZFcX77LpKtDBhjucvnOH8C2l2ioGsBNEQ3jef8xFk= github.com/cosmos/btcutil v1.0.5/go.mod h1:IyB7iuqZMJlthe2tkIFL33xPyzbFYP0XVdS8P5lUPis= +github.com/cosmos/btree v0.0.0-20250924232609-2c6195d95951 h1:dC3GJcS8bJiSEe7VAFDDFgFnVM1G9nBdGOgqJsmsZwM= +github.com/cosmos/btree v0.0.0-20250924232609-2c6195d95951/go.mod h1:jBbTdUWhSZClZWoDg54VnvV7/54modSOzDN7VXftj1A= github.com/cosmos/cosmos-db v1.1.3 h1:7QNT77+vkefostcKkhrzDK9uoIEryzFrU9eoMeaQOPY= github.com/cosmos/cosmos-db v1.1.3/go.mod h1:kN+wGsnwUJZYn8Sy5Q2O0vCYA99MJllkKASbs6Unb9U= github.com/cosmos/cosmos-proto v1.0.0-beta.5 h1:eNcayDLpip+zVLRLYafhzLvQlSmyab+RC5W7ZfmxJLA= @@ -796,8 +798,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tendermint/go-amino v0.16.0 h1:GyhmgQKvqF82e2oZeuMSp9JTN0N09emoSZlb2lyGa2E= github.com/tendermint/go-amino v0.16.0/go.mod h1:TQU0M1i/ImAo+tYpZi73AU3V/dKeCoMC9Sphe2ZwGME= -github.com/tidwall/btree v1.8.1 h1:27ehoXvm5AG/g+1VxLS1SD3vRhp/H7LuEfwNvddEdmA= -github.com/tidwall/btree v1.8.1/go.mod h1:jBbTdUWhSZClZWoDg54VnvV7/54modSOzDN7VXftj1A= +github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE= +github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= diff --git a/types/abci.go b/types/abci.go index b403da6e9a1e..06ca95d7b620 100644 --- a/types/abci.go +++ b/types/abci.go @@ -1,7 +1,11 @@ package types import ( + "context" + abci "github.com/cometbft/cometbft/abci/types" + + storetypes "cosmossdk.io/store/types" ) // ABCIHandlers aggregates all ABCI handlers needed for an application. @@ -94,5 +98,15 @@ func (r ResponsePreBlock) IsConsensusParamsChanged() bool { type RunTx = func(txBytes []byte, tx Tx) (gInfo GasInfo, result *Result, anteEvents []abci.Event, err error) +// DeliverTxFunc is the function called for each transaction in order to produce a single ExecTxResult +type DeliverTxFunc func(tx []byte, ms storetypes.MultiStore, txIndex int, incarnationCache map[string]any) *abci.ExecTxResult + +// TxRunner defines an interface for types which can be used to execute the DeliverTxFunc. +// It should return an array of *abci.ExecTxResult corresponding to the result of executing each transaction +// provided to the Run function. +type TxRunner interface { + Run(ctx context.Context, ms storetypes.MultiStore, txs [][]byte, deliverTx DeliverTxFunc) ([]*abci.ExecTxResult, error) +} + // PeerFilter responds to p2p filtering queries from Tendermint type PeerFilter func(info string) *abci.ResponseQuery