Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions service/history/activity_command_task_dispatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package history

import (
"context"
"errors"
"fmt"
"time"

enumspb "go.temporal.io/api/enums/v1"
nexuspb "go.temporal.io/api/nexus/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
workerpb "go.temporal.io/api/worker/v1"
"go.temporal.io/server/api/matchingservice/v1"
"go.temporal.io/server/common/debug"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/payload"
"go.temporal.io/server/common/resource"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/tasks"
)

const (
activityCommandTaskTimeout = time.Second * 10 * debug.TimeoutMultiplier
)

// activityCommandTaskDispatcher handles dispatching activity command tasks to workers.
type activityCommandTaskDispatcher struct {
matchingRawClient resource.MatchingRawClient
config *configs.Config
metricsHandler metrics.Handler
logger log.Logger
}

func newActivityCommandTaskDispatcher(
matchingRawClient resource.MatchingRawClient,
config *configs.Config,
metricsHandler metrics.Handler,
logger log.Logger,
) *activityCommandTaskDispatcher {
return &activityCommandTaskDispatcher{
matchingRawClient: matchingRawClient,
config: config,
metricsHandler: metricsHandler,
logger: logger,
}
}

func (d *activityCommandTaskDispatcher) execute(
ctx context.Context,
task *tasks.ActivityCommandTask,
) error {
if !d.config.EnableActivityCancellationNexusTask() {
return nil
}

if len(task.TaskTokens) == 0 {
return nil
}

ctx, cancel := context.WithTimeout(ctx, activityCommandTaskTimeout)
defer cancel()

return d.dispatchToWorker(ctx, task)
}

func (d *activityCommandTaskDispatcher) dispatchToWorker(
ctx context.Context,
task *tasks.ActivityCommandTask,
) error {
notificationRequest := &workerpb.ActivityNotificationRequest{
NotificationType: workerpb.ActivityNotificationType(task.CommandType),
TaskTokens: task.TaskTokens,
}
requestPayload, err := payload.Encode(notificationRequest)
if err != nil {
return fmt.Errorf("failed to encode activity command request: %w", err)
}

nexusRequest := &nexuspb.Request{
Header: map[string]string{},
Variant: &nexuspb.Request_StartOperation{
StartOperation: &nexuspb.StartOperationRequest{
Service: workerpb.WorkerService.ServiceName,
Operation: workerpb.WorkerService.NotifyActivity.Name(),
Payload: requestPayload,
},
},
}

resp, err := d.matchingRawClient.DispatchNexusTask(ctx, &matchingservice.DispatchNexusTaskRequest{
NamespaceId: task.NamespaceID,
TaskQueue: &taskqueuepb.TaskQueue{
Name: task.Destination,
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
},
Request: nexusRequest,
})
if err != nil {
d.logger.Warn("Failed to dispatch activity command to worker",
tag.NewStringTag("control_queue", task.Destination),
tag.Error(err))
return err
}

return d.handleDispatchResponse(resp, task.Destination)
}

func (d *activityCommandTaskDispatcher) handleDispatchResponse(
resp *matchingservice.DispatchNexusTaskResponse,
controlQueue string,
) error {
// Check for timeout (no worker polling)
if resp.GetRequestTimeout() != nil {
d.logger.Warn("No worker polling control queue for activity command",
tag.NewStringTag("control_queue", controlQueue))
return errors.New("no worker polling control queue")
}

// Check for worker handler failure
if failure := resp.GetFailure(); failure != nil {
d.logger.Warn("Worker handler failed for activity command",
tag.NewStringTag("control_queue", controlQueue),
tag.NewStringTag("failure_message", failure.GetMessage()))
return fmt.Errorf("worker handler failed: %s", failure.GetMessage())
}

// Check operation-level response
nexusResp := resp.GetResponse()
if nexusResp == nil {
return nil
}

startOpResp := nexusResp.GetStartOperation()
if startOpResp == nil {
return nil
}

// Check for operation failure (terminal - don't retry)
if opFailure := startOpResp.GetFailure(); opFailure != nil {
d.logger.Warn("Activity command operation failure",
tag.NewStringTag("control_queue", controlQueue),
tag.NewStringTag("failure_message", opFailure.GetMessage()))
return nil
}

return nil
}

