diff --git a/packages/api/internal/orchestrator/keep_alive.go b/packages/api/internal/orchestrator/keep_alive.go index 033da24863..687c230cee 100644 --- a/packages/api/internal/orchestrator/keep_alive.go +++ b/packages/api/internal/orchestrator/keep_alive.go @@ -23,6 +23,10 @@ func (o *Orchestrator) KeepAliveFor(ctx context.Context, sandboxID string, durat now := time.Now() updateFunc := func(sbx sandbox.Sandbox) (sandbox.Sandbox, error) { + if sbx.State != sandbox.StateRunning { + return sbx, &sandbox.NotFoundError{SandboxID: sandboxID} + } + maxAllowedTTL := getMaxAllowedTTL(now, sbx.StartTime, duration, sbx.MaxInstanceLength) newEndTime := now.Add(maxAllowedTTL) diff --git a/packages/api/internal/sandbox/store/memory/operations.go b/packages/api/internal/sandbox/store/memory/operations.go index bfd5da5497..be50c0cd44 100644 --- a/packages/api/internal/sandbox/store/memory/operations.go +++ b/packages/api/internal/sandbox/store/memory/operations.go @@ -120,6 +120,7 @@ func (s *Store) Update(sandboxID string, updateFunc func(sandbox.Sandbox) (sandb item.mu.Lock() defer item.mu.Unlock() + sbx, err := updateFunc(item._data) if err != nil { return sandbox.Sandbox{}, err diff --git a/tests/integration/internal/tests/api/sandboxes/sandbox_set_timeout_test.go b/tests/integration/internal/tests/api/sandboxes/sandbox_set_timeout_test.go new file mode 100644 index 0000000000..80ae72de92 --- /dev/null +++ b/tests/integration/internal/tests/api/sandboxes/sandbox_set_timeout_test.go @@ -0,0 +1,61 @@ +package sandboxes + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/tests/integration/internal/api" + "github.com/e2b-dev/infra/tests/integration/internal/setup" + "github.com/e2b-dev/infra/tests/integration/internal/utils" +) + +func TestSandboxSetTimeoutPausingSandbox(t *testing.T) { + c := setup.GetAPIClient() + + t.Run("test set timeout while pausing", func(t *testing.T) { + sbx := utils.SetupSandboxWithCleanup(t, c, utils.WithAutoPause(true)) + sbxId := sbx.SandboxID + + // Pause the sandbox + wg := errgroup.Group{} + wg.Go(func() error { + pauseResp, err := c.PostSandboxesSandboxIDPauseWithResponse(t.Context(), sbxId, setup.WithAPIKey()) + if err != nil { + return err + } + + if pauseResp.StatusCode() != http.StatusNoContent { + return fmt.Errorf("unexpected status code: %d", pauseResp.StatusCode()) + } + + return nil + }) + + for range 5 { + time.Sleep(200 * time.Millisecond) + wg.Go(func() error { + setTimeoutResp, err := c.PostSandboxesSandboxIDTimeoutWithResponse(t.Context(), sbxId, api.PostSandboxesSandboxIDTimeoutJSONRequestBody{ + Timeout: 15, + }, + setup.WithAPIKey()) + if err != nil { + return err + } + + if setTimeoutResp.StatusCode() != http.StatusNotFound { + return fmt.Errorf("unexpected status code: %d", setTimeoutResp.StatusCode()) + } + + return nil + }) + } + + err := wg.Wait() + require.NoError(t, err) + }) +}