diff --git a/.github/workflows/pr-tests.yml b/.github/workflows/pr-tests.yml index b72c17302a..3bf04ed6d7 100644 --- a/.github/workflows/pr-tests.yml +++ b/.github/workflows/pr-tests.yml @@ -33,6 +33,16 @@ jobs: with: job_key: unit-tests-${{ matrix.package }} + - name: Enable unprivileged uffd mode + run: | + echo 1 | sudo tee /proc/sys/vm/unprivileged_userfaultfd + + - name: Enable hugepages + run: | + sudo mkdir -p /mnt/hugepages + sudo mount -t hugetlbfs none /mnt/hugepages + echo 128 | sudo tee /proc/sys/vm/nr_hugepages + - name: Run tests working-directory: ${{ matrix.package }} run: go test -v ${{ matrix.test_path }} diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index b0caf19411..c94c51410c 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -1,66 +1,77 @@ package block import ( - "context" - "fmt" "sync" - "sync/atomic" "github.com/bits-and-blooms/bitset" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -type TrackedSliceDevice struct { - data ReadonlyDevice - blockSize int64 +type Tracker struct { + b *bitset.BitSet + mu sync.RWMutex - nilTracking atomic.Bool - dirty *bitset.BitSet - dirtyMu sync.Mutex - empty []byte + blockSize int64 } -func NewTrackedSliceDevice(blockSize int64, device ReadonlyDevice) (*TrackedSliceDevice, error) { - return &TrackedSliceDevice{ - data: device, - empty: make([]byte, blockSize), +func NewTracker(blockSize int64) *Tracker { + return &Tracker{ + // The bitset resizes automatically based on the maximum set bit. + b: bitset.New(0), blockSize: blockSize, - }, nil + } } -func (t *TrackedSliceDevice) Disable() error { - size, err := t.data.Size() - if err != nil { - return fmt.Errorf("failed to get device size: %w", err) +func NewTrackerFromBitset(b *bitset.BitSet, blockSize int64) *Tracker { + return &Tracker{ + b: b, + blockSize: blockSize, } +} - t.dirty = bitset.New(uint(header.TotalBlocks(size, t.blockSize))) - // We are starting with all being dirty. - t.dirty.FlipRange(0, t.dirty.Len()) - - t.nilTracking.Store(true) +func (t *Tracker) Has(off int64) bool { + t.mu.RLock() + defer t.mu.RUnlock() - return nil + return t.b.Test(uint(header.BlockIdx(off, t.blockSize))) } -func (t *TrackedSliceDevice) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - if t.nilTracking.Load() { - t.dirtyMu.Lock() - t.dirty.Clear(uint(header.BlockIdx(off, t.blockSize))) - t.dirtyMu.Unlock() +func (t *Tracker) Add(off int64) bool { + t.mu.Lock() + defer t.mu.Unlock() - return t.empty, nil + if t.b.Test(uint(header.BlockIdx(off, t.blockSize))) { + return false } - return t.data.Slice(ctx, off, length) + t.b.Set(uint(header.BlockIdx(off, t.blockSize))) + + return true } -// Return which bytes were not read since Disable. -// This effectively returns the bytes that have been requested after paused vm and are not dirty. -func (t *TrackedSliceDevice) Dirty() *bitset.BitSet { - t.dirtyMu.Lock() - defer t.dirtyMu.Unlock() +func (t *Tracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() - return t.dirty.Clone() + t.b.ClearAll() +} + +// BitSet returns a clone of the bitset and the block size. +func (t *Tracker) BitSet() *bitset.BitSet { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.b.Clone() +} + +func (t *Tracker) BlockSize() int64 { + return t.blockSize +} + +func (t *Tracker) Clone() *Tracker { + return &Tracker{ + b: t.BitSet(), + blockSize: t.BlockSize(), + } } diff --git a/packages/orchestrator/internal/sandbox/block/tracker_test.go b/packages/orchestrator/internal/sandbox/block/tracker_test.go new file mode 100644 index 0000000000..75d1f58fd2 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/tracker_test.go @@ -0,0 +1,109 @@ +package block + +import ( + "testing" +) + +func TestTracker_AddAndHas(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Initially should not be marked + if tr.Has(offset) { + t.Errorf("Expected offset %d not to be marked initially", offset) + } + + // After adding, should be marked + tr.Add(offset) + if !tr.Has(offset) { + t.Errorf("Expected offset %d to be marked after Add", offset) + } +} + +func TestTracker_Reset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Add offset and verify it's marked + tr.Add(offset) + if !tr.Has(offset) { + t.Errorf("Expected offset %d to be marked after Add", offset) + } + + // After reset, should not be marked + tr.Reset() + if tr.Has(offset) { + t.Errorf("Expected offset %d to be cleared after Reset", offset) + } +} + +func TestTracker_MultipleOffsets(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Verify all offsets are marked + for _, o := range offsets { + if !tr.Has(o) { + t.Errorf("Expected offset %d to be marked", o) + } + } +} + +func TestTracker_ResetClearsAll(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Reset should clear all + tr.Reset() + + // Verify all offsets are cleared + for _, o := range offsets { + if tr.Has(o) { + t.Errorf("Expected offset %d to be cleared after Reset", o) + } + } +} + +func TestTracker_MisalignedOffset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + // Test with misaligned offset + misalignedOffset := int64(123) + tr.Add(misalignedOffset) + + // Should be set for the block containing the offset—that is, block 0 (0..4095) + if !tr.Has(misalignedOffset) { + t.Errorf("Expected misaligned offset %d to be marked (should mark its containing block)", misalignedOffset) + } + + // Now check that any offset in the same block is also considered marked + anotherOffsetInSameBlock := int64(1000) + if !tr.Has(anotherOffsetInSameBlock) { + t.Errorf("Expected offset %d to be marked as in same block as %d", anotherOffsetInSameBlock, misalignedOffset) + } + + // But not for a different block + offsetInNextBlock := int64(pageSize) // convert to int64 to match Has signature + if tr.Has(offsetInNextBlock) { + t.Errorf("Did not expect offset %d to be marked", offsetInNextBlock) + } +} diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index ffd1031b54..d637300a1d 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -674,10 +674,15 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to pause VM: %w", err) } - if err := s.memory.Disable(); err != nil { + if err := s.memory.Disable(ctx); err != nil { return nil, fmt.Errorf("failed to disable uffd: %w", err) } + dirty, err := s.memory.Dirty(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty pages: %w", err) + } + // Snapfile is not closed as it's returned and cached for later use (like resume) snapfile := template.NewLocalFileLink(snapshotTemplateFiles.CacheSnapfilePath()) // Memfile is also closed on diff creation processing @@ -721,7 +726,7 @@ func (s *Sandbox) Pause( originalMemfile.Header(), &MemoryDiffCreator{ memfile: memfile, - dirtyPages: s.memory.Dirty(), + dirtyPages: dirty.BitSet(), blockSize: originalMemfile.BlockSize(), doneHook: func(context.Context) error { return memfile.Close() @@ -933,7 +938,7 @@ func serveMemory( ctx, span := tracer.Start(ctx, "serve-memory") defer span.End() - fcUffd, err := uffd.New(memfile, socketPath, memfile.BlockSize()) + fcUffd, err := uffd.New(memfile, socketPath) if err != nil { return nil, fmt.Errorf("failed to create uffd: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go b/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go deleted file mode 100644 index dbff8bbd7a..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go +++ /dev/null @@ -1,32 +0,0 @@ -package mapping - -import "fmt" - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - // This is actually in bytes. - // This field is deprecated in the newer version of the Firecracer with a new field `page_size`. - PageSize uintptr `json:"page_size_kib"` -} - -func (m *GuestRegionUffdMapping) relativeOffset(addr uintptr) int64 { - return int64(m.Offset + addr - m.BaseHostVirtAddr) -} - -type FcMappings []GuestRegionUffdMapping - -// Returns the relative offset and the page size of the mapped range for a given address -func (m FcMappings) GetRange(addr uintptr) (int64, int64, error) { - for _, m := range m { - if addr < m.BaseHostVirtAddr || m.BaseHostVirtAddr+m.Size <= addr { - // Outside of this mapping - continue - } - - return m.relativeOffset(addr), int64(m.PageSize), nil - } - - return 0, 0, fmt.Errorf("address %d not found in any mapping", addr) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go b/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go deleted file mode 100644 index ffa8c4a6fa..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go +++ /dev/null @@ -1,5 +0,0 @@ -package mapping - -type Mappings interface { - GetRange(addr uintptr) (offset int64, pagesize int64, err error) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go new file mode 100644 index 0000000000..fab88886fd --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -0,0 +1,24 @@ +package memory + +import ( + "fmt" +) + +type Mapping struct { + Regions []Region +} + +func NewMapping(regions []Region) *Mapping { + return &Mapping{Regions: regions} +} + +// GetOffset returns the relative offset and the page size of the mapped range for a given address. +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { + for _, r := range m.Regions { + if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { + return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + } + } + + return 0, 0, fmt.Errorf("address %d not found in any mapping", hostVirtAddr) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go new file mode 100644 index 0000000000..ddddce8e0d --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go @@ -0,0 +1,168 @@ +package memory + +import ( + "testing" +) + +func TestMapping_GetOffset(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + hostVirtAddr uintptr + expectedOffset int64 + expectedSize uint64 + expectError bool + }{ + { + name: "valid address in first region", + hostVirtAddr: 0x1500, + expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedSize: 4096, + expectError: false, + }, + { + name: "valid address in second region", + hostVirtAddr: 0x5500, + expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) + expectedSize: 4096, + expectError: false, + }, + { + name: "address before first region", + hostVirtAddr: 0x500, + expectError: true, + }, + { + name: "address after last region", + hostVirtAddr: 0x7000, + expectError: true, + }, + { + name: "address in gap between regions", + hostVirtAddr: 0x4000, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offset, size, err := mapping.GetOffset(tt.hostVirtAddr) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + + return + } + + if offset != tt.expectedOffset { + t.Errorf("Expected offset %d, got %d", tt.expectedOffset, offset) + } + + if size != tt.expectedSize { + t.Errorf("Expected size %d, got %d", tt.expectedSize, size) + } + }) + } +} + +func TestMapping_EmptyRegions(t *testing.T) { + mapping := NewMapping([]Region{}) + + // Test GetOffset with empty regions + _, _, err := mapping.GetOffset(0x1000) + if err == nil { + t.Errorf("Expected error for empty regions, got none") + } +} + +func TestMapping_OverlappingRegions(t *testing.T) { + // Test with overlapping regions (edge case) + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + { + BaseHostVirtAddr: 0x2000, // Overlaps with first region + Size: 0x1000, + Offset: 0x8000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + // The first matching region should be returned + offset, _, err := mapping.GetOffset(0x2500) // In overlap area + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Should get result from first region + expectedOffset := int64(0x5000 + (0x2500 - 0x1000)) // 0x6500 + if offset != expectedOffset { + t.Errorf("Expected offset %d, got %d", expectedOffset, offset) + } +} + +func TestMapping_BoundaryConditions(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + // Test exact start boundary + offset, _, err := mapping.GetOffset(0x1000) + if err != nil { + t.Errorf("Unexpected error at start boundary: %v", err) + } + expectedOffset := int64(0x5000) // 0x5000 + (0x1000 - 0x1000) + if offset != expectedOffset { + t.Errorf("Expected offset %d at start boundary, got %d", expectedOffset, offset) + } + + // Test just before end boundary (exclusive) + offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 + if err != nil { + t.Errorf("Unexpected error just before end boundary: %v", err) + } + expectedOffset = int64(0x5000 + (0x2FFF - 0x1000)) // 0x6FFF + if offset != expectedOffset { + t.Errorf("Expected offset %d just before end boundary, got %d", expectedOffset, offset) + } + + // Test exact end boundary (should fail - exclusive) + _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 + if err == nil { + t.Errorf("Expected error at end boundary (exclusive), got none") + } +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go new file mode 100644 index 0000000000..5670603380 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -0,0 +1,23 @@ +package memory + +// Region is a mapping of a region of memory of the guest to a region of memory on the host. +// The serialization is based on the Firecracker UFFD protocol communication. +type Region struct { + BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` + Size uintptr `json:"size"` + Offset uintptr `json:"offset"` + // This is actually in bytes. + // This field is deprecated in the newer version of the Firecracer with a new field `page_size`. + PageSize uintptr `json:"page_size_kib"` +} + +// endHostVirtAddr returns the end address of the region in host virtual address. +// The end address is exclusive. +func (r *Region) endHostVirtAddr() uintptr { + return r.BaseHostVirtAddr + r.Size +} + +// shiftedOffset returns the offset of the given address in the region. +func (r *Region) shiftedOffset(addr uintptr) int64 { + return int64(addr - r.BaseHostVirtAddr + r.Offset) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index 4c65f5d977..84abf75a57 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -3,14 +3,14 @@ package uffd import ( "context" - "github.com/bits-and-blooms/bitset" - + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) type MemoryBackend interface { - Disable() error - Dirty() *bitset.BitSet + Dirty(ctx context.Context) (*block.Tracker, error) + // Disable switch the uffd to start serving empty pages. + Disable(ctx context.Context) error Start(ctx context.Context, sandboxId string) error Stop() error diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index e323828d08..2507d38cec 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -5,6 +5,7 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -13,7 +14,7 @@ type NoopMemory struct { size int64 blockSize int64 - dirty *bitset.BitSet + dirty *block.Tracker exit *utils.ErrorOnce } @@ -23,23 +24,23 @@ var _ MemoryBackend = (*NoopMemory)(nil) func NewNoopMemory(size, blockSize int64) *NoopMemory { blocks := header.TotalBlocks(size, blockSize) - dirty := bitset.New(uint(blocks)) - dirty.FlipRange(0, dirty.Len()) + b := bitset.New(uint(blocks)) + b.FlipRange(0, b.Len()) return &NoopMemory{ size: size, blockSize: blockSize, - dirty: dirty, + dirty: block.NewTrackerFromBitset(b, blockSize), exit: utils.NewErrorOnce(), } } -func (m *NoopMemory) Disable() error { +func (m *NoopMemory) Disable(context.Context) error { return nil } -func (m *NoopMemory) Dirty() *bitset.BitSet { - return m.dirty +func (m *NoopMemory) Dirty(context.Context) (*block.Tracker, error) { + return m.dirty.Clone(), nil } func (m *NoopMemory) Start(context.Context, string) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/serve.go b/packages/orchestrator/internal/sandbox/uffd/serve.go deleted file mode 100644 index c5866b52dd..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/serve.go +++ /dev/null @@ -1,200 +0,0 @@ -package uffd - -import ( - "context" - "errors" - "fmt" - "syscall" - "unsafe" - - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sys/unix" - - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" -) - -var ErrUnexpectedEventType = errors.New("unexpected event type") - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - PageSize uintptr `json:"page_size_kib"` -} - -func Serve( - ctx context.Context, - uffd int, - mappings mapping.Mappings, - src block.Slicer, - fdExit *fdexit.FdExit, - logger *zap.Logger, -) error { - pollFds := []unix.PollFd{ - {Fd: int32(uffd), Events: unix.POLLIN}, - {Fd: fdExit.Reader(), Events: unix.POLLIN}, - } - - var eg errgroup.Group - - missingPagesBeingHandled := map[int64]struct{}{} - -outerLoop: - for { - if _, err := unix.Poll( - pollFds, - -1, - ); err != nil { - if err == unix.EINTR { - logger.Debug("uffd: interrupted polling, going back to polling") - - continue - } - - if err == unix.EAGAIN { - logger.Debug("uffd: eagain during polling, going back to polling") - - continue - } - - logger.Error("UFFD serve polling error", zap.Error(err)) - - return fmt.Errorf("failed polling: %w", err) - } - - exitFd := pollFds[1] - if exitFd.Revents&unix.POLLIN != 0 { - errMsg := eg.Wait() - if errMsg != nil { - logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) - - return fmt.Errorf("failed to handle uffd: %w", errMsg) - } - - return nil - } - - uffdFd := pollFds[0] - if uffdFd.Revents&unix.POLLIN == 0 { - // Uffd is not ready for reading as there is nothing to read on the fd. - // https://github.com/firecracker-microvm/firecracker/issues/5056 - // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 - // TODO: Check for all the errors - // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html - // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c - // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html - // It might be possible to just check for data != 0 in the syscall.Read loop - // but I don't feel confident about doing that. - logger.Debug("uffd: no data in fd, going back to polling") - - continue - } - - buf := make([]byte, unsafe.Sizeof(userfaultfd.UffdMsg{})) - - for { - n, err := syscall.Read(uffd, buf) - if err == syscall.EINTR { - logger.Debug("uffd: interrupted read, reading again") - - continue - } - - if err == nil { - // There is no error so we can proceed. - break - } - - if err == syscall.EAGAIN { - logger.Debug("uffd: eagain error, going back to polling", zap.Error(err), zap.Int("read_bytes", n)) - - // Continue polling the fd. - continue outerLoop - } - - logger.Error("uffd: read error", zap.Error(err)) - - return fmt.Errorf("failed to read: %w", err) - } - - msg := *(*userfaultfd.UffdMsg)(unsafe.Pointer(&buf[0])) - if userfaultfd.GetMsgEvent(&msg) != userfaultfd.UFFD_EVENT_PAGEFAULT { - logger.Error("UFFD serve unexpected event type", zap.Any("event_type", userfaultfd.GetMsgEvent(&msg))) - - return ErrUnexpectedEventType - } - - arg := userfaultfd.GetMsgArg(&msg) - pagefault := (*(*userfaultfd.UffdPagefault)(unsafe.Pointer(&arg[0]))) - - addr := userfaultfd.GetPagefaultAddress(&pagefault) - - offset, pagesize, err := mappings.GetRange(uintptr(addr)) - if err != nil { - logger.Error("UFFD serve get mapping error", zap.Error(err)) - - return fmt.Errorf("failed to map: %w", err) - } - - if _, ok := missingPagesBeingHandled[offset]; ok { - continue - } - - missingPagesBeingHandled[offset] = struct{}{} - - eg.Go(func() error { - defer func() { - if r := recover(); r != nil { - logger.Error("UFFD serve panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) - } - }() - - b, err := src.Slice(ctx, offset, pagesize) - if err != nil { - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(err, signalErr) - - logger.Error("UFFD serve slice error", zap.Error(joinedErr)) - - return fmt.Errorf("failed to read from source: %w", joinedErr) - } - - cpy := userfaultfd.NewUffdioCopy( - b, - addr&^userfaultfd.CULong(pagesize-1), - userfaultfd.CULong(pagesize), - 0, - 0, - ) - - if _, _, errno := syscall.Syscall( - syscall.SYS_IOCTL, - uintptr(uffd), - userfaultfd.UFFDIO_COPY, - uintptr(unsafe.Pointer(&cpy)), - ); errno != 0 { - if errno == unix.EEXIST { - logger.Debug("UFFD serve page already mapped", zap.Any("offset", offset), zap.Any("pagesize", pagesize)) - - // Page is already mapped - return nil - } - - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(errno, signalErr) - - logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) - - return fmt.Errorf("failed uffdio copy %w", joinedErr) - } - - return nil - }) - } -} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go new file mode 100644 index 0000000000..68298ea6ea --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go @@ -0,0 +1,20 @@ +package testutils + +// FirstDifferentByte returns the first byte index where a and b differ. +// It also returns the differing byte values (want, got). +// If slices are identical, it returns idx -1. +func FirstDifferentByte(a, b []byte) (idx int, want, got byte) { + smallerSize := min(len(a), len(b)) + + for i := range smallerSize { + if a[i] != b[i] { + return i, b[i], a[i] + } + } + + if len(a) != len(b) { + return smallerSize, 0, 0 + } + + return -1, 0, 0 +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go new file mode 100644 index 0000000000..fa197edbd9 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go @@ -0,0 +1,43 @@ +package testutils + +import ( + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type testWriter struct { + t *testing.T +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + w.t.Log(string(p)) + + return len(p), nil +} + +// NewTestLogger creates a new zap logger that logs all zap logs to the test output. +func NewTestLogger(t *testing.T) *zap.Logger { + t.Helper() + + encoderCfg := zap.NewDevelopmentEncoderConfig() + encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + encoderCfg.CallerKey = zapcore.OmitKey + encoderCfg.ConsoleSeparator = " " + encoderCfg.TimeKey = "" + encoderCfg.MessageKey = "message" + encoderCfg.LevelKey = "level" + encoderCfg.NameKey = "logger" + encoderCfg.StacktraceKey = "stacktrace" + encoderCfg.EncodeTime = zapcore.RFC3339NanoTimeEncoder + encoderCfg.EncodeCaller = zapcore.ShortCallerEncoder + encoderCfg.EncodeDuration = zapcore.StringDurationEncoder + + encoder := zapcore.NewConsoleEncoder(encoderCfg) + + testSyncer := zapcore.AddSync(&testWriter{t}) + core := zapcore.NewCore(encoder, testSyncer, zap.DebugLevel) + + return zap.New(core, zap.AddCaller()) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go new file mode 100644 index 0000000000..ba2cac99e1 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go @@ -0,0 +1,47 @@ +package testutils + +import ( + "context" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" +) + +// MemorySlicer exposes byte slice via the Slicer interface. +// This is used for testing purposes. +type MemorySlicer struct { + content []byte + pagesize int64 + + accessed *block.Tracker +} + +var _ block.Slicer = (*MemorySlicer)(nil) + +func newMemorySlicer(content []byte, pagesize int64) *MemorySlicer { + return &MemorySlicer{ + content: content, + pagesize: pagesize, + accessed: block.NewTracker(pagesize), + } +} + +func (s *MemorySlicer) Slice(_ context.Context, offset, size int64) ([]byte, error) { + for i := offset; i < offset+size; i += s.pagesize { + s.accessed.Add(i) + } + + return s.content[offset : offset+size], nil +} + +func (s *MemorySlicer) Size() (int64, error) { + return int64(len(s.content)), nil +} + +func (s *MemorySlicer) Content() []byte { + return s.content +} + +// Offsets returns offsets of the content that were accessed via the Slice method. +func (s *MemorySlicer) Accessed() *block.Tracker { + return s.accessed.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go b/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go new file mode 100644 index 0000000000..e85a234ec5 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go @@ -0,0 +1,47 @@ +package testutils + +import ( + "fmt" + "math" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func NewPageMmap(size, pagesize uint64) ([]byte, uintptr, func() error, error) { + if pagesize == header.PageSize { + return newMmap(size, header.PageSize, 0) + } + + if pagesize == header.HugepageSize { + return newMmap(size, header.HugepageSize, unix.MAP_HUGETLB|unix.MAP_HUGE_2MB) + } + + return nil, 0, nil, fmt.Errorf("unsupported page size: %d", pagesize) +} + +// Even though UFFD behaves differently with file backend memory (and hugetlbfs file backed), the FC uses MAP_PRIVATE|MAP_ANONYMOUS, so the following stub is correct to test for FC. +// - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications +// - https://github.com/firecracker-microvm/firecracker/blob/a305f362d0e6f7ba926c73e65452cb51262a44d8/src/vmm/src/persist.rs#L499 +func newMmap(size, pagesize uint64, flags int) ([]byte, uintptr, func() error, error) { + l := int(math.Ceil(float64(size)/float64(pagesize)) * float64(pagesize)) + b, err := syscall.Mmap( + -1, + 0, + l, + syscall.PROT_READ|syscall.PROT_WRITE, + syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS|flags, + ) + if err != nil { + return nil, 0, nil, fmt.Errorf("failed to mmap: %w", err) + } + + closeMmap := func() error { + return syscall.Munmap(b) + } + + return b, uintptr(unsafe.Pointer(&b[0])), closeMmap, nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go b/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go new file mode 100644 index 0000000000..c8bfe8ed1e --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go @@ -0,0 +1,17 @@ +package testutils + +import ( + "crypto/rand" +) + +func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { + size := pagesize * numberOfPages + + n := int(size) + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + + return newMemorySlicer(buf, int64(pagesize)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/handler.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go similarity index 60% rename from packages/orchestrator/internal/sandbox/uffd/handler.go rename to packages/orchestrator/internal/sandbox/uffd/uffd.go index 947ff168cf..66e8aca128 100644 --- a/packages/orchestrator/internal/sandbox/uffd/handler.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -10,13 +10,13 @@ import ( "syscall" "time" - "github.com/bits-and-blooms/bitset" "go.opentelemetry.io/otel" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -26,29 +26,22 @@ var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/interna const ( uffdMsgListenerTimeout = 10 * time.Second fdSize = 4 - mappingsSize = 1024 + regionMappingsSize = 1024 ) type Uffd struct { - exit *utils.ErrorOnce - readyCh chan struct{} - - fdExit *fdexit.FdExit - - lis *net.UnixListener - - memfile *block.TrackedSliceDevice + exit *utils.ErrorOnce + readyCh chan struct{} + fdExit *fdexit.FdExit + lis *net.UnixListener socketPath string + memfile block.ReadonlyDevice + handler utils.SetOnce[*userfaultfd.Userfaultfd] } var _ MemoryBackend = (*Uffd)(nil) -func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uffd, error) { - trackedMemfile, err := block.NewTrackedSliceDevice(blockSize, memfile) - if err != nil { - return nil, fmt.Errorf("failed to create tracked slice device: %w", err) - } - +func New(memfile block.ReadonlyDevice, socketPath string) (*Uffd, error) { fdExit, err := fdexit.New() if err != nil { return nil, fmt.Errorf("failed to create fd exit: %w", err) @@ -58,8 +51,9 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff exit: utils.NewErrorOnce(), readyCh: make(chan struct{}, 1), fdExit: fdExit, - memfile: trackedMemfile, socketPath: socketPath, + memfile: memfile, + handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](), }, nil } @@ -106,19 +100,19 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { unixConn := conn.(*net.UnixConn) - mappingsBuf := make([]byte, mappingsSize) + regionMappingsBuf := make([]byte, regionMappingsSize) uffdBuf := make([]byte, syscall.CmsgSpace(fdSize)) - numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(mappingsBuf, uffdBuf) + numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(regionMappingsBuf, uffdBuf) if err != nil { return fmt.Errorf("failed to read unix msg from connection: %w", err) } - mappingsBuf = mappingsBuf[:numBytesMappings] + regionMappingsBuf = regionMappingsBuf[:numBytesMappings] - var m mapping.FcMappings + var regions []memory.Region - err = json.Unmarshal(mappingsBuf, &m) + err = json.Unmarshal(regionMappingsBuf, ®ions) if err != nil { return fmt.Errorf("failed parsing memory mapping data: %w", err) } @@ -141,24 +135,48 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { return fmt.Errorf("expected 1 fd: found %d", len(fds)) } - uffd := fds[0] + m := memory.NewMapping(regions) + + uffd, err := userfaultfd.NewUserfaultfdFromFd( + uintptr(fds[0]), + u.memfile, + u.memfile.BlockSize(), + m, + zap.L().With(logger.WithSandboxID(sandboxId)), + ) + if err != nil { + return fmt.Errorf("failed to create uffd: %w", err) + } + + u.handler.SetValue(uffd) defer func() { - closeErr := syscall.Close(uffd) + closeErr := uffd.Close() if closeErr != nil { zap.L().Error("failed to close uffd", logger.WithSandboxID(sandboxId), zap.String("socket_path", u.socketPath), zap.Error(closeErr)) } }() + for _, region := range m.Regions { + // Mark the memory region as write protected. + // It seems the memory in FC is by default registered with the WP flag capability. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.Register( + region.BaseHostVirtAddr+region.Offset, + uint64(region.Size), + userfaultfd.UFFDIO_REGISTER_MODE_WP|userfaultfd.UFFDIO_REGISTER_MODE_MISSING, + ) + if err != nil { + return fmt.Errorf("failed to reregister memory region with write protection %d-%d", region.Offset, region.Offset+region.Size) + } + } + u.readyCh <- struct{}{} - err = Serve( + err = uffd.Serve( ctx, - uffd, - m, - u.memfile, u.fdExit, - zap.L().With(logger.WithSandboxID(sandboxId)), ) if err != nil { return fmt.Errorf("failed handling uffd: %w", err) @@ -175,18 +193,28 @@ func (u *Uffd) Ready() chan struct{} { return u.readyCh } -func (u *Uffd) Exit() *utils.ErrorOnce { - return u.exit -} +func (u *Uffd) Disable(ctx context.Context) error { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return fmt.Errorf("failed to get uffd: %w", err) + } + + uffd.Disable() -func (u *Uffd) TrackAndReturnNil() error { - return u.lis.Close() + return nil } -func (u *Uffd) Disable() error { - return u.memfile.Disable() +func (u *Uffd) Exit() *utils.ErrorOnce { + return u.exit } -func (u *Uffd) Dirty() *bitset.BitSet { - return u.memfile.Dirty() +// Dirty waits for all the requests in flight to be finished and then returns clone of the dirty tracker. +// Call *after* pausing the firecracker process—to let the uffd process all the requests. +func (u *Uffd) Dirty(ctx context.Context) (*block.Tracker, error) { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get uffd: %w", err) + } + + return uffd.Dirty(ctx) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go index 63a3b3e71e..86c8633883 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go @@ -34,14 +34,13 @@ const ( UFFDIO_API = C.UFFDIO_API UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE - UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS - UFFD_FEATURE_WP_HUGETLBFS_SHMEM = C.UFFD_FEATURE_WP_HUGETLBFS_SHMEM + UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) type ( @@ -79,7 +78,7 @@ func NewUffdioRegister(start, length, mode CULong) UffdioRegister { func NewUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { return UffdioCopy{ src: CULong(uintptr(unsafe.Pointer(&b[0]))), - dst: address &^ (pagesize - 1), + dst: address, len: pagesize, mode: mode, copy: bytesCopied, @@ -104,14 +103,6 @@ func GetMsgArg(msg *UffdMsg) [24]byte { return msg.arg } -func GetPagefaultAddress(pagefault *UffdPagefault) CULong { - return pagefault.address -} - -func IsWritePageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 -} - -func IsWriteProtectPageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WP != 0 +func GetPagefaultAddress(pagefault *UffdPagefault) uintptr { + return uintptr(pagefault.address) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid new file mode 100644 index 0000000000..7b4ead1174 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid @@ -0,0 +1,4 @@ +flowchart TD +A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +E -- write (WP flag) --> F(remove MODE_WP) --> C \ No newline at end of file diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go new file mode 100644 index 0000000000..4dc7d5ba4c --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go @@ -0,0 +1,275 @@ +package userfaultfd + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + "go.uber.org/zap" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func (u *Userfaultfd) Serve( + ctx context.Context, + fdExit *fdexit.FdExit, +) error { + pollFds := []unix.PollFd{ + {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: fdExit.Reader(), Events: unix.POLLIN}, + } + +outerLoop: + for { + if _, err := unix.Poll( + pollFds, + -1, + ); err != nil { + if err == unix.EINTR { + u.logger.Debug("uffd: interrupted polling, going back to polling") + + continue + } + + if err == unix.EAGAIN { + u.logger.Debug("uffd: eagain during polling, going back to polling") + + continue + } + + u.logger.Error("UFFD serve polling error", zap.Error(err)) + + return fmt.Errorf("failed polling: %w", err) + } + + exitFd := pollFds[1] + if exitFd.Revents&unix.POLLIN != 0 { + errMsg := u.wg.Wait() + if errMsg != nil { + u.logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) + + return fmt.Errorf("failed to handle uffd: %w", errMsg) + } + + return nil + } + + uffdFd := pollFds[0] + if uffdFd.Revents&unix.POLLIN == 0 { + // Uffd is not ready for reading as there is nothing to read on the fd. + // https://github.com/firecracker-microvm/firecracker/issues/5056 + // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 + // TODO: Check for all the errors + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html + // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c + // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html + // It might be possible to just check for data != 0 in the syscall.Read loop + // but I don't feel confident about doing that. + u.logger.Debug("uffd: no data in fd, going back to polling") + + continue + } + + buf := make([]byte, unsafe.Sizeof(UffdMsg{})) + + for { + n, err := syscall.Read(int(u.fd), buf) + if err == syscall.EINTR { + u.logger.Debug("uffd: interrupted read, reading again") + + continue + } + + if err == nil { + // There is no error so we can proceed. + break + } + + if err == syscall.EAGAIN { + u.logger.Debug("uffd: eagain error, going back to polling", zap.Error(err), zap.Int("read_bytes", n)) + + // Continue polling the fd. + continue outerLoop + } + + u.logger.Error("uffd: read error", zap.Error(err)) + + return fmt.Errorf("failed to read: %w", err) + } + + msg := *(*UffdMsg)(unsafe.Pointer(&buf[0])) + if GetMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT { + u.logger.Error("UFFD serve unexpected event type", zap.Any("event_type", GetMsgEvent(&msg))) + + return ErrUnexpectedEventType + } + + arg := GetMsgArg(&msg) + pagefault := (*(*UffdPagefault)(unsafe.Pointer(&arg[0]))) + flags := pagefault.flags + + addr := GetPagefaultAddress(&pagefault) + + offset, pagesize, err := u.ma.GetOffset(addr) + if err != nil { + u.logger.Error("UFFD serve get mapping error", zap.Error(err)) + + return fmt.Errorf("failed to map: %w", err) + } + + // Handle write to write protected page (WP flag) + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + err := u.handleWriteProtection(ctx, addr, offset, pagesize) + if err != nil { + return fmt.Errorf("failed to handle write protection: %w", err) + } + + continue + } + + // Handle write to missing page (WRITE flag) + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize, true) + if err != nil { + return fmt.Errorf("failed to handle missing write: %w", err) + } + + continue + } + + // Handle read to missing page ("MISSING" flag) + if flags == 0 { + err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize, false) + if err != nil { + return fmt.Errorf("failed to handle missing: %w", err) + } + + continue + } + + u.logger.Warn("UFFD serve unexpected event type", zap.Any("event_type", flags)) + } +} + +func (u *Userfaultfd) handleMissing( + ctx context.Context, + onFailure func() error, + addr uintptr, + offset int64, + pagesize uint64, + write bool, +) error { + if write { + u.writesInProgress.Add() + } else if !u.missingRequests.Add(offset) { + return nil + } + + err := u.workerSem.Acquire(ctx, 1) + if err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + + u.wg.Go(func() error { + defer u.workerSem.Release(1) + + defer func() { + if r := recover(); r != nil { + u.logger.Error("UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + defer func() { + if write { + u.writesInProgress.Done() + } + }() + + var b []byte + + if u.disabled.Load() { + b = header.EmptyHugePage[:pagesize] + } else { + sliceB, sliceErr := u.src.Slice(ctx, offset, int64(pagesize)) + if sliceErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(sliceErr, signalErr) + + u.logger.Error("UFFD serve slice error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to read from source: %w", joinedErr) + } + + b = sliceB + } + var copyMode CULong + + if !write { + copyMode |= UFFDIO_COPY_MODE_WP + } + + copyErr := u.copy(addr, b, pagesize, copyMode) + if errors.Is(copyErr, unix.EEXIST) { + // Page is already mapped + + return nil + } + + if copyErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(copyErr, signalErr) + + u.logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) + + return fmt.Errorf("failed uffdio copy %w", joinedErr) + } + + // We mark the page as dirty if it was a write to a page that was not already mapped. + if write { + u.dirty.Add(offset) + } + + return nil + }) + + return nil +} + +func (u *Userfaultfd) handleWriteProtection(ctx context.Context, addr uintptr, offset int64, pagesize uint64) error { + err := u.workerSem.Acquire(ctx, 1) + if err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + + u.writesInProgress.Add() + + u.wg.Go(func() error { + defer u.workerSem.Release(1) + + defer func() { + if r := recover(); r != nil { + u.logger.Error("UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + defer u.writesInProgress.Done() + + wpErr := u.RemoveWriteProtection(addr, pagesize) + if wpErr != nil { + return fmt.Errorf("error removing write protection from page %d", addr) + } + + // We mark the page as dirty if it was a write to a page that was already mapped. + u.dirty.Add(offset) + + return nil + }) + + return nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go new file mode 100644 index 0000000000..c5f3684c67 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go @@ -0,0 +1,96 @@ +package userfaultfd + +import ( + "fmt" + "syscall" + "unsafe" + + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newUserfaultfd(flags uintptr, src block.Slicer, pagesize int64, m *memory.Mapping, logger *zap.Logger) (*Userfaultfd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return nil, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return NewUserfaultfdFromFd(uffd, src, pagesize, m, logger) +} + +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func (u *Userfaultfd) configureApi(pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := NewUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (u *Userfaultfd) Register(addr uintptr, size uint64, mode CULong) error { + register := NewUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (u *Userfaultfd) writeProtect(addr uintptr, size uint64, mode CULong) error { + register := NewUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (u *Userfaultfd) RemoveWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, 0) +} + +func (u *Userfaultfd) AddWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, UFFDIO_WRITEPROTECT_MODE_WP) +} + +// mode: UFFDIO_COPY_MODE_WP +// When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page +func (u *Userfaultfd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { + cpy := NewUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) + + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + return errno + } + + // Check if the copied size matches the requested pagesize + if uint64(cpy.copy) != pagesize { + return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) + } + + return nil +} + +func (u *Userfaultfd) Close() error { + return syscall.Close(int(u.fd)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go new file mode 100644 index 0000000000..0652f8ec62 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -0,0 +1,65 @@ +package userfaultfd + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +var ErrUnexpectedEventType = errors.New("unexpected event type") + +type Userfaultfd struct { + fd uintptr + + src block.Slicer + ma *memory.Mapping + dirty *block.Tracker + disabled atomic.Bool + + missingRequests *block.Tracker + workerSem *semaphore.Weighted + + writesInProgress *utils.SettleCounter + wg errgroup.Group + + logger *zap.Logger +} + +// NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. +func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, pagesize int64, m *memory.Mapping, logger *zap.Logger) (*Userfaultfd, error) { + return &Userfaultfd{ + fd: fd, + src: src, + dirty: block.NewTracker(pagesize), + missingRequests: block.NewTracker(pagesize), + disabled: atomic.Bool{}, + workerSem: semaphore.NewWeighted(2048), + ma: m, + writesInProgress: utils.NewZeroSettleCounter(), + logger: logger, + }, nil +} + +func (u *Userfaultfd) Disable() { + u.disabled.Store(true) +} + +func (u *Userfaultfd) Dirty(ctx context.Context) (*block.Tracker, error) { + err := u.writesInProgress.Wait(ctx) + if err != nil { + return nil, fmt.Errorf("failed to wait for write requests: %w", err) + } + + u.missingRequests.Reset() + + return u.dirty.Clone(), nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..9736822dbb --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,809 @@ +package userfaultfd + +import ( + "bytes" + "context" + "fmt" + "slices" + "syscall" + "testing" + "time" + + "github.com/bits-and-blooms/bitset" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +type testConfig struct { + name string + // Page size of the memory area. + pagesize uint64 + // Number of pages in the memory area. + numberOfPages uint64 + // Operations to trigger on the memory area. + operations []operation +} + +type operationMode uint32 + +const ( + operationModeRead operationMode = 1 << iota + operationModeWrite +) + +type operation struct { + // Offset in bytes. Must be smaller than the (numberOfPages-1) * pagesize as it reads a page and it must be aligned to the pagesize from the testConfig. + offset int64 + mode operationMode +} + +func TestUffdMissing(t *testing.T) { + tests := []testConfig{ + { + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at middle", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at last page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at start", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at middle", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at last page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanupFunc := configureTest(t, tt) + defer cleanupFunc() + + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + require.NoError(t, err) + } + } + + err := h.uffd.writesInProgress.Wait(t.Context()) + require.NoError(t, err) + + expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + assert.Equal(t, expectedAccessedOffsets, h.getAccessedOffsets(), "checking which pages were faulted)") + }) + } +} + +func TestUffdWriteProtection(t *testing.T) { + tests := []testConfig{ + { + name: "standard 4k page, single write", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single read then write on first page (MISSING then WP)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single write then read on first page (WRITE then skipping)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, single read then write on non-first page (MISSING then WP)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, two writes on different pages", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 16 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, two writes on same page (WRITE then skipping)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, three writes on same page (WRITE then skipping)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, read then two writes on same page (MISSING then WP then WP)", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single write", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on first page (MISSING then WP)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on non-first page (MISSING then WP)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single write then read on non-first page (WRITE then skipping)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, two writes on different pages", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, two writes on same page (WRITE then skipping)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, three writes on same page (WRITE then skipping)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, read then two writes on same page (MISSING then WP then WP)", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + require.NoError(t, err) + } + + if operation.mode == operationModeWrite { + err := h.executeWrite(t.Context(), operation) + require.NoError(t, err) + } + } + + err := h.uffd.writesInProgress.Wait(t.Context()) + require.NoError(t, err) + + expectedWriteOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + assert.Equal(t, expectedWriteOffsets, h.getWriteOffsets(), "checking which pages were written to") + + expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + assert.Equal(t, expectedAccessedOffsets, h.getAccessedOffsets(), "checking which pages were faulted)") + }) + } +} + +// Will trigger UFFD and with higher volume overload it before reaching our code. +func TestUffdParallelWP(t *testing.T) { + parallelOperations := 10_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 5, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + // Single read to add Write protection to the page + err := h.executeRead(t.Context(), readOp) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeWrite(t.Context(), writeOp) + }) + } + + err = verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") + assert.Equal(t, []uint{0}, h.getWriteOffsets(), "pages written to (page 0)") +} + +// Will trigger UFFD and with higher volume overload it before reaching our code. +func TestUffdParallelWrite(t *testing.T) { + parallelOperations := 10_00 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeWrite(t.Context(), writeOp) + }) + } + + err := verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") + assert.Equal(t, []uint{0}, h.getWriteOffsets(), "pages written to (page 0)") +} + +func TestUffdParallelWriteWithPrefault(t *testing.T) { + parallelOperations := 10_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + err := h.executeWrite(t.Context(), writeOp) + require.NoError(t, err) + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeWrite(t.Context(), writeOp) + }) + } + + err = verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") + assert.Equal(t, []uint{0}, h.getWriteOffsets(), "pages written to (page 0)") +} + +func TestUffdParallelMissing(t *testing.T) { + parallelOperations := 100_0000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeRead(t.Context(), readOp) + }) + } + + err := verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") +} + +func TestUffdParallelMissingWithPrefault(t *testing.T) { + parallelOperations := 10_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + err := h.executeRead(t.Context(), readOp) + require.NoError(t, err) + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeRead(t.Context(), readOp) + }) + } + + err = verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") +} + +func TestUffdSerialWP(t *testing.T) { + serialOperations := 1_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + err := h.executeRead(t.Context(), readOp) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + var verr errgroup.Group + + for range serialOperations { + err = h.executeWrite(t.Context(), writeOp) + require.NoError(t, err) + } + + err = verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") +} + +func TestUffdSerialMissing(t *testing.T) { + serialOperations := 1_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, cleanup := configureTest(t, tt) + t.Cleanup(cleanup) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + var verr errgroup.Group + + for range serialOperations { + err := h.executeRead(t.Context(), readOp) + require.NoError(t, err) + } + + err := verr.Wait() + require.NoError(t, err) + + assert.Equal(t, []uint{0}, h.getAccessedOffsets(), "pages accessed (page 0)") +} + +type testHandler struct { + memoryArea *[]byte + pagesize uint64 + data *testutils.MemorySlicer + memoryMap *memory.Mapping + uffd *Userfaultfd +} + +func (h *testHandler) getAccessedOffsets() []uint { + return utils.Map(slices.Collect(h.uffd.missingRequests.BitSet().Union(h.uffd.dirty.BitSet()).EachSet()), func(offset uint) uint { + return uint(header.BlockOffset(int64(offset), int64(h.pagesize))) + }) +} + +func (h *testHandler) getWriteOffsets() []uint { + return utils.Map(slices.Collect(h.uffd.dirty.BitSet().EachSet()), func(offset uint) uint { + return uint(header.BlockOffset(int64(offset), int64(h.pagesize))) + }) +} + +func (h *testHandler) executeRead(ctx context.Context, op operation) error { + readBytes := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)] + + expectedBytes, err := h.data.Slice(ctx, op.offset, int64(h.pagesize)) + if err != nil { + return err + } + + if !bytes.Equal(readBytes, expectedBytes) { + idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) + + return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) + } + + return nil +} + +func (h *testHandler) executeWrite(ctx context.Context, op operation) error { + bytesToWrite, err := h.data.Slice(ctx, op.offset, int64(h.pagesize)) + if err != nil { + return err + } + + n := copy((*h.memoryArea)[op.offset:op.offset+int64(h.pagesize)], bytesToWrite) + if n != int(h.pagesize) { + return fmt.Errorf("copy length mismatch: want %d, got %d", h.pagesize, n) + } + + // err = h.uffd.writesInProgress.Wait(ctx) + // if err != nil { + // return fmt.Errorf("failed to wait for write requests finish: %w", err) + // } + + // if !h.uffd.dirty.Has(op.offset) { + // return fmt.Errorf("dirty bit not set for page at offset %d, all dirty offsets: %v", op.offset, h.getWriteOffsets()) + // } + + return nil +} + +func configureTest(t *testing.T, tt testConfig) (*testHandler, func()) { + t.Helper() + + cleanupList := []func(){} + + cleanup := func() { + slices.Reverse(cleanupList) + + for _, cleanup := range cleanupList { + cleanup() + } + } + + data := testutils.RandomPages(tt.pagesize, tt.numberOfPages) + + size, err := data.Size() + require.NoError(t, err) + + memoryArea, memoryStart, unmap, err := testutils.NewPageMmap(uint64(size), tt.pagesize) + require.NoError(t, err) + + cleanupList = append(cleanupList, func() { + unmap() + }) + + m := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: uintptr(0), + PageSize: uintptr(tt.pagesize), + }, + }) + + logger := testutils.NewTestLogger(t) + + fdExit, err := fdexit.New() + require.NoError(t, err) + + cleanupList = append(cleanupList, func() { + fdExit.Close() + }) + + uffd, err := newUserfaultfd(syscall.O_CLOEXEC|syscall.O_NONBLOCK, data, int64(tt.pagesize), m, logger) + require.NoError(t, err) + + cleanupList = append(cleanupList, func() { + uffd.Close() + }) + + err = uffd.configureApi(tt.pagesize) + require.NoError(t, err) + + err = uffd.Register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP) + require.NoError(t, err) + + exitUffd := make(chan struct{}, 1) + + go func() { + err := uffd.Serve(t.Context(), fdExit) + assert.NoError(t, err) + + exitUffd <- struct{}{} + }() + + cleanupList = append(cleanupList, func() { + signalExitErr := fdExit.SignalExit() + assert.NoError(t, signalExitErr) + + <-exitUffd + }) + + time.Sleep(1 * time.Second) + + return &testHandler{ + memoryArea: &memoryArea, + memoryMap: m, + pagesize: tt.pagesize, + data: data, + uffd: uffd, + }, cleanup +} + +// Get a bitset of the offsets of the operations for the given mode. +func getOperationsOffsets(ops []operation, m operationMode) []uint { + b := bitset.New(0) + + for _, operation := range ops { + if operation.mode&m != 0 { + b.Set(uint(operation.offset)) + } + } + + return slices.Collect(b.EachSet()) +} diff --git a/packages/shared/pkg/utils/settle_counter.go b/packages/shared/pkg/utils/settle_counter.go new file mode 100644 index 0000000000..96aee28402 --- /dev/null +++ b/packages/shared/pkg/utils/settle_counter.go @@ -0,0 +1,74 @@ +package utils + +import ( + "context" + "sync" +) + +type SettleCounter struct { + mu sync.Mutex + cond *sync.Cond + counter int64 + settleValue int64 +} + +func NewZeroSettleCounter() *SettleCounter { + c := &SettleCounter{settleValue: 0} + + c.cond = sync.NewCond(&c.mu) + + return c +} + +func (s *SettleCounter) add(delta int64) { + s.mu.Lock() + + s.counter += delta + + if s.counter == s.settleValue { + s.cond.Broadcast() // wake up all waiters + } + + s.mu.Unlock() +} + +func (s *SettleCounter) Add() { + s.add(1) +} + +func (s *SettleCounter) Done() { + s.add(-1) +} + +func (s *SettleCounter) Wait(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + // fast path + if s.counter == s.settleValue { + return nil + } + + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + s.mu.Lock() + s.cond.Broadcast() // wake waiters to check ctx + s.mu.Unlock() + case <-done: + } + }() + + for s.counter != s.settleValue { + if ctx.Err() != nil { + return ctx.Err() + } + + s.cond.Wait() + } + + return nil +} diff --git a/packages/shared/pkg/utils/settle_counter_test.go b/packages/shared/pkg/utils/settle_counter_test.go new file mode 100644 index 0000000000..2a943baae7 --- /dev/null +++ b/packages/shared/pkg/utils/settle_counter_test.go @@ -0,0 +1,228 @@ +package utils + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSettleCounter_NewZeroSettleCounter(t *testing.T) { + sc := NewZeroSettleCounter() + + // Counter should start at 0 (settle value) + require.Equal(t, int64(0), sc.counter, "Expected counter to start at 0") +} + +func TestSettleCounter_AddAndDone(t *testing.T) { + sc := NewZeroSettleCounter() + + // Test Add + sc.Add() + require.Equal(t, int64(1), sc.counter, "Expected counter to be 1 after Add") + + sc.Add() + require.Equal(t, int64(2), sc.counter, "Expected counter to be 2 after second Add") + + // Test Done + sc.Done() + require.Equal(t, int64(1), sc.counter, "Expected counter to be 1 after Done") + + sc.Done() + require.Equal(t, int64(0), sc.counter, "Expected counter to be 0 after second Done") +} + +func TestSettleCounter_Wait_AlreadySettled(t *testing.T) { + sc := NewZeroSettleCounter() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Counter is already at 0, should return immediately + err := sc.Wait(ctx) + require.NoError(t, err, "Expected no error when already settled") +} + +func TestSettleCounter_Wait_SettlesAfterDone(t *testing.T) { + sc := NewZeroSettleCounter() + + // Add some work + sc.Add() + sc.Add() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Start waiting in a goroutine + var wg sync.WaitGroup + var waitErr error + wg.Add(1) + go func() { + defer wg.Done() + waitErr = sc.Wait(ctx) + }() + + // Give the wait goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Complete the work + sc.Done() + sc.Done() + + // Wait for the wait to complete + wg.Wait() + + require.NoError(t, waitErr, "Expected no error when settling") +} + +func TestSettleCounter_Wait_ContextTimeout(t *testing.T) { + sc := NewZeroSettleCounter() + + // Add work that won't be completed + sc.Add() + sc.Add() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := sc.Wait(ctx) + require.Error(t, err, "Expected context timeout error") + require.ErrorIs(t, err, context.DeadlineExceeded, "Expected context.DeadlineExceeded") +} + +func TestSettleCounter_Wait_ContextCancel(t *testing.T) { + sc := NewZeroSettleCounter() + + // Add work that won't be completed + sc.Add() + sc.Add() + + ctx, cancel := context.WithCancel(context.Background()) + + // Start waiting in a goroutine + var wg sync.WaitGroup + var waitErr error + wg.Add(1) + go func() { + defer wg.Done() + waitErr = sc.Wait(ctx) + }() + + // Give the wait goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Cancel the context + cancel() + + // Wait for the wait to complete + wg.Wait() + + require.Error(t, waitErr, "Expected context cancellation error") + + require.ErrorIs(t, waitErr, context.Canceled, "Expected context.Canceled") +} + +func TestSettleCounter_Close(t *testing.T) { + sc := NewZeroSettleCounter() + + // Add some work + sc.Add() + sc.Add() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Start waiting in a goroutine + var wg sync.WaitGroup + var waitErr error + wg.Add(1) + go func() { + defer wg.Done() + waitErr = sc.Wait(ctx) + }() + + // Give the wait goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Close should settle the counter + sc.close() + + // Wait for the wait to complete + wg.Wait() + + require.NoError(t, waitErr, "Expected no error when closing") + + // Counter should be at settle value (0) + require.Equal(t, int64(0), sc.counter, "Expected counter to be 0 after Close") +} + +func TestSettleCounter_ConcurrentOperations(t *testing.T) { + sc := NewZeroSettleCounter() + + const numGoroutines = 10 + const operationsPerGoroutine = 100 + + _, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var wg sync.WaitGroup + + // Start multiple goroutines that add and then done + for range numGoroutines { + wg.Add(1) + go func() { + defer wg.Done() + for range operationsPerGoroutine { + sc.Add() + sc.Done() + } + }() + } + + // Wait for all operations to complete + wg.Wait() + + // Counter should be back to 0 + require.Equal(t, int64(0), sc.counter, "Expected counter to be 0 after all operations") +} + +func TestSettleCounter_MultipleWaiters(t *testing.T) { + sc := NewZeroSettleCounter() + + // Add some work + sc.Add() + sc.Add() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + const numWaiters = 5 + var wg sync.WaitGroup + errors := make([]error, numWaiters) + + // Start multiple waiters + for i := range numWaiters { + wg.Add(1) + go func(index int) { + defer wg.Done() + errors[index] = sc.Wait(ctx) + }(i) + } + + // Give waiters time to start + time.Sleep(10 * time.Millisecond) + + // Complete the work + sc.Done() + sc.Done() + + // Wait for all waiters to complete + wg.Wait() + + // All waiters should succeed + for i, err := range errors { + require.NoError(t, err, "Waiter %d got unexpected error: %v", i, err) + } +}