28 changes: 21 additions & 7 deletions service/history/outbound_queue_active_task_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"go.temporal.io/server/common/debug"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/resource"
"go.temporal.io/server/service/history/configs"
"go.temporal.io/server/service/history/consts"
historyi "go.temporal.io/server/service/history/interfaces"
"go.temporal.io/server/service/history/queues"
Expand All @@ -24,7 +26,8 @@ const (

type outboundQueueActiveTaskExecutor struct {
stateMachineEnvironment
chasmEngine chasm.Engine
chasmEngine chasm.Engine
activityCommandTaskDispatcher *activityCommandTaskDispatcher
}

var _ queues.Executor = &outboundQueueActiveTaskExecutor{}
Expand All @@ -35,17 +38,26 @@ func newOutboundQueueActiveTaskExecutor(
logger log.Logger,
metricsHandler metrics.Handler,
chasmEngine chasm.Engine,
matchingRawClient resource.MatchingRawClient,
config *configs.Config,
) *outboundQueueActiveTaskExecutor {
scopedMetricsHandler := metricsHandler.WithTags(
metrics.OperationTag(metrics.OperationOutboundQueueProcessorScope),
)
return &outboundQueueActiveTaskExecutor{
stateMachineEnvironment: stateMachineEnvironment{
shardContext: shardCtx,
cache: workflowCache,
logger: logger,
metricsHandler: metricsHandler.WithTags(
metrics.OperationTag(metrics.OperationOutboundQueueProcessorScope),
),
shardContext: shardCtx,
cache: workflowCache,
logger: logger,
metricsHandler: scopedMetricsHandler,
},
chasmEngine: chasmEngine,
activityCommandTaskDispatcher: newActivityCommandTaskDispatcher(
matchingRawClient,
config,
scopedMetricsHandler,
logger,
),
}
}

Expand Down Expand Up @@ -92,6 +104,8 @@ func (e *outboundQueueActiveTaskExecutor) Execute(
return respond(e.executeStateMachineTask(ctx, task))
case *tasks.ChasmTask:
return respond(e.executeChasmSideEffectTask(ctx, task))
case *tasks.ActivityCommandTask:
return respond(e.activityCommandTaskDispatcher.execute(ctx, task))
}

return respond(queueserrors.NewUnprocessableTaskError(fmt.Sprintf("unknown task type '%T'", task)))
Expand Down
2 changes: 2 additions & 0 deletions service/history/outbound_queue_active_task_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ func (s *outboundQueueActiveTaskExecutorSuite) SetupTest() {
s.logger,
s.metricsHandler,
s.mockChasmEngine,
nil, // matchingRawClient - not used in these tests
nil, // config - not used in these tests
)
}

Expand Down
4 changes: 4 additions & 0 deletions service/history/outbound_queue_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/quotas"
"go.temporal.io/server/common/resource"
ctasks "go.temporal.io/server/common/tasks"
"go.temporal.io/server/common/telemetry"
"go.temporal.io/server/service/history/circuitbreakerpool"
Expand All @@ -31,6 +32,7 @@ type outboundQueueFactoryParams struct {

QueueFactoryBaseParams
CircuitBreakerPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool
MatchingRawClient resource.MatchingRawClient
}

type groupLimiter struct {
Expand Down Expand Up @@ -227,6 +229,8 @@ func (f *outboundQueueFactory) CreateQueue(
logger,
metricsHandler,
f.ChasmEngine,
f.MatchingRawClient,
shardContext.GetConfig(),
)

standbyExecutor := newOutboundQueueStandbyTaskExecutor(
Expand Down
149 changes: 149 additions & 0 deletions tests/activity_command_task_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package tests

import (
"context"
"testing"
"time"

commandpb "go.temporal.io/api/command/v1"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
workerpb "go.temporal.io/api/worker/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/tests/testcore"
"google.golang.org/protobuf/types/known/durationpb"
)

// TestDispatchCancelToWorker tests that when an activity cancellation is requested,
// the server dispatches an ActivityCommandTask to the worker's control queue via Nexus.
func TestDispatchCancelToWorker(t *testing.T) {
env := testcore.NewEnv(t, testcore.WithDynamicConfig(dynamicconfig.EnableActivityCancellationNexusTask, true))

ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()

tv := env.Tv()
poller := env.TaskPoller()

// Get the control queue name from test vars
controlQueueName := tv.ControlQueueName(env.Namespace().String())
t.Logf("WorkerInstanceKey: %s", tv.WorkerInstanceKey())
t.Logf("ControlQueueName: %s", controlQueueName)

// Start the workflow
startResp, err := env.FrontendClient().StartWorkflowExecution(ctx, &workflowservice.StartWorkflowExecutionRequest{
RequestId: tv.Any().String(),
Namespace: env.Namespace().String(),
WorkflowId: tv.WorkflowID(),
WorkflowType: tv.WorkflowType(),
TaskQueue: tv.TaskQueue(),
WorkflowExecutionTimeout: durationpb.New(60 * time.Second),
WorkflowTaskTimeout: durationpb.New(10 * time.Second),
})
env.NoError(err)
t.Logf("Started workflow: %s/%s", tv.WorkflowID(), startResp.RunId)

// Poll and complete first workflow task - schedule the activity
_, err = poller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
return &workflowservice.RespondWorkflowTaskCompletedRequest{
Commands: []*commandpb.Command{
{
CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK,
Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{
ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{
ActivityId: tv.ActivityID(),
ActivityType: tv.ActivityType(),
TaskQueue: tv.TaskQueue(),
ScheduleToCloseTimeout: durationpb.New(60 * time.Second),
StartToCloseTimeout: durationpb.New(60 * time.Second),
},
},
},
},
}, nil
})
env.NoError(err)
t.Log("Scheduled activity")

// Poll for activity task and start running the activity.
activityPollResp, err := env.FrontendClient().PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{
Namespace: env.Namespace().String(),
TaskQueue: tv.TaskQueue(),
Identity: tv.WorkerIdentity(),
WorkerInstanceKey: tv.WorkerInstanceKey(),
WorkerControlTaskQueue: controlQueueName,
})
env.NoError(err)
env.NotNil(activityPollResp)
env.NotEmpty(activityPollResp.TaskToken)
t.Log("Activity started with WorkerInstanceKey")

// Request workflow cancellation
t.Log("Requesting workflow cancellation...")
_, err = env.FrontendClient().RequestCancelWorkflowExecution(ctx, &workflowservice.RequestCancelWorkflowExecutionRequest{
Namespace: env.Namespace().String(),
WorkflowExecution: &commonpb.WorkflowExecution{
WorkflowId: tv.WorkflowID(),
RunId: startResp.RunId,
},
})
env.NoError(err)

// Simulate what the SDK does when a workflow is cancelled.
// Poll and complete the workflow task with RequestCancelActivityTask command.
// This sets CancelRequested=true and triggers the dispatch of ActivityCommandTask.
_, err = poller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
// Find the scheduled event ID
var scheduledEventID int64
for _, event := range task.History.Events {
if event.EventType == enumspb.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED {
scheduledEventID = event.EventId
break
}
}
return &workflowservice.RespondWorkflowTaskCompletedRequest{
Commands: []*commandpb.Command{
{
CommandType: enumspb.COMMAND_TYPE_REQUEST_CANCEL_ACTIVITY_TASK,
Attributes: &commandpb.Command_RequestCancelActivityTaskCommandAttributes{
RequestCancelActivityTaskCommandAttributes: &commandpb.RequestCancelActivityTaskCommandAttributes{
ScheduledEventId: scheduledEventID,
},
},
},
},
}, nil
})
env.NoError(err)
t.Log("Workflow task completed with RequestCancelActivityTask command")

// Poll Nexus control queue until we receive the notification request
var nexusPollResp *workflowservice.PollNexusTaskQueueResponse
env.Eventually(func() bool {
pollCtx, pollCancel := context.WithTimeout(ctx, 5*time.Second)
defer pollCancel()
resp, err := env.FrontendClient().PollNexusTaskQueue(pollCtx, &workflowservice.PollNexusTaskQueueRequest{
Namespace: env.Namespace().String(),
TaskQueue: &taskqueuepb.TaskQueue{Name: controlQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL},
Identity: tv.WorkerIdentity(),
})
if err == nil && resp != nil && resp.Request != nil {
nexusPollResp = resp
return true
}
return false
}, 120*time.Second, 100*time.Millisecond, "Timed out waiting for Nexus task")

// Verify we received the notification request on the control queue
env.NotNil(nexusPollResp.Request, "Expected to receive Nexus request on control queue")

startOp := nexusPollResp.Request.GetStartOperation()
env.NotNil(startOp, "Expected StartOperation in Nexus request")
env.Equal(workerpb.WorkerService.ServiceName, startOp.Service, "Expected WorkerService")
env.Equal(workerpb.WorkerService.NotifyActivity.Name(), startOp.Operation, "Expected notify-activity operation")
t.Log("SUCCESS: Received notify-activity Nexus request on control queue")
}
Loading