diff --git a/backend/api/api.go b/backend/api/api.go index 3419c1a98..3e71e43a2 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -125,6 +125,9 @@ func Init( if err != nil { return nil, err } + if err = observabilityHandler.RunTaskScheduleTask(ctx); err != nil { + return nil, err + } observabilityHandler.RunAsync(ctx) return &apis.APIHandler{ diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index f090ab564..7b5ae44ac 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,7 +8,6 @@ package apis import ( "context" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" diff --git a/backend/modules/observability/application/convertor/page.go b/backend/modules/observability/application/convertor/page.go new file mode 100755 index 000000000..8fe8e09dd --- /dev/null +++ b/backend/modules/observability/application/convertor/page.go @@ -0,0 +1,30 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" + entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" +) + +func OrderByDTO2DO(orderBy *common.OrderBy) *entity.OrderBy { + if orderBy == nil { + return nil + } + return &entity.OrderBy{ + Field: orderBy.GetField(), + IsAsc: orderBy.GetIsAsc(), + } +} + +func OrderByDO2DTO(orderBy *entity.OrderBy) *common.OrderBy { + if orderBy == nil { + return nil + } + return &common.OrderBy{ + Field: ptr.Of(orderBy.Field), + IsAsc: ptr.Of(orderBy.IsAsc), + } +} diff --git a/backend/modules/observability/application/convertor/task/filter.go b/backend/modules/observability/application/convertor/task/filter.go new file mode 100755 index 000000000..a22c0efb2 --- /dev/null +++ b/backend/modules/observability/application/convertor/task/filter.go @@ -0,0 +1,103 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package task + +import ( + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" +) + +func TaskFiltersDTO2DO(filters *filter.TaskFilterFields) *entity.TaskFilterFields { + if filters == nil { + return nil + } + result := &entity.TaskFilterFields{} + if filters.QueryAndOr != nil { + relation := entity.QueryRelation(*filters.QueryAndOr) + result.QueryAndOr = &relation + } + if len(filters.FilterFields) == 0 { + return result + } + result.FilterFields = make([]*entity.TaskFilterField, 0, len(filters.FilterFields)) + for _, field := range filters.FilterFields { + if field == nil { + continue + } + result.FilterFields = append(result.FilterFields, taskFilterFieldDTO2DO(field)) + } + return result +} + +func taskFilterFieldDTO2DO(field *filter.TaskFilterField) *entity.TaskFilterField { + if field == nil { + return nil + } + result := &entity.TaskFilterField{ + Values: append([]string(nil), field.Values...), + SubFilter: taskFilterFieldDTO2DO(field.SubFilter), + } + if field.FieldName != nil { + name := entity.TaskFieldName(*field.FieldName) + result.FieldName = &name + } + if field.FieldType != nil { + fieldType := entity.FieldType(*field.FieldType) + result.FieldType = &fieldType + } + if field.QueryType != nil { + queryType := entity.QueryType(*field.QueryType) + result.QueryType = &queryType + } + if field.QueryAndOr != nil { + relation := entity.QueryRelation(*field.QueryAndOr) + result.QueryAndOr = &relation + } + return result +} + +func TaskFiltersDO2DTO(filters *entity.TaskFilterFields) *filter.TaskFilterFields { + if filters == nil { + return nil + } + result := &filter.TaskFilterFields{} + if filters.QueryAndOr != nil { + result.QueryAndOr = ptr.Of(filter.QueryRelation(*filters.QueryAndOr)) + } + if len(filters.FilterFields) == 0 { + return result + } + result.FilterFields = make([]*filter.TaskFilterField, 0, len(filters.FilterFields)) + for _, field := range filters.FilterFields { + if field == nil { + continue + } + result.FilterFields = append(result.FilterFields, taskFilterFieldDO2DTO(field)) + } + return result +} + +func taskFilterFieldDO2DTO(field *entity.TaskFilterField) *filter.TaskFilterField { + if field == nil { + return nil + } + result := &filter.TaskFilterField{ + Values: append([]string(nil), field.Values...), + SubFilter: taskFilterFieldDO2DTO(field.SubFilter), + } + if field.FieldName != nil { + result.FieldName = ptr.Of(string(*field.FieldName)) + } + if field.FieldType != nil { + result.FieldType = ptr.Of(filter.FieldType(*field.FieldType)) + } + if field.QueryType != nil { + result.QueryType = ptr.Of(filter.QueryType(*field.QueryType)) + } + if field.QueryAndOr != nil { + result.QueryAndOr = ptr.Of(filter.QueryRelation(*field.QueryAndOr)) + } + return result +} diff --git a/backend/modules/observability/application/convertor/task/task.go b/backend/modules/observability/application/convertor/task/task.go index 99b2c62ee..38831ee19 100644 --- a/backend/modules/observability/application/convertor/task/task.go +++ b/backend/modules/observability/application/convertor/task/task.go @@ -18,11 +18,11 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" entity_common "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" ) func TaskDOs2DTOs(ctx context.Context, taskPOs []*entity.ObservabilityTask, userInfos map[string]*entity_common.UserInfo) []*task.Task { @@ -61,8 +61,8 @@ func TaskDO2DTO(ctx context.Context, v *entity.ObservabilityTask, userMap map[st Name: v.Name, Description: v.Description, WorkspaceID: ptr.Of(v.WorkspaceID), - TaskType: v.TaskType, - TaskStatus: ptr.Of(v.TaskStatus), + TaskType: task.TaskType(v.TaskType), + TaskStatus: ptr.Of(task.TaskStatus(v.TaskStatus)), Rule: RuleDO2DTO(v.SpanFilter, v.EffectiveTime, v.Sampler, v.BackfillEffectiveTime), TaskConfig: TaskConfigDO2DTO(v.TaskConfig), TaskDetail: taskDetail, @@ -84,8 +84,8 @@ func TaskRunDO2DTO(ctx context.Context, v *entity.TaskRun, userMap map[string]*e ID: v.ID, WorkspaceID: v.WorkspaceID, TaskID: v.TaskID, - TaskType: v.TaskType, - RunStatus: v.RunStatus, + TaskType: task.TaskRunType(v.TaskType), + RunStatus: task.RunStatus(v.RunStatus), RunDetail: RunDetailDO2DTO(v.RunDetail), BackfillRunDetail: BackfillRunDetailDO2DTO(v.BackfillDetail), RunStartAt: v.RunStartAt.UnixMilli(), @@ -177,8 +177,8 @@ func SpanFilterDO2DTO(spanFilter *entity.SpanFilterFields) *filter.SpanFilterFie return &filter.SpanFilterFields{ Filters: convertor.FilterFieldsDO2DTO(&spanFilter.Filters), - PlatformType: &spanFilter.PlatformType, - SpanListType: &spanFilter.SpanListType, + PlatformType: lo.ToPtr(common.PlatformType(spanFilter.PlatformType)), + SpanListType: lo.ToPtr(common.SpanListType(spanFilter.SpanListType)), } } @@ -204,7 +204,7 @@ func SamplerDO2DTO(sampler *entity.Sampler) *task.Sampler { IsCycle: ptr.Of(sampler.IsCycle), CycleCount: ptr.Of(sampler.CycleCount), CycleInterval: ptr.Of(sampler.CycleInterval), - CycleTimeUnit: ptr.Of(sampler.CycleTimeUnit), + CycleTimeUnit: ptr.Of(string(sampler.CycleTimeUnit)), } } @@ -305,7 +305,7 @@ func UserInfoPO2DO(userInfo *entity_common.UserInfo, userID string) *common.User } } -func TaskDTO2DO(taskDTO *task.Task, userID string, spanFilters *entity.SpanFilterFields) *entity.ObservabilityTask { +func TaskDTO2DO(taskDTO *task.Task) *entity.ObservabilityTask { if taskDTO == nil { return nil } @@ -316,31 +316,16 @@ func TaskDTO2DO(taskDTO *task.Task, userID string, spanFilters *entity.SpanFilte if taskDTO.GetBaseInfo().GetUpdatedBy() != nil { updatedBy = taskDTO.GetBaseInfo().GetUpdatedBy().GetUserID() } - if userID != "" { - createdBy = userID - updatedBy = userID - } else { - if taskDTO.GetBaseInfo().GetCreatedBy() != nil { - createdBy = taskDTO.GetBaseInfo().GetCreatedBy().GetUserID() - } - if taskDTO.GetBaseInfo().GetUpdatedBy() != nil { - updatedBy = taskDTO.GetBaseInfo().GetUpdatedBy().GetUserID() - } - } - var spanFilterDO *entity.SpanFilterFields - if spanFilters != nil { - spanFilterDO = spanFilters - } else { - spanFilterDO = SpanFilterDTO2DO(taskDTO.GetRule().GetSpanFilters()) - } + + spanFilterDO := SpanFilterDTO2DO(taskDTO.GetRule().GetSpanFilters()) return &entity.ObservabilityTask{ ID: taskDTO.GetID(), WorkspaceID: taskDTO.GetWorkspaceID(), Name: taskDTO.GetName(), Description: ptr.Of(taskDTO.GetDescription()), - TaskType: taskDTO.GetTaskType(), - TaskStatus: taskDTO.GetTaskStatus(), + TaskType: entity.TaskType(taskDTO.GetTaskType()), + TaskStatus: entity.TaskStatus(taskDTO.GetTaskStatus()), TaskDetail: RunDetailDTO2DO(taskDTO.GetTaskDetail()), SpanFilter: spanFilterDO, EffectiveTime: EffectiveTimeDTO2DO(taskDTO.GetRule().GetEffectiveTime()), @@ -359,8 +344,8 @@ func SpanFilterDTO2DO(spanFilterFields *filter.SpanFilterFields) *entity.SpanFil return nil } return &entity.SpanFilterFields{ - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, + PlatformType: loop_span.PlatformType(*spanFilterFields.PlatformType), + SpanListType: loop_span.SpanListType(*spanFilterFields.SpanListType), Filters: *convertor.FilterFieldsDTO2DO(spanFilterFields.Filters), } } @@ -396,7 +381,7 @@ func SamplerDTO2DO(sampler *task.Sampler) *entity.Sampler { IsCycle: sampler.GetIsCycle(), CycleCount: sampler.GetCycleCount(), CycleInterval: sampler.GetCycleInterval(), - CycleTimeUnit: sampler.GetCycleTimeUnit(), + CycleTimeUnit: entity.TimeUnit(sampler.GetCycleTimeUnit()), } } @@ -408,6 +393,7 @@ func TaskConfigDTO2DO(taskConfig *task.TaskConfig) *entity.TaskConfig { for _, autoEvaluateConfig := range taskConfig.AutoEvaluateConfigs { var fieldMappings []*entity.EvaluateFieldMapping if len(autoEvaluateConfig.FieldMappings) > 0 { + // todo tyf 这段逻辑挪到service层 var evalSetNames []string jspnPathMapping := make(map[string]string) for _, config := range autoEvaluateConfig.FieldMappings { @@ -471,8 +457,8 @@ func TaskRunDTO2DO(taskRun *task.TaskRun) *entity.TaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: entity.TaskRunType(taskRun.TaskType), + RunStatus: entity.TaskRunStatus(taskRun.RunStatus), RunDetail: RunDetailDTO2DO(taskRun.RunDetail), BackfillDetail: BackfillRunDetailDTO2DO(taskRun.BackfillRunDetail), RunStartAt: time.UnixMilli(taskRun.RunStartAt), @@ -531,82 +517,6 @@ func BackfillRunDetailDTO2DO(v *task.BackfillDetail) *entity.BackfillDetail { } } -func CheckEffectiveTime(ctx context.Context, effectiveTime *task.EffectiveTime, taskStatus task.TaskStatus, effectiveTimeDO *entity.EffectiveTime) (*entity.EffectiveTime, error) { - if effectiveTimeDO == nil { - logs.CtxError(ctx, "EffectiveTimePO2DO error") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("effective time is nil")) - } - var validEffectiveTime entity.EffectiveTime - // 开始时间不能大于结束时间 - if effectiveTime.GetStartAt() >= effectiveTime.GetEndAt() { - logs.CtxError(ctx, "Start time must be less than end time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be less than end time")) - } - // 开始、结束时间不能小于当前时间 - if effectiveTimeDO.StartAt != effectiveTime.GetStartAt() && effectiveTime.GetStartAt() < time.Now().UnixMilli() { - logs.CtxError(ctx, "update time must be greater than current time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) - } - if effectiveTimeDO.EndAt != effectiveTime.GetEndAt() && effectiveTime.GetEndAt() < time.Now().UnixMilli() { - logs.CtxError(ctx, "update time must be greater than current time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) - } - validEffectiveTime.StartAt = effectiveTimeDO.StartAt - validEffectiveTime.EndAt = effectiveTimeDO.EndAt - switch taskStatus { - case task.TaskStatusUnstarted: - if validEffectiveTime.StartAt != 0 { - validEffectiveTime.StartAt = *effectiveTime.StartAt - } - if validEffectiveTime.EndAt != 0 { - validEffectiveTime.EndAt = *effectiveTime.EndAt - } - case task.TaskStatusRunning, task.TaskStatusPending: - if validEffectiveTime.EndAt != 0 { - validEffectiveTime.EndAt = *effectiveTime.EndAt - } - default: - logs.CtxError(ctx, "Invalid task status:%s", taskStatus) - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - return &validEffectiveTime, nil -} - -func CheckTaskStatus(ctx context.Context, taskStatus task.TaskStatus, currentTaskStatus task.TaskStatus) (task.TaskStatus, error) { - var validTaskStatus task.TaskStatus - // [0530]todo: 任务状态校验 - switch taskStatus { - case task.TaskStatusUnstarted: - if currentTaskStatus == task.TaskStatusUnstarted { - validTaskStatus = taskStatus - } else { - logs.CtxError(ctx, "Invalid task status:%s", taskStatus) - return "", errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - case task.TaskStatusRunning: - if currentTaskStatus == task.TaskStatusUnstarted || currentTaskStatus == task.TaskStatusPending { - validTaskStatus = taskStatus - } else { - logs.CtxError(ctx, "Invalid task status:%s,currentTaskStatus:%s", taskStatus, currentTaskStatus) - return "", errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - case task.TaskStatusPending: - if currentTaskStatus == task.TaskStatusRunning { - validTaskStatus = task.TaskStatusPending - } - case task.TaskStatusDisabled: - if currentTaskStatus == task.TaskStatusUnstarted || currentTaskStatus == task.TaskStatusPending { - validTaskStatus = task.TaskStatusDisabled - } - case task.TaskStatusSuccess: - if currentTaskStatus != task.TaskStatusSuccess { - validTaskStatus = task.TaskStatusSuccess - } - } - - return validTaskStatus, nil -} - func getLastPartAfterDot(s string) string { s = strings.TrimRight(s, ".") lastDotIndex := strings.LastIndex(s, ".") diff --git a/backend/modules/observability/application/convertor/task/task_test.go b/backend/modules/observability/application/convertor/task/task_test.go index 9a48ea739..96d3d529d 100755 --- a/backend/modules/observability/application/convertor/task/task_test.go +++ b/backend/modules/observability/application/convertor/task/task_test.go @@ -19,8 +19,6 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" entityCommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" ) @@ -102,7 +100,7 @@ func TestTaskDOs2DTOs(t *testing.T) { IsCycle: true, CycleCount: 2, CycleInterval: 3, - CycleTimeUnit: kitTask.TimeUnitDay, + CycleTimeUnit: entity.TimeUnitDay, }, TaskConfig: &entity.TaskConfig{}, CreatedAt: now, @@ -239,20 +237,9 @@ func TestTaskDTO2DO(t *testing.T) { }, } - overrideSpan := &entity.SpanFilterFields{ - PlatformType: kitCommon.PlatformTypeCozeloop, - SpanListType: kitCommon.SpanListTypeRootSpan, - Filters: loop_span.FilterFields{ - QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), - FilterFields: []*loop_span.FilterField{}, - }, - } - - entityTask := TaskDTO2DO(dto, "override", overrideSpan) + entityTask := TaskDTO2DO(dto) if assert.NotNil(t, entityTask) { assert.Equal(t, int64(11), entityTask.ID) - assert.Equal(t, "override", entityTask.CreatedBy) - assert.Equal(t, overrideSpan, entityTask.SpanFilter) assert.NotZero(t, entityTask.CreatedAt.Unix()) assert.Equal(t, int64(1), entityTask.TaskDetail.SuccessCount) assert.Equal(t, float64(0.3), entityTask.Sampler.SampleRate) @@ -277,149 +264,6 @@ func TestSpanFilterPO2DO(t *testing.T) { assert.Nil(t, SpanFilterPO2DO(ctx, &invalid)) } -func TestCheckEffectiveTime(t *testing.T) { - t.Parallel() - - ctx := context.Background() - now := time.Now() - - getCode := func(err error) int32 { - statusErr, ok := errorx.FromStatusError(err) - if !ok { - return 0 - } - return statusErr.Code() - } - - futureStart := now.Add(2 * time.Hour).UnixMilli() - futureEnd := now.Add(3 * time.Hour).UnixMilli() - - cases := []struct { - name string - effective *kitTask.EffectiveTime - status kitTask.TaskStatus - current *entity.EffectiveTime - wantStart int64 - wantEnd int64 - wantErrCode int32 - }{ - { - name: "nil current", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusUnstarted, - current: nil, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "start after end", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureEnd), EndAt: gptr.Of(futureStart)}, - status: kitTask.TaskStatusUnstarted, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "update start in past", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(now.Add(-time.Hour).UnixMilli()), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "update end in past", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(now.Add(-time.Hour).UnixMilli())}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "unstarted updates both", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusUnstarted, - current: &entity.EffectiveTime{StartAt: 100, EndAt: 200}, - wantStart: futureStart, - wantEnd: futureEnd, - }, - { - name: "running keeps start", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureEnd), EndAt: gptr.Of(futureEnd + 1000)}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: 111, EndAt: 222}, - wantStart: 111, - wantEnd: futureEnd + 1000, - }, - { - name: "invalid status", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatus("unknown"), - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - } - - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := CheckEffectiveTime(ctx, tc.effective, tc.status, tc.current) - if tc.wantErrCode != 0 { - assert.NotNil(t, err) - assert.Equal(t, tc.wantErrCode, getCode(err)) - assert.Nil(t, got) - return - } - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, tc.wantStart, got.StartAt) - assert.Equal(t, tc.wantEnd, got.EndAt) - } - }) - } -} - -func TestCheckTaskStatus(t *testing.T) { - t.Parallel() - - ctx := context.Background() - getCode := func(err error) int32 { - statusErr, ok := errorx.FromStatusError(err) - if !ok { - return 0 - } - return statusErr.Code() - } - - cases := []struct { - name string - status kitTask.TaskStatus - current kitTask.TaskStatus - want kitTask.TaskStatus - wantErrCode int32 - }{ - {"unstarted ok", kitTask.TaskStatusUnstarted, kitTask.TaskStatusUnstarted, kitTask.TaskStatusUnstarted, 0}, - {"unstarted invalid", kitTask.TaskStatusUnstarted, kitTask.TaskStatusRunning, "", obErrorx.CommercialCommonInvalidParamCodeCode}, - {"running ok", kitTask.TaskStatusRunning, kitTask.TaskStatusPending, kitTask.TaskStatusRunning, 0}, - {"running invalid", kitTask.TaskStatusRunning, kitTask.TaskStatusSuccess, "", obErrorx.CommercialCommonInvalidParamCodeCode}, - {"pending ok", kitTask.TaskStatusPending, kitTask.TaskStatusRunning, kitTask.TaskStatusPending, 0}, - {"disabled ok", kitTask.TaskStatusDisabled, kitTask.TaskStatusPending, kitTask.TaskStatusDisabled, 0}, - {"success ok", kitTask.TaskStatusSuccess, kitTask.TaskStatusRunning, kitTask.TaskStatusSuccess, 0}, - {"pending no transition", kitTask.TaskStatusPending, kitTask.TaskStatusDisabled, "", 0}, - } - - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := CheckTaskStatus(ctx, tc.status, tc.current) - if tc.wantErrCode != 0 { - assert.Equal(t, tc.wantErrCode, getCode(err)) - return - } - assert.NoError(t, err) - assert.Equal(t, tc.want, got) - }) - } -} - func TestGetLastPartAfterDot(t *testing.T) { t.Parallel() diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index 5900639a2..054854056 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -9,34 +9,32 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - domain_task "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/task" "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" - task_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - trace_Svc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" ) type ITaskQueueConsumer interface { SpanTrigger(ctx context.Context, event *entity.RawSpan) error - CallBack(ctx context.Context, event *entity.AutoEvalEvent) error - Correction(ctx context.Context, event *entity.CorrectionEvent) error + AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error + AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error BackFill(ctx context.Context, event *entity.BackFillEvent) error } + type ITaskApplication interface { task.TaskService ITaskQueueConsumer + RunTaskScheduleTask(ctx context.Context) error } func NewTaskApplication( @@ -46,30 +44,33 @@ func NewTaskApplication( evaluationService rpc.IEvaluationRPCAdapter, userService rpc.IUserProvider, tracehubSvc tracehub.ITraceHubService, - taskProcessor task_processor.TaskProcessor, - buildHelper trace_Svc.TraceFilterProcessorBuilder, + taskProcessor processor.TaskProcessor, + taskCallbackService service.ITaskCallbackService, + scheduledTasks []scheduledtask.ScheduledTask, ) (ITaskApplication, error) { return &TaskApplication{ - taskSvc: taskService, - authSvc: authService, - evalSvc: evalService, - evaluationSvc: evaluationService, - userSvc: userService, - tracehubSvc: tracehubSvc, - taskProcessor: taskProcessor, - buildHelper: buildHelper, + taskSvc: taskService, + authSvc: authService, + evalSvc: evalService, + evaluationSvc: evaluationService, + userSvc: userService, + tracehubSvc: tracehubSvc, + taskProcessor: taskProcessor, + taskCallbackSvc: taskCallbackService, + scheduledTasks: scheduledTasks, }, nil } type TaskApplication struct { - taskSvc service.ITaskService - authSvc rpc.IAuthProvider - evalSvc rpc.IEvaluatorRPCAdapter - evaluationSvc rpc.IEvaluationRPCAdapter - userSvc rpc.IUserProvider - tracehubSvc tracehub.ITraceHubService - taskProcessor task_processor.TaskProcessor - buildHelper trace_Svc.TraceFilterProcessorBuilder + taskSvc service.ITaskService + authSvc rpc.IAuthProvider + evalSvc rpc.IEvaluatorRPCAdapter + evaluationSvc rpc.IEvaluationRPCAdapter + userSvc rpc.IUserProvider + tracehubSvc tracehub.ITraceHubService + taskProcessor processor.TaskProcessor + taskCallbackSvc service.ITaskCallbackService + scheduledTasks []scheduledtask.ScheduledTask } func (t *TaskApplication) CheckTaskName(ctx context.Context, req *task.CheckTaskNameRequest) (*task.CheckTaskNameResponse, error) { @@ -114,13 +115,13 @@ func (t *TaskApplication) CreateTask(ctx context.Context, req *task.CreateTaskRe if userID == "" { return nil, errorx.NewByCode(obErrorx.UserParseFailedCode) } + // 创建task - req.Task.TaskStatus = ptr.Of(domain_task.TaskStatusUnstarted) - spanFilers, err := t.buildSpanFilters(ctx, req.Task.GetRule().GetSpanFilters(), req.GetTask().GetWorkspaceID()) - if err != nil { - return nil, err - } - sResp, err := t.taskSvc.CreateTask(ctx, &service.CreateTaskReq{Task: tconv.TaskDTO2DO(req.GetTask(), userID, spanFilers)}) + taskDO := tconv.TaskDTO2DO(req.GetTask()) + taskDO.TaskStatus = entity.TaskStatusUnstarted + taskDO.CreatedBy = userID + taskDO.UpdatedBy = userID + sResp, err := t.taskSvc.CreateTask(ctx, &service.CreateTaskReq{Task: taskDO}) if err != nil { return resp, err } @@ -128,50 +129,6 @@ func (t *TaskApplication) CreateTask(ctx context.Context, req *task.CreateTaskRe return &task.CreateTaskResponse{TaskID: sResp.TaskID}, nil } -func (t *TaskApplication) buildSpanFilters(ctx context.Context, spanFilterFields *filter.SpanFilterFields, workspaceID int64) (*entity.SpanFilterFields, error) { - spanFilters := &entity.SpanFilterFields{ - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, - } - filters := convertor.FilterFieldsDTO2DO(spanFilterFields.GetFilters()) - spanFilters.Filters = *filters - switch spanFilterFields.GetPlatformType() { - case common.PlatformTypeCozeBot, common.PlatformTypeProject, common.PlatformTypeWorkflow, common.PlatformTypeInnerCozeBot: - platformFilter, err := t.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(spanFilterFields.GetPlatformType())) - if err != nil { - return nil, err - } - env := &span_filter.SpanEnv{ - WorkspaceID: workspaceID, - } - basicFilter, forceQuery, err := platformFilter.BuildBasicSpanFilter(ctx, env) - if err != nil { - return nil, err - } else if len(basicFilter) == 0 && !forceQuery { // if it's null, no need to query from ck - return nil, nil - } - for _, filter := range basicFilter { - filters.FilterFields = append(filters.FilterFields, &loop_span.FilterField{ - FieldName: filter.FieldName, - FieldType: filter.FieldType, - Values: filter.Values, - QueryType: filter.QueryType, - QueryAndOr: filter.QueryAndOr, - SubFilter: filter.SubFilter, - Hidden: true, - }) - } - - return &entity.SpanFilterFields{ - Filters: *filters, - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, - }, nil - default: - return spanFilters, nil - } -} - func (t *TaskApplication) validateCreateTaskReq(ctx context.Context, req *task.CreateTaskRequest) error { // 参数验证 if req == nil || req.GetTask() == nil { @@ -211,12 +168,16 @@ func (t *TaskApplication) UpdateTask(ctx context.Context, req *task.UpdateTaskRe strconv.FormatInt(req.GetTaskID(), 10)); err != nil { return nil, err } + var taskStatus *entity.TaskStatus + if req.TaskStatus != nil { + taskStatus = lo.ToPtr(entity.TaskStatus(req.GetTaskStatus())) + } err := t.taskSvc.UpdateTask(ctx, &service.UpdateTaskReq{ TaskID: req.GetTaskID(), WorkspaceID: req.GetWorkspaceID(), - TaskStatus: req.TaskStatus, + TaskStatus: taskStatus, Description: req.Description, - EffectiveTime: req.EffectiveTime, + EffectiveTime: tconv.EffectiveTimeDTO2DO(req.EffectiveTime), SampleRate: req.SampleRate, }) if err != nil { @@ -239,12 +200,13 @@ func (t *TaskApplication) ListTasks(ctx context.Context, req *task.ListTasksRequ false); err != nil { return resp, err } + sResp, err := t.taskSvc.ListTasks(ctx, &service.ListTasksReq{ WorkspaceID: req.GetWorkspaceID(), - TaskFilters: req.GetTaskFilters(), + TaskFilters: tconv.TaskFiltersDTO2DO(req.GetTaskFilters()), Limit: req.GetLimit(), Offset: req.GetOffset(), - OrderBy: req.GetOrderBy(), + OrderBy: convertor.OrderByDTO2DO(req.GetOrderBy()), }) if err != nil { return resp, err @@ -252,9 +214,21 @@ func (t *TaskApplication) ListTasks(ctx context.Context, req *task.ListTasksRequ if sResp == nil { return resp, nil } + + userMap := make(map[string]bool) + for _, tp := range sResp.Tasks { + userMap[tp.CreatedBy] = true + userMap[tp.UpdatedBy] = true + } + _, userInfoMap, err := t.userSvc.GetUserInfo(ctx, lo.Keys(userMap)) + if err != nil { + logs.CtxError(ctx, "MGetUserInfo err:%v", err) + } + tasks := tconv.TaskDOs2DTOs(ctx, sResp.Tasks, userInfoMap) + return &task.ListTasksResponse{ - Tasks: sResp.Tasks, - Total: sResp.Total, + Tasks: tasks, + Total: &sResp.Total, }, nil } @@ -271,6 +245,7 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) false); err != nil { return resp, err } + sResp, err := t.taskSvc.GetTask(ctx, &service.GetTaskReq{ TaskID: req.GetTaskID(), WorkspaceID: req.GetWorkspaceID(), @@ -282,23 +257,65 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) return resp, nil } + taskDO := sResp.Task + _, userInfoMap, err := t.userSvc.GetUserInfo(ctx, []string{taskDO.CreatedBy, taskDO.UpdatedBy}) + if err != nil { + logs.CtxError(ctx, "MGetUserInfo err:%v", err) + } + return &task.GetTaskResponse{ - Task: sResp.Task, + Task: tconv.TaskDO2DTO(ctx, taskDO, userInfoMap), }, nil } func (t *TaskApplication) SpanTrigger(ctx context.Context, event *entity.RawSpan) error { - return t.tracehubSvc.SpanTrigger(ctx, event) + span := event.RawSpanConvertToLoopSpan() + if span != nil { + if err := t.tracehubSvc.SpanTrigger(ctx, span); err != nil { + logs.CtxError(ctx, "SpanTrigger err:%v", err) + // span trigger 失败,不处理 + return nil + } + } + return nil } -func (t *TaskApplication) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { - return t.tracehubSvc.CallBack(ctx, event) +func (t *TaskApplication) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } + + return t.taskCallbackSvc.AutoEvalCallback(ctx, event) } -func (t *TaskApplication) Correction(ctx context.Context, event *entity.CorrectionEvent) error { - return t.tracehubSvc.Correction(ctx, event) +func (t *TaskApplication) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } + + return t.taskCallbackSvc.AutoEvalCorrection(ctx, event) } func (t *TaskApplication) BackFill(ctx context.Context, event *entity.BackFillEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } + return t.tracehubSvc.BackFill(ctx, event) } + +func (t *TaskApplication) RunTaskScheduleTask(ctx context.Context) error { + for _, scheduledTask := range t.scheduledTasks { + if err := scheduledTask.Run(); err != nil { + logs.CtxError(ctx, "RunTaskScheduleTask err:%v", err) + return err + } + } + return nil +} diff --git a/backend/modules/observability/application/task_test.go b/backend/modules/observability/application/task_test.go index db20bbf3f..46a1b25b2 100755 --- a/backend/modules/observability/application/task_test.go +++ b/backend/modules/observability/application/task_test.go @@ -11,6 +11,9 @@ import ( "time" "github.com/bytedance/gg/gptr" + tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -25,11 +28,6 @@ import ( svc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" svcmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/mocks" tracehubmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - traceSvc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - traceSvcMock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/mocks" - span_filter "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" - filtermocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter/mocks" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) @@ -310,224 +308,6 @@ func TestTaskApplication_CreateTask(t *testing.T) { } } -func TestTaskApplication_buildSpanFilters(t *testing.T) { - t.Parallel() - - type fields struct { - builder traceSvc.TraceFilterProcessorBuilder - } - - type args struct { - spanFilters *filterdto.SpanFilterFields - workspaceID int64 - } - - tests := []struct { - name string - fieldsBuilder func(ctrl *gomock.Controller, t *testing.T, a args) fields - args args - assertFunc func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) - }{ - { - name: "non supported platform returns original", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - return fields{} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{ - { - FieldName: gptr.Of("custom_field"), - FieldType: gptr.Of(filterdto.FieldTypeString), - Values: []string{"value"}, - }, - }, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeCozeloop), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 100, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, commondomain.PlatformTypeCozeloop, got.PlatformType) - assert.Equal(t, commondomain.SpanListTypeRootSpan, got.SpanListType) - dtoFilters := original.GetFilters().GetFilterFields() - if assert.Len(t, got.Filters.FilterFields, len(dtoFilters)) && len(dtoFilters) > 0 { - firstDTO := dtoFilters[0] - firstDomain := got.Filters.FilterFields[0] - if assert.NotNil(t, firstDTO.FieldName) { - assert.Equal(t, *firstDTO.FieldName, firstDomain.FieldName) - } - if assert.NotNil(t, firstDTO.FieldType) { - assert.Equal(t, loop_span.FieldType(*firstDTO.FieldType), firstDomain.FieldType) - } - assert.Equal(t, firstDTO.Values, firstDomain.Values) - assert.False(t, firstDomain.Hidden) - } - } - }, - }, - { - name: "build platform filter error", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeCozeBot)).Return(nil, errors.New("build platform error")) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeCozeBot), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 200, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.Nil(t, got) - assert.EqualError(t, err, "build platform error") - }, - }, - { - name: "build basic span filter error", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeWorkflow)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return nil, false, errors.New("build basic error") - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeWorkflow), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 300, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.Nil(t, got) - assert.EqualError(t, err, "build basic error") - }, - }, - { - name: "empty basic filter without force returns nil", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeInnerCozeBot)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return []*loop_span.FilterField{}, false, nil - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeInnerCozeBot), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 400, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - assert.Nil(t, got) - }, - }, - { - name: "merge platform filters success", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeProject)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return []*loop_span.FilterField{ - { - FieldName: loop_span.SpanFieldSpaceId, - FieldType: loop_span.FieldTypeString, - Values: []string{"tenant"}, - }, - }, false, nil - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{ - { - FieldName: gptr.Of("custom_field"), - FieldType: gptr.Of(filterdto.FieldTypeString), - Values: []string{"origin"}, - }, - }, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeProject), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 500, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, commondomain.PlatformTypeProject, got.PlatformType) - assert.Equal(t, commondomain.SpanListTypeRootSpan, got.SpanListType) - originalFilters := original.GetFilters().GetFilterFields() - if assert.Len(t, got.Filters.FilterFields, len(originalFilters)+1) && len(originalFilters) > 0 { - firstDomain := got.Filters.FilterFields[0] - firstDTO := originalFilters[0] - if assert.NotNil(t, firstDTO.FieldName) { - assert.Equal(t, *firstDTO.FieldName, firstDomain.FieldName) - } - assert.False(t, firstDomain.Hidden) - appended := got.Filters.FilterFields[len(originalFilters)] - assert.Equal(t, loop_span.SpanFieldSpaceId, appended.FieldName) - assert.True(t, appended.Hidden) - assert.Equal(t, []string{"tenant"}, appended.Values) - } - } - }, - }, - } - - for _, tt := range tests { - caseItem := tt - t.Run(caseItem.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - fields := caseItem.fieldsBuilder(ctrl, t, caseItem.args) - app := &TaskApplication{ - buildHelper: fields.builder, - } - - got, err := app.buildSpanFilters(context.Background(), caseItem.args.spanFilters, caseItem.args.workspaceID) - - caseItem.assertFunc(t, caseItem.args.spanFilters, got, err) - }) - } -} - func TestTaskApplication_UpdateTask(t *testing.T) { t.Parallel() @@ -647,8 +427,8 @@ func TestTaskApplication_ListTasks(t *testing.T) { t.Parallel() taskListResp := &svc.ListTasksResp{ - Tasks: []*taskdto.Task{{Name: "task1"}}, - Total: gptr.Of(int64(1)), + Tasks: []*entity.ObservabilityTask{{Name: "task1"}}, + Total: int64(1), } tests := []struct { name string @@ -708,10 +488,13 @@ func TestTaskApplication_ListTasks(t *testing.T) { }, }, { - name: "success", - ctx: context.Background(), - req: &taskapi.ListTasksRequest{WorkspaceID: 789}, - expectResp: &taskapi.ListTasksResponse{Tasks: taskListResp.Tasks, Total: taskListResp.Total}, + name: "success", + ctx: context.Background(), + req: &taskapi.ListTasksRequest{WorkspaceID: 789}, + expectResp: &taskapi.ListTasksResponse{ + Tasks: tconv.TaskDOs2DTOs(context.Background(), taskListResp.Tasks, map[string]*common.UserInfo{}), + Total: lo.ToPtr(taskListResp.Total), + }, fieldsBuilder: func(ctrl *gomock.Controller) (svc.ITaskService, rpc.IAuthProvider) { auth := rpcmock.NewMockIAuthProvider(ctrl) auth.EXPECT().CheckWorkspacePermission(gomock.Any(), rpc.AuthActionTraceTaskList, strconv.FormatInt(789, 10), false).Return(nil) @@ -759,7 +542,7 @@ func TestTaskApplication_ListTasks(t *testing.T) { func TestTaskApplication_GetTask(t *testing.T) { t.Parallel() - taskResp := &svc.GetTaskResp{Task: &taskdto.Task{Name: "task"}} + taskResp := &svc.GetTaskResp{Task: &entity.ObservabilityTask{Name: "task"}} tests := []struct { name string @@ -820,7 +603,7 @@ func TestTaskApplication_GetTask(t *testing.T) { name: "success", ctx: context.Background(), req: &taskapi.GetTaskRequest{WorkspaceID: 202, TaskID: 3}, - expectResp: &taskapi.GetTaskResponse{Task: taskResp.Task}, + expectResp: &taskapi.GetTaskResponse{Task: tconv.TaskDO2DTO(context.Background(), taskResp.Task, map[string]*common.UserInfo{})}, fieldsBuilder: func(ctrl *gomock.Controller) (svc.ITaskService, rpc.IAuthProvider) { auth := rpcmock.NewMockIAuthProvider(ctrl) auth.EXPECT().CheckWorkspacePermission(gomock.Any(), rpc.AuthActionTraceTaskList, strconv.FormatInt(202, 10), false).Return(nil) @@ -917,23 +700,23 @@ func TestTaskApplication_CallBack(t *testing.T) { event := &entity.AutoEvalEvent{} tests := []struct { name string - mockSvc func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService + mockSvc func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService expectErr bool }{ { name: "trace hub error", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().CallBack(gomock.Any(), event).Return(errors.New("hub error")) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCallback(gomock.Any(), event).Return(errors.New("hub error")) return svc }, expectErr: true, }, { name: "success", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().CallBack(gomock.Any(), event).Return(nil) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCallback(gomock.Any(), event).Return(nil) return svc }, }, @@ -947,8 +730,8 @@ func TestTaskApplication_CallBack(t *testing.T) { defer ctrl.Finish() traceSvc := caseItem.mockSvc(ctrl) - app := &TaskApplication{tracehubSvc: traceSvc} - err := app.CallBack(context.Background(), event) + app := &TaskApplication{taskCallbackSvc: traceSvc} + err := app.AutoEvalCallback(context.Background(), event) if caseItem.expectErr { assert.Error(t, err) } else { @@ -964,23 +747,23 @@ func TestTaskApplication_Correction(t *testing.T) { event := &entity.CorrectionEvent{} tests := []struct { name string - mockSvc func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService + mockSvc func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService expectErr bool }{ { name: "trace hub error", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().Correction(gomock.Any(), event).Return(errors.New("hub error")) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCorrection(gomock.Any(), event).Return(errors.New("hub error")) return svc }, expectErr: true, }, { name: "success", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().Correction(gomock.Any(), event).Return(nil) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCorrection(gomock.Any(), event).Return(nil) return svc }, }, @@ -994,8 +777,8 @@ func TestTaskApplication_Correction(t *testing.T) { defer ctrl.Finish() traceSvc := caseItem.mockSvc(ctrl) - app := &TaskApplication{tracehubSvc: traceSvc} - err := app.Correction(context.Background(), event) + app := &TaskApplication{taskCallbackSvc: traceSvc} + err := app.AutoEvalCorrection(context.Background(), event) if caseItem.expectErr { assert.Error(t, err) } else { diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index 9bf613bae..2185c7ad9 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -24,18 +24,20 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/auth/authservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file/fileservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" metrics_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" metric_service "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service" metric_general "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/general" metric_model "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/model" metric_service_def "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/service" metric_tool "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/tool" + task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" trepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" taskSvc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" task_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + taskst "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/exporter" @@ -55,7 +57,7 @@ import ( obrepo "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo" ckdao "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" mysqldao "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" - tredis "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + redis2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/auth" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/evaluation" @@ -77,10 +79,11 @@ var ( obrepo.NewTaskRepoImpl, // obrepo.NewTaskRunRepoImpl, mysqldao.NewTaskDaoImpl, - tredis.NewTaskDAO, - tredis.NewTaskRunDAO, + redis2.NewTaskDAO, + redis2.NewTaskRunDAO, mysqldao.NewTaskRunDaoImpl, mq2.NewBackfillProducerImpl, + NewScheduledTask, ) traceDomainSet = wire.NewSet( service.NewTraceServiceImpl, @@ -134,6 +137,7 @@ var ( evaluation.NewEvaluationRPCProvider, NewTaskLocker, traceDomainSet, + taskSvc.NewTaskCallbackServiceImpl, ) metricsSet = wire.NewSet( NewMetricApplication, @@ -281,10 +285,24 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, evaluationService rpc.IEvaluationRPCAdapter, taskRepo trepo.ITaskRepo, ) *task_processor.TaskProcessor { taskProcessor := task_processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, task_processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) + taskProcessor.Register(task_entity.TaskTypeAutoEval, task_processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) return taskProcessor } +func NewScheduledTask( + locker lock.ILocker, + config config.ITraceConfig, + traceHubService tracehub.ITraceHubService, + taskService taskSvc.ITaskService, + taskProcessor task_processor.TaskProcessor, + taskRepo trepo.ITaskRepo, +) []scheduledtask.ScheduledTask { + return []scheduledtask.ScheduledTask{ + taskst.NewStatusCheckTask(locker, config, traceHubService, taskService, taskProcessor, taskRepo), + taskst.NewLocalCacheRefreshTask(traceHubService, taskRepo), + } +} + func InitTraceApplication( db db.Provider, ckDb ck.Provider, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 6d30d8a1f..6bc481c78 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -24,18 +24,20 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/auth/authservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file/fileservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" config2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" service2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/general" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/model" service4 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/tool" + entity3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" service3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + scheduledtask2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" entity2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/exporter" @@ -55,7 +57,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo" ck2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + redis2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/auth" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/evaluation" @@ -72,7 +74,7 @@ import ( // Injectors from wire.go: -func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdable, meter metrics.Meter, mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, idgen2 idgen.IIDGenerator, fileClient fileservice.Client, benefit2 benefit.IBenefitService, authClient authservice.Client, userClient userservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, tagService tagservice.Client, datasetService datasetservice.Client) (ITraceApplication, error) { +func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis3 redis.Cmdable, meter metrics.Meter, mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, idgen2 idgen.IIDGenerator, fileClient fileservice.Client, benefit2 benefit.IBenefitService, authClient authservice.Client, userClient userservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, tagService tagservice.Client, datasetService datasetservice.Client) (ITraceApplication, error) { iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -104,9 +106,9 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdabl iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -129,7 +131,7 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdabl return iTraceApplication, nil } -func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, fileClient fileservice.Client, ckDb ck.Provider, benefit2 benefit.IBenefitService, limiterFactory limiter.IRateLimiterFactory, authClient authservice.Client, meter metrics.Meter, db2 db.Provider, redis2 redis.Cmdable, idgen2 idgen.IIDGenerator, evalService evaluatorservice.Client) (IObservabilityOpenAPIApplication, error) { +func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, fileClient fileservice.Client, ckDb ck.Provider, benefit2 benefit.IBenefitService, limiterFactory limiter.IRateLimiterFactory, authClient authservice.Client, meter metrics.Meter, db2 db.Provider, redis3 redis.Cmdable, idgen2 idgen.IIDGenerator, evalService evaluatorservice.Client) (IObservabilityOpenAPIApplication, error) { iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -161,9 +163,9 @@ func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoa iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -240,13 +242,12 @@ func InitTraceIngestionApplication(configFactory conf.IConfigLoaderFactory, ckDb return iTraceIngestionApplication, nil } -func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis2 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32) (ITaskApplication, error) { +func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis3 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32) (ITaskApplication, error) { iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) - iUserProvider := user.NewUserRPCProvider(userClient) iConfigLoader, err := NewTraceConfigLoader(configFactory) if err != nil { return nil, err @@ -260,11 +261,14 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iEvaluationRPCAdapter := evaluation.NewEvaluationRPCProvider(exptService) processorTaskProcessor := NewInitTaskProcessor(datasetServiceAdaptor, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iTaskRepo) - iTaskService, err := service3.NewTaskServiceImpl(iTaskRepo, iUserProvider, idgen2, iBackfillProducer, processorTaskProcessor) + iFileProvider := file.NewFileRPCProvider(fileClient) + traceFilterProcessorBuilder := NewTraceProcessorBuilder(iTraceConfig, iFileProvider, benefit2) + iTaskService, err := service3.NewTaskServiceImpl(iTaskRepo, idgen2, iBackfillProducer, processorTaskProcessor, traceFilterProcessorBuilder) if err != nil { return nil, err } iAuthProvider := auth.NewAuthProvider(authClient) + iUserProvider := user.NewUserRPCProvider(userClient) iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -278,14 +282,14 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iTenantProvider := tenant.NewTenantProvider(iTraceConfig) - iFileProvider := file.NewFileRPCProvider(fileClient) - traceFilterProcessorBuilder := NewTraceProcessorBuilder(iTraceConfig, iFileProvider, benefit2) - iLocker := NewTaskLocker(redis2) - iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, benefit2, aid, iBackfillProducer, iLocker, iConfigLoader) + iLocker := NewTaskLocker(redis3) + iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iTraceConfig) if err != nil { return nil, err } - iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, traceFilterProcessorBuilder) + iTaskCallbackService := service3.NewTaskCallbackServiceImpl(iTaskRepo, iTraceRepo, taskProcessor, iTenantProvider, iTraceConfig, benefit2) + v := NewScheduledTask(iLocker, iTraceConfig, iTraceHubService, iTaskService, taskProcessor, iTaskRepo) + iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService, v) if err != nil { return nil, err } @@ -296,7 +300,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto var ( taskDomainSet = wire.NewSet( - NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, dao.NewTaskDAO, dao.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, + NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis2.NewTaskDAO, redis2.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, NewScheduledTask, ) traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, config.NewTraceConfigCenter, tenant.NewTenantProvider, workspace.NewWorkspaceProvider, evaluator.NewEvaluatorRPCProvider, NewDatasetServiceAdapter, @@ -313,7 +317,7 @@ var ( NewOpenAPIApplication, auth.NewAuthProvider, traceDomainSet, ) taskSet = wire.NewSet(tracehub.NewTraceHubImpl, NewTaskApplication, auth.NewAuthProvider, user.NewUserRPCProvider, evaluation.NewEvaluationRPCProvider, NewTaskLocker, - traceDomainSet, + traceDomainSet, service3.NewTaskCallbackServiceImpl, ) metricsSet = wire.NewSet( NewMetricApplication, service2.NewMetricsService, repo.NewTraceMetricCKRepoImpl, tenant.NewTenantProvider, auth.NewAuthProvider, NewTraceConfigLoader, @@ -371,6 +375,17 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, evaluationService rpc.IEvaluationRPCAdapter, taskRepo repo3.ITaskRepo, ) *processor.TaskProcessor { taskProcessor := processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) + taskProcessor.Register(entity3.TaskTypeAutoEval, processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) return taskProcessor } + +func NewScheduledTask( + locker lock.ILocker, config3 config2.ITraceConfig, + + traceHubService tracehub.ITraceHubService, + taskService service3.ITaskService, + taskProcessor processor.TaskProcessor, + taskRepo repo3.ITaskRepo, +) []scheduledtask.ScheduledTask { + return []scheduledtask.ScheduledTask{scheduledtask2.NewStatusCheckTask(locker, config3, traceHubService, taskService, taskProcessor, taskRepo), scheduledtask2.NewLocalCacheRefreshTask(traceHubService, taskRepo)} +} diff --git a/backend/modules/observability/domain/component/config/config.go b/backend/modules/observability/domain/component/config/config.go index f587b3188..cb0a29255 100644 --- a/backend/modules/observability/domain/component/config/config.go +++ b/backend/modules/observability/domain/component/config/config.go @@ -128,6 +128,7 @@ type ITraceConfig interface { GetQueryMaxQPS(ctx context.Context, key string) (int, error) GetKeySpanTypes(ctx context.Context) map[string][]string GetBackfillMqProducerCfg(ctx context.Context) (*MqProducerCfg, error) + GetConsumerListening(ctx context.Context) (*ConsumerListening, error) conf.IConfigLoader } diff --git a/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go b/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go new file mode 100644 index 000000000..cdb57cf23 --- /dev/null +++ b/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go @@ -0,0 +1,62 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package scheduledtask + +import ( + "context" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/llm/pkg/goroutineutil" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +type ScheduledTask interface { + Run() error + RunOnce(ctx context.Context) error + Stop() error +} + +type BaseScheduledTask struct { + name string + timeInterval time.Duration + stopChan chan struct{} +} + +func NewBaseScheduledTask(name string, timeInterval time.Duration) BaseScheduledTask { + return BaseScheduledTask{ + name: name, + timeInterval: timeInterval, + stopChan: make(chan struct{}), + } +} + +func (b *BaseScheduledTask) Run() error { + ticker := time.NewTicker(b.timeInterval) + goroutineutil.GoWithDefaultRecovery(context.Background(), func() { + for { + select { + case <-ticker.C: + ctx := context.Background() + startTime := time.Now() + if err := b.RunOnce(ctx); err != nil { + logs.CtxError(ctx, "ScheduledTask [%s] run error: %v, cost: %v", b.name, err, time.Since(startTime)) + } else { + logs.CtxInfo(ctx, "ScheduledTask [%s] run success, cost: %v", b.name, time.Since(startTime)) + } + case <-b.stopChan: + return + } + } + }) + return nil +} + +func (b *BaseScheduledTask) RunOnce(ctx context.Context) error { + panic("implement me") +} + +func (b *BaseScheduledTask) Stop() error { + close(b.stopChan) + return nil +} diff --git a/backend/modules/observability/domain/task/entity/event.go b/backend/modules/observability/domain/task/entity/event.go index c6bcd86e8..512ddefe2 100644 --- a/backend/modules/observability/domain/task/entity/event.go +++ b/backend/modules/observability/domain/task/entity/event.go @@ -4,9 +4,12 @@ package entity import ( + "fmt" "strconv" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) type RawSpan struct { @@ -174,6 +177,14 @@ type AutoEvalEvent struct { ExptID int64 `json:"expt_id"` TurnEvalResults []*OnlineExptTurnEvalResult `json:"turn_eval_results"` } + +func (e *AutoEvalEvent) Validate() error { + if e.TurnEvalResults == nil || len(e.TurnEvalResults) == 0 { + return fmt.Errorf("turn_eval_results is required") + } + return nil +} + type OnlineExptTurnEvalResult struct { EvaluatorVersionID int64 `json:"evaluator_version_id"` EvaluatorRecordID int64 `json:"evaluator_record_id"` @@ -251,6 +262,22 @@ func (s *OnlineExptTurnEvalResult) GetWorkspaceIDFromExt() (string, int64) { return workspaceIDStr, workspaceID } +func (s *OnlineExptTurnEvalResult) GetRunID() (int64, error) { + taskRunIDStr := s.Ext["run_id"] + if taskRunIDStr == "" { + return 0, fmt.Errorf("run_id not found in ext") + } + + return strconv.ParseInt(taskRunIDStr, 10, 64) +} + +func (s *OnlineExptTurnEvalResult) GetUserID() string { + if s.BaseInfo == nil || s.BaseInfo.UpdatedBy == nil { + return "" + } + return s.BaseInfo.UpdatedBy.UserID +} + type EvaluatorRunError struct { Code int32 `json:"code"` Message string `json:"message"` @@ -277,9 +304,14 @@ type CorrectionEvent struct { UpdatedAt int64 `json:"updated_at"` } -type BackFillEvent struct { - SpaceID int64 `json:"space_id"` - TaskID int64 `json:"task_id"` +func (c *CorrectionEvent) Validate() error { + if c.EvaluatorRecordID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("evaluator_record_id is empty")) + } + if c.EvaluatorResult == nil || c.EvaluatorResult.Correction == nil { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("correction is empty")) + } + return nil } func (c *CorrectionEvent) GetSpanIDFromExt() string { @@ -331,3 +363,25 @@ func (c *CorrectionEvent) GetWorkspaceIDFromExt() (string, int64) { } return workspaceIDStr, workspaceID } + +func (c *CorrectionEvent) GetUpdateBy() string { + if c == nil || c.EvaluatorResult == nil || c.EvaluatorResult.Correction == nil { + return "" + } + return c.EvaluatorResult.Correction.UpdatedBy +} + +type BackFillEvent struct { + SpaceID int64 `json:"space_id"` + TaskID int64 `json:"task_id"` +} + +func (b *BackFillEvent) Validate() error { + if b.SpaceID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("space_id is empty")) + } + if b.TaskID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("task_id is empty")) + } + return nil +} diff --git a/backend/modules/observability/domain/task/entity/filter.go b/backend/modules/observability/domain/task/entity/filter.go new file mode 100755 index 000000000..23f515597 --- /dev/null +++ b/backend/modules/observability/domain/task/entity/filter.go @@ -0,0 +1,69 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package entity + +// QueryType represents the operator applied to filter values. +type QueryType string + +// QueryRelation represents the logical relation between multiple filter expressions. +type QueryRelation string + +// FieldType describes the type of a field used in filter expressions. +type FieldType string + +// TaskFieldName defines the supported task field names for filtering. +type TaskFieldName string + +const ( + QueryTypeMatch QueryType = "match" + QueryTypeEq QueryType = "eq" + QueryTypeNotEq QueryType = "not_eq" + QueryTypeLte QueryType = "lte" + QueryTypeGte QueryType = "gte" + QueryTypeLt QueryType = "lt" + QueryTypeGt QueryType = "gt" + QueryTypeExist QueryType = "exist" + QueryTypeNotExist QueryType = "not_exist" + QueryTypeIn QueryType = "in" + QueryTypeNotIn QueryType = "not_in" + QueryTypeNotMatch QueryType = "not_match" + + QueryRelationAnd QueryRelation = "and" + QueryRelationOr QueryRelation = "or" + + FieldTypeString FieldType = "string" + FieldTypeLong FieldType = "long" + FieldTypeDouble FieldType = "double" + FieldTypeBool FieldType = "bool" + + TaskFieldNameTaskStatus TaskFieldName = "task_status" + TaskFieldNameTaskName TaskFieldName = "task_name" + TaskFieldNameTaskType TaskFieldName = "task_type" + TaskFieldNameSampleRate TaskFieldName = "sample_rate" + TaskFieldNameCreatedBy TaskFieldName = "created_by" +) + +// TaskFilterFields aggregates multiple TaskFilterField expressions. +type TaskFilterFields struct { + QueryAndOr *QueryRelation + FilterFields []*TaskFilterField +} + +// GetQueryAndOr returns the relation between filter expressions. +func (f *TaskFilterFields) GetQueryAndOr() string { + if f == nil || f.QueryAndOr == nil { + return string(QueryRelationAnd) + } + return string(*f.QueryAndOr) +} + +// TaskFilterField describes a single filter clause. +type TaskFilterField struct { + FieldName *TaskFieldName + FieldType *FieldType + Values []string + QueryType *QueryType + QueryAndOr *QueryRelation + SubFilter *TaskFilterField +} diff --git a/backend/modules/observability/domain/task/entity/task.go b/backend/modules/observability/domain/task/entity/task.go index 76c8a41a2..8d822fc12 100644 --- a/backend/modules/observability/domain/task/entity/task.go +++ b/backend/modules/observability/domain/task/entity/task.go @@ -4,22 +4,70 @@ package entity import ( + "context" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + taskdto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/logs" ) +type TimeUnit string + +const ( + TimeUnitDay = "day" + TimeUnitWeek = "week" + TimeUnitNull = "null" +) + +type TaskStatus string + +const ( + TaskStatusUnstarted TaskStatus = "unstarted" + TaskStatusRunning TaskStatus = "running" + TaskStatusFailed TaskStatus = "failed" + TaskStatusSuccess TaskStatus = "success" + TaskStatusPending TaskStatus = "pending" + TaskStatusDisabled TaskStatus = "disabled" +) + +type TaskType string + +const ( + TaskTypeAutoEval TaskType = "auto_evaluate" + TaskTypeAutoDataReflow TaskType = "auto_data_reflow" +) + +type TaskRunType string + +const ( + TaskRunTypeBackFill TaskRunType = "back_fill" + TaskRunTypeNewData TaskRunType = "new_data" +) + +type TaskRunStatus string + +const ( + TaskRunStatusRunning TaskRunStatus = "running" + TaskRunStatusDone TaskRunStatus = "done" +) + +type StatusChangeEvent struct { + Before TaskStatus + After TaskStatus +} + // do type ObservabilityTask struct { ID int64 // Task ID WorkspaceID int64 // 空间ID Name string // 任务名称 Description *string // 任务描述 - TaskType string // 任务类型 - TaskStatus string // 任务状态 + TaskType TaskType // 任务类型 + TaskStatus TaskStatus // 任务状态 TaskDetail *RunDetail // 任务运行详情 SpanFilter *SpanFilterFields // span 过滤条件 EffectiveTime *EffectiveTime // 生效时间 @@ -41,8 +89,8 @@ type RunDetail struct { } type SpanFilterFields struct { Filters loop_span.FilterFields `json:"filters"` - PlatformType common.PlatformType `json:"platform_type"` - SpanListType common.SpanListType `json:"span_list_type"` + PlatformType loop_span.PlatformType `json:"platform_type"` + SpanListType loop_span.SpanListType `json:"span_list_type"` } type EffectiveTime struct { // ms timestamp @@ -51,12 +99,12 @@ type EffectiveTime struct { EndAt int64 `json:"end_at"` } type Sampler struct { - SampleRate float64 `json:"sample_rate"` - SampleSize int64 `json:"sample_size"` - IsCycle bool `json:"is_cycle"` - CycleCount int64 `json:"cycle_count"` - CycleInterval int64 `json:"cycle_interval"` - CycleTimeUnit string `json:"cycle_time_unit"` + SampleRate float64 `json:"sample_rate"` + SampleSize int64 `json:"sample_size"` + IsCycle bool `json:"is_cycle"` + CycleCount int64 `json:"cycle_count"` + CycleInterval int64 `json:"cycle_interval"` + CycleTimeUnit TimeUnit `json:"cycle_time_unit"` } type TaskConfig struct { AutoEvaluateConfigs []*AutoEvaluateConfig `json:"auto_evaluate_configs"` @@ -85,8 +133,8 @@ type TaskRun struct { ID int64 // Task Run ID TaskID int64 // Task ID WorkspaceID int64 // 空间ID - TaskType string // 任务类型 - RunStatus string // Task Run状态 + TaskType TaskRunType // 任务类型 + RunStatus TaskRunStatus // Task Run状态 RunDetail *RunDetail // Task Run运行详情 BackfillDetail *BackfillDetail // 历史回溯运行详情 RunStartAt time.Time // run 开始时间 @@ -126,7 +174,7 @@ type DataReflowRunConfig struct { Status string `json:"status"` } -func (t ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { +func (t *ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { if t.EffectiveTime == nil { return 0, 0 } @@ -135,9 +183,9 @@ func (t ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { endAt = t.EffectiveTime.EndAt } else { switch t.Sampler.CycleTimeUnit { - case task.TimeUnitDay: + case TimeUnitDay: endAt = startAt + (t.Sampler.CycleInterval)*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: + case TimeUnitWeek: endAt = startAt + (t.Sampler.CycleInterval)*7*24*time.Hour.Milliseconds() default: endAt = startAt + (t.Sampler.CycleInterval)*24*time.Hour.Milliseconds() @@ -146,34 +194,34 @@ func (t ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { return startAt, endAt } -func (t ObservabilityTask) IsFinished() bool { +func (t *ObservabilityTask) IsFinished() bool { switch t.TaskStatus { - case task.TaskStatusSuccess, task.TaskStatusDisabled, task.TaskStatusPending: + case TaskStatusSuccess, TaskStatusDisabled, TaskStatusPending: return true default: return false } } -func (t ObservabilityTask) GetBackfillTaskRun() *TaskRun { - for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == task.TaskRunTypeBackFill { - return taskRunPO +func (t *ObservabilityTask) GetBackfillTaskRun() *TaskRun { + for _, taskRun := range t.TaskRuns { + if taskRun.TaskType == TaskRunTypeBackFill { + return taskRun } } return nil } -func (t ObservabilityTask) GetCurrentTaskRun() *TaskRun { - for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == task.TaskRunTypeNewData && taskRunPO.RunStatus == task.TaskStatusRunning { - return taskRunPO +func (t *ObservabilityTask) GetCurrentTaskRun() *TaskRun { + for _, taskRun := range t.TaskRuns { + if taskRun.TaskType == TaskRunTypeNewData && taskRun.RunStatus == TaskRunStatusRunning { + return taskRun } } return nil } -func (t ObservabilityTask) GetTaskttl() int64 { +func (t *ObservabilityTask) GetTaskTTL() int64 { var ttl int64 if t.EffectiveTime != nil { ttl = t.EffectiveTime.EndAt - t.EffectiveTime.StartAt @@ -183,3 +231,93 @@ func (t ObservabilityTask) GetTaskttl() int64 { } return ttl } + +func (t *ObservabilityTask) SetEffectiveTime(ctx context.Context, effectiveTime EffectiveTime) error { + if t.EffectiveTime == nil { + logs.CtxError(ctx, "EffectiveTime is null.") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("effective time is nil")) + } + // 开始时间不能大于结束时间 + if effectiveTime.StartAt >= effectiveTime.EndAt { + logs.CtxError(ctx, "Start time must be less than end time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be less than end time")) + } + // 开始、结束时间不能小于当前时间 + if t.EffectiveTime.StartAt != effectiveTime.StartAt && effectiveTime.StartAt < time.Now().UnixMilli() { + logs.CtxError(ctx, "update time must be greater than current time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) + } + if t.EffectiveTime.EndAt != effectiveTime.EndAt && effectiveTime.EndAt < time.Now().UnixMilli() { + logs.CtxError(ctx, "update time must be greater than current time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) + } + switch t.TaskStatus { + case TaskStatusUnstarted: + if effectiveTime.StartAt != 0 { + t.EffectiveTime.StartAt = effectiveTime.StartAt + } + if effectiveTime.EndAt != 0 { + t.EffectiveTime.EndAt = effectiveTime.EndAt + } + case TaskStatusRunning, TaskStatusPending: + if effectiveTime.EndAt != 0 { + t.EffectiveTime.EndAt = effectiveTime.EndAt + } + default: + logs.CtxError(ctx, "Invalid task status:%s", t.TaskStatus) + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) + } + return nil +} + +func (t *ObservabilityTask) SetTaskStatus(ctx context.Context, taskStatus TaskStatus) (*StatusChangeEvent, error) { + currentTaskStatus := t.TaskStatus + if currentTaskStatus == taskStatus { + return nil, nil + } + + switch taskStatus { + case taskdto.TaskStatusUnstarted: + break + case taskdto.TaskStatusRunning: + if currentTaskStatus == taskdto.TaskStatusUnstarted || currentTaskStatus == taskdto.TaskStatusPending { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusPending: + if currentTaskStatus == taskdto.TaskStatusRunning { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusDisabled: + if currentTaskStatus == taskdto.TaskStatusUnstarted || currentTaskStatus == taskdto.TaskStatusPending { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusSuccess: + break + } + + logs.CtxError(ctx, "Invalid task status. Before:[%s], after:[%s]", currentTaskStatus, taskStatus) + return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) +} + +func (t *ObservabilityTask) ShouldTriggerBackfill() bool { + // 检查回填时间配置 + if t.BackfillEffectiveTime == nil { + return false + } + + return t.BackfillEffectiveTime.StartAt > 0 && + t.BackfillEffectiveTime.EndAt > 0 && + t.BackfillEffectiveTime.StartAt < t.BackfillEffectiveTime.EndAt +} diff --git a/backend/modules/observability/domain/task/entity/task_test.go b/backend/modules/observability/domain/task/entity/task_test.go new file mode 100644 index 000000000..e1276646b --- /dev/null +++ b/backend/modules/observability/domain/task/entity/task_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 +package entity + +import ( + "context" + "reflect" + "testing" +) + +func TestObservabilityTask_SetTaskStatus(t *testing.T) { + tests := []struct { + name string // 测试用例名称 + initialTask ObservabilityTask // 任务的初始状态 + targetStatus TaskStatus // 目标设置的状态 + wantEvent *StatusChangeEvent // 期望返回的事件 + wantErr bool // 是否期望发生错误 + finalStatus TaskStatus // 期望的最终任务状态 + }{ + { + name: "状态相同时不进行变更", + initialTask: ObservabilityTask{TaskStatus: TaskStatusRunning}, + targetStatus: TaskStatusRunning, + wantEvent: nil, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "有效状态流转:从未开始到运行中", + initialTask: ObservabilityTask{TaskStatus: TaskStatusUnstarted}, + targetStatus: TaskStatusRunning, + wantEvent: &StatusChangeEvent{ + Before: TaskStatusUnstarted, + After: TaskStatusRunning, + }, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "有效状态流转:从挂起到运行中", + initialTask: ObservabilityTask{TaskStatus: TaskStatusPending}, + targetStatus: TaskStatusRunning, + wantEvent: &StatusChangeEvent{ + Before: TaskStatusPending, + After: TaskStatusRunning, + }, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "无效状态流转:从禁用状态到其他状态", + initialTask: ObservabilityTask{TaskStatus: TaskStatusDisabled}, + targetStatus: TaskStatusRunning, + wantEvent: nil, + wantErr: true, + finalStatus: TaskStatusDisabled, + }, + } + + // 遍历并执行所有测试用例 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Arrange: 创建一个任务副本以防止并发测试时修改原始测试用例数据 + task := tt.initialTask + + // Act: 调用被测方法 + gotEvent, err := task.SetTaskStatus(context.Background(), tt.targetStatus) + + // Assert: 校验错误是否符合预期 + if (err != nil) != tt.wantErr { + t.Errorf("SetTaskStatus() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Assert: 校验返回的事件是否符合预期 + if !reflect.DeepEqual(gotEvent, tt.wantEvent) { + t.Errorf("SetTaskStatus() gotEvent = %v, want %v", gotEvent, tt.wantEvent) + } + + // Assert: 校验任务的最终状态是否符合预期 + if task.TaskStatus != tt.finalStatus { + t.Errorf("Final task status = %v, want %v", task.TaskStatus, tt.finalStatus) + } + }) + } +} diff --git a/backend/modules/observability/domain/task/repo/mocks/Task.go b/backend/modules/observability/domain/task/repo/mocks/Task.go index 6af237884..0020c6c71 100644 --- a/backend/modules/observability/domain/task/repo/mocks/Task.go +++ b/backend/modules/observability/domain/task/repo/mocks/Task.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination=modules/observability/domain/task/repo/mocks/Task.go -package=mocks github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo ITaskRepo +// mockgen -destination=mocks/Task.go -package=mocks . ITaskRepo // // Package mocks is a generated GoMock package. @@ -14,7 +14,7 @@ import ( reflect "reflect" entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - mysql "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" gomock "go.uber.org/mock/gomock" ) @@ -22,6 +22,7 @@ import ( type MockITaskRepo struct { ctrl *gomock.Controller recorder *MockITaskRepoMockRecorder + isgomock struct{} } // MockITaskRepoMockRecorder is the mock recorder for MockITaskRepo. @@ -41,409 +42,394 @@ func (m *MockITaskRepo) EXPECT() *MockITaskRepoMockRecorder { return m.recorder } +// AddNonFinalTask mocks base method. +func (m *MockITaskRepo) AddNonFinalTask(ctx context.Context, spaceID string, taskID int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNonFinalTask", ctx, spaceID, taskID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNonFinalTask indicates an expected call of AddNonFinalTask. +func (mr *MockITaskRepoMockRecorder) AddNonFinalTask(ctx, spaceID, taskID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).AddNonFinalTask), ctx, spaceID, taskID) +} + // CreateTask mocks base method. -func (m *MockITaskRepo) CreateTask(arg0 context.Context, arg1 *entity.ObservabilityTask) (int64, error) { +func (m *MockITaskRepo) CreateTask(ctx context.Context, do *entity.ObservabilityTask) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateTask", arg0, arg1) + ret := m.ctrl.Call(m, "CreateTask", ctx, do) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateTask indicates an expected call of CreateTask. -func (mr *MockITaskRepoMockRecorder) CreateTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) CreateTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTask", reflect.TypeOf((*MockITaskRepo)(nil).CreateTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTask", reflect.TypeOf((*MockITaskRepo)(nil).CreateTask), ctx, do) } // CreateTaskRun mocks base method. -func (m *MockITaskRepo) CreateTaskRun(arg0 context.Context, arg1 *entity.TaskRun) (int64, error) { +func (m *MockITaskRepo) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateTaskRun", arg0, arg1) + ret := m.ctrl.Call(m, "CreateTaskRun", ctx, do) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateTaskRun indicates an expected call of CreateTaskRun. -func (mr *MockITaskRepoMockRecorder) CreateTaskRun(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) CreateTaskRun(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).CreateTaskRun), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).CreateTaskRun), ctx, do) } // DecrTaskCount mocks base method. -func (m *MockITaskRepo) DecrTaskCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) DecrTaskCount(ctx context.Context, taskID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DecrTaskCount", ctx, taskID, ttl) ret0, _ := ret[0].(error) return ret0 } // DecrTaskCount indicates an expected call of DecrTaskCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskCount(ctx, taskID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskCount), ctx, taskID, ttl) } // DecrTaskRunCount mocks base method. -func (m *MockITaskRepo) DecrTaskRunCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) DecrTaskRunCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskRunCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "DecrTaskRunCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // DecrTaskRunCount indicates an expected call of DecrTaskRunCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskRunCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskRunCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunCount), ctx, taskID, taskRunID, ttl) } // DecrTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) DecrTaskRunSuccessCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) DecrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskRunSuccessCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DecrTaskRunSuccessCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(error) return ret0 } // DecrTaskRunSuccessCount indicates an expected call of DecrTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskRunSuccessCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskRunSuccessCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunSuccessCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunSuccessCount), ctx, taskID, taskRunID) } // DeleteTask mocks base method. -func (m *MockITaskRepo) DeleteTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { +func (m *MockITaskRepo) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteTask", arg0, arg1) + ret := m.ctrl.Call(m, "DeleteTask", ctx, do) ret0, _ := ret[0].(error) return ret0 } // DeleteTask indicates an expected call of DeleteTask. -func (mr *MockITaskRepoMockRecorder) DeleteTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DeleteTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockITaskRepo)(nil).DeleteTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockITaskRepo)(nil).DeleteTask), ctx, do) } // GetBackfillTaskRun mocks base method. -func (m *MockITaskRepo) GetBackfillTaskRun(arg0 context.Context, arg1 *int64, arg2 int64) (*entity.TaskRun, error) { +func (m *MockITaskRepo) GetBackfillTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*entity.TaskRun, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBackfillTaskRun", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetBackfillTaskRun", ctx, workspaceID, taskID) ret0, _ := ret[0].(*entity.TaskRun) ret1, _ := ret[1].(error) return ret0, ret1 } // GetBackfillTaskRun indicates an expected call of GetBackfillTaskRun. -func (mr *MockITaskRepoMockRecorder) GetBackfillTaskRun(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetBackfillTaskRun(ctx, workspaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackfillTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetBackfillTaskRun), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackfillTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetBackfillTaskRun), ctx, workspaceID, taskID) } // GetLatestNewDataTaskRun mocks base method. -func (m *MockITaskRepo) GetLatestNewDataTaskRun(arg0 context.Context, arg1 *int64, arg2 int64) (*entity.TaskRun, error) { +func (m *MockITaskRepo) GetLatestNewDataTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*entity.TaskRun, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLatestNewDataTaskRun", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetLatestNewDataTaskRun", ctx, workspaceID, taskID) ret0, _ := ret[0].(*entity.TaskRun) ret1, _ := ret[1].(error) return ret0, ret1 } // GetLatestNewDataTaskRun indicates an expected call of GetLatestNewDataTaskRun. -func (mr *MockITaskRepoMockRecorder) GetLatestNewDataTaskRun(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetLatestNewDataTaskRun(ctx, workspaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestNewDataTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetLatestNewDataTaskRun), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestNewDataTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetLatestNewDataTaskRun), ctx, workspaceID, taskID) } -// GetObjListWithTask mocks base method. -func (m *MockITaskRepo) GetObjListWithTask(arg0 context.Context) ([]string, []string, []*entity.ObservabilityTask) { +// GetTask mocks base method. +func (m *MockITaskRepo) GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetObjListWithTask", arg0) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].([]string) - ret2, _ := ret[2].([]*entity.ObservabilityTask) - return ret0, ret1, ret2 + ret := m.ctrl.Call(m, "GetTask", ctx, id, workspaceID, userID) + ret0, _ := ret[0].(*entity.ObservabilityTask) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// GetObjListWithTask indicates an expected call of GetObjListWithTask. -func (mr *MockITaskRepoMockRecorder) GetObjListWithTask(arg0 any) *gomock.Call { +// GetTask indicates an expected call of GetTask. +func (mr *MockITaskRepoMockRecorder) GetTask(ctx, id, workspaceID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjListWithTask", reflect.TypeOf((*MockITaskRepo)(nil).GetObjListWithTask), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockITaskRepo)(nil).GetTask), ctx, id, workspaceID, userID) } -// GetTask mocks base method. -func (m *MockITaskRepo) GetTask(arg0 context.Context, arg1 int64, arg2 *int64, arg3 *string) (*entity.ObservabilityTask, error) { +// GetTaskByCache mocks base method. +func (m *MockITaskRepo) GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTask", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetTaskByCache", ctx, taskID) ret0, _ := ret[0].(*entity.ObservabilityTask) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetTask indicates an expected call of GetTask. -func (mr *MockITaskRepoMockRecorder) GetTask(arg0, arg1, arg2, arg3 any) *gomock.Call { +// GetTaskByCache indicates an expected call of GetTaskByCache. +func (mr *MockITaskRepoMockRecorder) GetTaskByCache(ctx, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockITaskRepo)(nil).GetTask), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByCache", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskByCache), ctx, taskID) } // GetTaskCount mocks base method. -func (m *MockITaskRepo) GetTaskCount(arg0 context.Context, arg1 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskCount(ctx context.Context, taskID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskCount", arg0, arg1) + ret := m.ctrl.Call(m, "GetTaskCount", ctx, taskID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskCount indicates an expected call of GetTaskCount. -func (mr *MockITaskRepoMockRecorder) GetTaskCount(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskCount(ctx, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskCount), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskCount), ctx, taskID) } // GetTaskRunCount mocks base method. -func (m *MockITaskRepo) GetTaskRunCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunCount indicates an expected call of GetTaskRunCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunCount), ctx, taskID, taskRunID) } // GetTaskRunFailCount mocks base method. -func (m *MockITaskRepo) GetTaskRunFailCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunFailCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunFailCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunFailCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunFailCount indicates an expected call of GetTaskRunFailCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunFailCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunFailCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunFailCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunFailCount), ctx, taskID, taskRunID) } // GetTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) GetTaskRunSuccessCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunSuccessCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunSuccessCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunSuccessCount indicates an expected call of GetTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunSuccessCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunSuccessCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunSuccessCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunSuccessCount), ctx, taskID, taskRunID) } // IncrTaskCount mocks base method. -func (m *MockITaskRepo) IncrTaskCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) IncrTaskCount(ctx context.Context, taskID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "IncrTaskCount", ctx, taskID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskCount indicates an expected call of IncrTaskCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskCount(ctx, taskID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskCount), ctx, taskID, ttl) } // IncrTaskRunCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunCount indicates an expected call of IncrTaskRunCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunCount), ctx, taskID, taskRunID, ttl) } // IncrTaskRunFailCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunFailCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunFailCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunFailCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunFailCount indicates an expected call of IncrTaskRunFailCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunFailCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunFailCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunFailCount), ctx, taskID, taskRunID, ttl) } // IncrTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunSuccessCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunSuccessCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunSuccessCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunSuccessCount indicates an expected call of IncrTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunSuccessCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunSuccessCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunSuccessCount), ctx, taskID, taskRunID, ttl) } -// ListNonFinalTask mocks base method. -func (m *MockITaskRepo) ListNonFinalTask(arg0 context.Context, arg1 string) ([]int64, error) { +// ListNonFinalTaskBySpaceID mocks base method. +func (m *MockITaskRepo) ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListNonFinalTask", arg0, arg1) + ret := m.ctrl.Call(m, "ListNonFinalTaskBySpaceID", ctx, spaceID) ret0, _ := ret[0].([]int64) ret1, _ := ret[1].(error) return ret0, ret1 } -// ListNonFinalTask indicates an expected call of ListNonFinalTask. -func (mr *MockITaskRepoMockRecorder) ListNonFinalTask(arg0, arg1 any) *gomock.Call { +// ListNonFinalTaskBySpaceID indicates an expected call of ListNonFinalTaskBySpaceID. +func (mr *MockITaskRepoMockRecorder) ListNonFinalTaskBySpaceID(ctx, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTaskBySpaceID", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTaskBySpaceID), ctx, spaceID) } -// ListTasks mocks base method. -func (m *MockITaskRepo) ListTasks(arg0 context.Context, arg1 mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { +// ListNonFinalTasks mocks base method. +func (m *MockITaskRepo) ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListTasks", arg0, arg1) + ret := m.ctrl.Call(m, "ListNonFinalTasks", ctx) ret0, _ := ret[0].([]*entity.ObservabilityTask) - ret1, _ := ret[1].(int64) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } -// ListTasks indicates an expected call of ListTasks. -func (mr *MockITaskRepoMockRecorder) ListTasks(arg0, arg1 any) *gomock.Call { +// ListNonFinalTasks indicates an expected call of ListNonFinalTasks. +func (mr *MockITaskRepoMockRecorder) ListNonFinalTasks(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListTasks), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTasks), ctx) } -// AddNonFinalTask mocks base method. -func (m *MockITaskRepo) AddNonFinalTask(arg0 context.Context, arg1 string, arg2 int64) error { +// ListTasks mocks base method. +func (m *MockITaskRepo) ListTasks(ctx context.Context, param repo.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddNonFinalTask", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "ListTasks", ctx, param) + ret0, _ := ret[0].([]*entity.ObservabilityTask) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } -// AddNonFinalTask indicates an expected call of AddNonFinalTask. -func (mr *MockITaskRepoMockRecorder) AddNonFinalTask(arg0, arg1, arg2 any) *gomock.Call { +// ListTasks indicates an expected call of ListTasks. +func (mr *MockITaskRepoMockRecorder) ListTasks(ctx, param any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).AddNonFinalTask), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListTasks), ctx, param) } // RemoveNonFinalTask mocks base method. -func (m *MockITaskRepo) RemoveNonFinalTask(arg0 context.Context, arg1 string, arg2 int64) error { +func (m *MockITaskRepo) RemoveNonFinalTask(ctx context.Context, spaceID string, taskID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveNonFinalTask", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "RemoveNonFinalTask", ctx, spaceID, taskID) ret0, _ := ret[0].(error) return ret0 } // RemoveNonFinalTask indicates an expected call of RemoveNonFinalTask. -func (mr *MockITaskRepoMockRecorder) RemoveNonFinalTask(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).RemoveNonFinalTask), arg0, arg1, arg2) -} - -// GetTaskByRedis mocks base method. -func (m *MockITaskRepo) GetTaskByRedis(arg0 context.Context, arg1 int64) (*entity.ObservabilityTask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskByRedis", arg0, arg1) - ret0, _ := ret[0].(*entity.ObservabilityTask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskByRedis indicates an expected call of GetTaskByRedis. -func (mr *MockITaskRepoMockRecorder) GetTaskByRedis(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) RemoveNonFinalTask(ctx, spaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByRedis", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskByRedis), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).RemoveNonFinalTask), ctx, spaceID, taskID) } // UpdateTask mocks base method. -func (m *MockITaskRepo) UpdateTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { +func (m *MockITaskRepo) UpdateTask(ctx context.Context, do *entity.ObservabilityTask) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTask", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateTask", ctx, do) ret0, _ := ret[0].(error) return ret0 } // UpdateTask indicates an expected call of UpdateTask. -func (mr *MockITaskRepoMockRecorder) UpdateTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTask", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTask", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTask), ctx, do) } // UpdateTaskRun mocks base method. -func (m *MockITaskRepo) UpdateTaskRun(arg0 context.Context, arg1 *entity.TaskRun) error { +func (m *MockITaskRepo) UpdateTaskRun(ctx context.Context, do *entity.TaskRun) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskRun", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateTaskRun", ctx, do) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskRun indicates an expected call of UpdateTaskRun. -func (mr *MockITaskRepoMockRecorder) UpdateTaskRun(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskRun(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRun), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRun), ctx, do) } // UpdateTaskRunWithOCC mocks base method. -func (m *MockITaskRepo) UpdateTaskRunWithOCC(arg0 context.Context, arg1, arg2 int64, arg3 map[string]any) error { +func (m *MockITaskRepo) UpdateTaskRunWithOCC(ctx context.Context, id, workspaceID int64, updateMap map[string]any) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskRunWithOCC", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateTaskRunWithOCC", ctx, id, workspaceID, updateMap) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskRunWithOCC indicates an expected call of UpdateTaskRunWithOCC. -func (mr *MockITaskRepoMockRecorder) UpdateTaskRunWithOCC(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskRunWithOCC(ctx, id, workspaceID, updateMap any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRunWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRunWithOCC), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRunWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRunWithOCC), ctx, id, workspaceID, updateMap) } // UpdateTaskWithOCC mocks base method. -func (m *MockITaskRepo) UpdateTaskWithOCC(arg0 context.Context, arg1, arg2 int64, arg3 map[string]any) error { +func (m *MockITaskRepo) UpdateTaskWithOCC(ctx context.Context, id, workspaceID int64, updateMap map[string]any) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskWithOCC", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateTaskWithOCC", ctx, id, workspaceID, updateMap) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskWithOCC indicates an expected call of UpdateTaskWithOCC. -func (mr *MockITaskRepoMockRecorder) UpdateTaskWithOCC(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskWithOCC), arg0, arg1, arg2, arg3) -} - -// SetTask mocks base method. -func (m *MockITaskRepo) SetTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetTask", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetTask indicates an expected call of SetTask. -func (mr *MockITaskRepoMockRecorder) SetTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskWithOCC(ctx, id, workspaceID, updateMap any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTask", reflect.TypeOf((*MockITaskRepo)(nil).SetTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskWithOCC), ctx, id, workspaceID, updateMap) } diff --git a/backend/modules/observability/domain/task/repo/task.go b/backend/modules/observability/domain/task/repo/task.go index 775443549..e8e2854d6 100644 --- a/backend/modules/observability/domain/task/repo/task.go +++ b/backend/modules/observability/domain/task/repo/task.go @@ -7,9 +7,17 @@ import ( "context" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" ) +type ListTaskParam struct { + WorkspaceIDs []int64 + TaskFilters *entity.TaskFilterFields + ReqLimit int32 + ReqOffset int32 + OrderBy *common.OrderBy +} + //go:generate mockgen -destination=mocks/Task.go -package=mocks . ITaskRepo type ITaskRepo interface { // task @@ -17,8 +25,10 @@ type ITaskRepo interface { UpdateTask(ctx context.Context, do *entity.ObservabilityTask) error UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) - ListTasks(ctx context.Context, param mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) + ListTasks(ctx context.Context, param ListTaskParam) ([]*entity.ObservabilityTask, int64, error) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error + // ListNonFinalTasks Only return Task without TaskRun + ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) // task run CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) @@ -41,16 +51,13 @@ type ITaskRepo interface { GetTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) (int64, error) IncrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error DecrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) error - IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error GetTaskRunFailCount(ctx context.Context, taskID, taskRunID int64) (int64, error) - - GetObjListWithTask(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) + IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error // 非终态task列表by spaceID - ListNonFinalTask(ctx context.Context, spaceID string) ([]int64, error) + ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) AddNonFinalTask(ctx context.Context, spaceID string, taskID int64) error RemoveNonFinalTask(ctx context.Context, spaceID string, taskID int64) error - GetTaskByRedis(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) - SetTask(ctx context.Context, task *entity.ObservabilityTask) error + GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) } diff --git a/backend/modules/observability/domain/task/service/mocks/task_callback_service.go b/backend/modules/observability/domain/task/service/mocks/task_callback_service.go new file mode 100644 index 000000000..0fd9ecb1e --- /dev/null +++ b/backend/modules/observability/domain/task/service/mocks/task_callback_service.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service (interfaces: ITaskCallbackService) +// +// Generated by this command: +// +// mockgen -destination=mocks/task_callback_service.go -package=mocks . ITaskCallbackService +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + gomock "go.uber.org/mock/gomock" +) + +// MockITaskCallbackService is a mock of ITaskCallbackService interface. +type MockITaskCallbackService struct { + ctrl *gomock.Controller + recorder *MockITaskCallbackServiceMockRecorder + isgomock struct{} +} + +// MockITaskCallbackServiceMockRecorder is the mock recorder for MockITaskCallbackService. +type MockITaskCallbackServiceMockRecorder struct { + mock *MockITaskCallbackService +} + +// NewMockITaskCallbackService creates a new mock instance. +func NewMockITaskCallbackService(ctrl *gomock.Controller) *MockITaskCallbackService { + mock := &MockITaskCallbackService{ctrl: ctrl} + mock.recorder = &MockITaskCallbackServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockITaskCallbackService) EXPECT() *MockITaskCallbackServiceMockRecorder { + return m.recorder +} + +// AutoEvalCallback mocks base method. +func (m *MockITaskCallbackService) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AutoEvalCallback", ctx, event) + ret0, _ := ret[0].(error) + return ret0 +} + +// AutoEvalCallback indicates an expected call of AutoEvalCallback. +func (mr *MockITaskCallbackServiceMockRecorder) AutoEvalCallback(ctx, event any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AutoEvalCallback", reflect.TypeOf((*MockITaskCallbackService)(nil).AutoEvalCallback), ctx, event) +} + +// AutoEvalCorrection mocks base method. +func (m *MockITaskCallbackService) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AutoEvalCorrection", ctx, event) + ret0, _ := ret[0].(error) + return ret0 +} + +// AutoEvalCorrection indicates an expected call of AutoEvalCorrection. +func (mr *MockITaskCallbackServiceMockRecorder) AutoEvalCorrection(ctx, event any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AutoEvalCorrection", reflect.TypeOf((*MockITaskCallbackService)(nil).AutoEvalCorrection), ctx, event) +} diff --git a/backend/modules/observability/domain/task/service/task_callback.go b/backend/modules/observability/domain/task/service/task_callback.go new file mode 100644 index 000000000..90d1fda40 --- /dev/null +++ b/backend/modules/observability/domain/task/service/task_callback.go @@ -0,0 +1,257 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "context" + "fmt" + "time" + + "github.com/coze-dev/coze-loop/backend/infra/external/benefit" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + tracerepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" +) + +//go:generate mockgen -destination=mocks/task_callback_service.go -package=mocks . ITaskCallbackService +type ITaskCallbackService interface { + AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error + AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error +} + +type TaskCallbackServiceImpl struct { + taskRepo repo.ITaskRepo + traceRepo tracerepo.ITraceRepo + taskProcessor processor.TaskProcessor + tenantProvider tenant.ITenantProvider + config config.ITraceConfig + benefitSvc benefit.IBenefitService +} + +func NewTaskCallbackServiceImpl( + taskRepo repo.ITaskRepo, + traceRepo tracerepo.ITraceRepo, + taskProcessor processor.TaskProcessor, + tenantProvider tenant.ITenantProvider, + config config.ITraceConfig, + benefitSvc benefit.IBenefitService, +) ITaskCallbackService { + return &TaskCallbackServiceImpl{ + taskRepo: taskRepo, + traceRepo: traceRepo, + taskProcessor: taskProcessor, + tenantProvider: tenantProvider, + config: config, + benefitSvc: benefitSvc, + } +} + +func (t *TaskCallbackServiceImpl) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { + for _, turn := range event.TurnEvalResults { + workspaceIDStr, workspaceID := turn.GetWorkspaceIDFromExt() + tenants, err := t.tenantProvider.GetTenantsByPlatformType(ctx, loop_span.PlatformType("callback_all")) + if err != nil { + return err + } + storageDuration := t.config.GetTraceDataMaxDurationDay(ctx, lo.ToPtr(string(loop_span.PlatformDefault))) + res, err := t.benefitSvc.CheckTraceBenefit(ctx, &benefit.CheckTraceBenefitParams{ + ConnectorUID: turn.GetUserID(), + SpaceID: workspaceID, + }) + if err != nil { + logs.CtxWarn(ctx, "fail to check trace benefit, %v", err) + } else if res == nil { + logs.CtxWarn(ctx, "fail to get trace benefit, got nil response") + } else { + storageDuration = res.StorageDuration + } + + spans, err := t.getSpan(ctx, + tenants, + []string{turn.GetSpanIDFromExt()}, + turn.GetTraceIDFromExt(), + workspaceIDStr, + turn.GetStartTimeFromExt()/1000-(24*time.Duration(storageDuration)*time.Hour).Milliseconds(), + turn.GetStartTimeFromExt()/1000+10*time.Minute.Milliseconds(), + ) + if err != nil { + return err + } + if len(spans) == 0 { + logs.CtxWarn(ctx, "span not found, span_id: %s", turn.GetSpanIDFromExt()) + return fmt.Errorf("span not found, span_id: %s", turn.GetSpanIDFromExt()) + } + span := spans[0] + + // Newly added: write Redis counters based on the Status + err = t.updateTaskRunDetailsCount(ctx, turn.GetTaskIDFromExt(), turn, storageDuration*24*60*60) + if err != nil { + logs.CtxWarn(ctx, "Update TaskRun count failed: taskID=%d, status=%d, err=%v", + turn.GetTaskIDFromExt(), turn.Status, err) + // Continue processing without interrupting the flow + } + + annotation, err := span.AddAutoEvalAnnotation( + turn.GetTaskIDFromExt(), + turn.EvaluatorRecordID, + turn.EvaluatorVersionID, + turn.Score, + turn.Reasoning, + turn.GetUserID(), + ) + if err != nil { + return err + } + + err = t.traceRepo.InsertAnnotations(ctx, &tracerepo.InsertAnnotationParam{ + Tenant: span.GetTenant(), + TTL: span.GetTTL(ctx), + Annotations: []*loop_span.Annotation{annotation}, + }) + if err != nil { + return err + } + } + return nil +} + +func (t *TaskCallbackServiceImpl) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { + workspaceIDStr, workspaceID := event.GetWorkspaceIDFromExt() + if workspaceID == 0 { + return fmt.Errorf("workspace_id is empty") + } + tenants, err := t.tenantProvider.GetTenantsByPlatformType(ctx, loop_span.PlatformType("callback_all")) + if err != nil { + return err + } + spans, err := t.getSpan(ctx, + tenants, + []string{event.GetSpanIDFromExt()}, + event.GetTraceIDFromExt(), + workspaceIDStr, + event.GetStartTimeFromExt()/1000-time.Second.Milliseconds(), + event.GetStartTimeFromExt()/1000+time.Second.Milliseconds(), + ) + if err != nil { + return err + } + if len(spans) == 0 { + return fmt.Errorf("span not found, span_id: %s", event.GetSpanIDFromExt()) + } + span := spans[0] + + annotations, err := t.traceRepo.ListAnnotations(ctx, &tracerepo.ListAnnotationsParam{ + Tenants: tenants, + SpanID: event.GetSpanIDFromExt(), + TraceID: event.GetTraceIDFromExt(), + WorkspaceId: workspaceID, + StartAt: event.GetStartTimeFromExt() - 5*time.Second.Milliseconds(), + EndAt: event.GetStartTimeFromExt() + 5*time.Second.Milliseconds(), + }) + if err != nil { + return err + } + + annotation, ok := annotations.FindByEvaluatorRecordID(event.EvaluatorRecordID) + if !ok { + logs.CtxError(ctx, "annotation not found, evaluator_record_id: %d", event.EvaluatorRecordID) + return fmt.Errorf("annotation not found, evaluator_record_id: %d", event.EvaluatorRecordID) + } + + annotation.CorrectAutoEvaluateScore(event.EvaluatorResult.Correction.Score, event.EvaluatorResult.Correction.Explain, event.GetUpdateBy()) + + // Then synchronize the observability data + param := &tracerepo.InsertAnnotationParam{ + Tenant: span.GetTenant(), + TTL: span.GetTTL(ctx), + Annotations: []*loop_span.Annotation{annotation}, + } + if err = t.traceRepo.InsertAnnotations(ctx, param); err != nil { + recordID := lo.Ternary(annotation.GetAutoEvaluateMetadata() != nil, annotation.GetAutoEvaluateMetadata().EvaluatorRecordID, 0) + // If the synchronous update fails, compensate asynchronously + // TODO: asynchronous processing has issues and may duplicate + logs.CtxError(ctx, "Sync upsert annotation failed, try async upsert. span_id=[%v], recored_id=[%v], err:%v", + annotation.SpanID, recordID, err) + return nil + } + return nil +} + +func (t *TaskCallbackServiceImpl) getSpan(ctx context.Context, tenants []string, spanIds []string, traceId, workspaceId string, startAt, endAt int64) ([]*loop_span.Span, error) { + if len(spanIds) == 0 || workspaceId == "" { + return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode) + } + var filterFields []*loop_span.FilterField + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldSpanId, + FieldType: loop_span.FieldTypeString, + Values: spanIds, + QueryType: ptr.Of(loop_span.QueryTypeEnumIn), + }) + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldSpaceId, + FieldType: loop_span.FieldTypeString, + Values: []string{workspaceId}, + QueryType: ptr.Of(loop_span.QueryTypeEnumEq), + }) + if traceId != "" { + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldTraceId, + FieldType: loop_span.FieldTypeString, + Values: []string{traceId}, + + QueryType: ptr.Of(loop_span.QueryTypeEnumEq), + }) + } + var spans []*loop_span.Span + // todo 目前可能有不同tenant在不同存储中,需要上层多次查询。后续逻辑需要下沉到repo中。 + for _, tenant := range tenants { + res, err := t.traceRepo.ListSpans(ctx, &tracerepo.ListSpansParam{ + Tenants: []string{tenant}, + Filters: &loop_span.FilterFields{ + FilterFields: filterFields, + }, + StartAt: startAt, + EndAt: endAt, + NotQueryAnnotation: true, + Limit: int32(len(spanIds)), + }) + if err != nil { + logs.CtxError(ctx, "failed to list span, %v", err) + return spans, err + } + spans = append(spans, res.Spans...) + } + logs.CtxInfo(ctx, "list span, spans: %v", spans) + + return spans, nil +} + +// updateTaskRunStatusCount updates the Redis count based on Status +func (t *TaskCallbackServiceImpl) updateTaskRunDetailsCount(ctx context.Context, taskID int64, turn *entity.OnlineExptTurnEvalResult, ttl int64) error { + taskRunID, err := turn.GetRunID() + if err != nil { + return fmt.Errorf("invalid task_run_id, err: %v", err) + } + // Increase the corresponding counter based on Status + switch turn.Status { + case entity.EvaluatorRunStatus_Success: + return t.taskRepo.IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl) + case entity.EvaluatorRunStatus_Fail: + return t.taskRepo.IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl) + default: + logs.CtxWarn(ctx, "unknown status, skip count: taskID=%d, taskRunID=%d, status=%d", + taskID, taskRunID, turn.Status) + return nil + } +} diff --git a/backend/modules/observability/domain/task/service/task_callback_test.go b/backend/modules/observability/domain/task/service/task_callback_test.go new file mode 100755 index 000000000..a09a4aab9 --- /dev/null +++ b/backend/modules/observability/domain/task/service/task_callback_test.go @@ -0,0 +1,311 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package service + +import ( + "context" + "errors" + "strconv" + "testing" + "time" + + "go.uber.org/mock/gomock" + + "github.com/coze-dev/coze-loop/backend/infra/external/benefit" + benefit_mocks "github.com/coze-dev/coze-loop/backend/infra/external/benefit/mocks" + tenant_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant/mocks" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + trace_repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" + "github.com/stretchr/testify/require" +) + +func TestTaskCallbackServiceImpl_CallBackSuccess(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockBenefit := benefit_mocks.NewMockIBenefitService(ctrl) + mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) + + impl := &TaskCallbackServiceImpl{ + benefitSvc: mockBenefit, + tenantProvider: mockTenant, + traceRepo: mockTraceRepo, + taskRepo: mockTaskRepo, + } + + mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), gomock.Any()).Return([]string{"tenant"}, nil).AnyTimes() + mockBenefit.EXPECT().CheckTraceBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckTraceBenefitResult{StorageDuration: 1}, nil).AnyTimes() + + now := time.Now() + span := &loop_span.Span{ + SpanID: "span-1", + TraceID: "trace-1", + SystemTagsString: map[string]string{loop_span.SpanFieldTenant: "tenant"}, + LogicDeleteTime: now.Add(24 * time.Hour).UnixMicro(), + StartTime: now.UnixMicro(), + } + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{Spans: loop_span.SpanList{span}}, nil) + mockTaskRepo.EXPECT().IncrTaskRunSuccessCount(gomock.Any(), int64(101), int64(202), gomock.Any()).Return(nil) + mockTraceRepo.EXPECT().InsertAnnotations(gomock.Any(), gomock.AssignableToTypeOf(&repo.InsertAnnotationParam{})).DoAndReturn( + func(_ context.Context, param *repo.InsertAnnotationParam) error { + require.Len(t, param.Annotations, 1) + require.Equal(t, loop_span.AnnotationTypeAutoEvaluate, param.Annotations[0].AnnotationType) + return nil + }, + ) + + startTime := now.Add(-time.Minute).UnixMilli() + event := &entity.AutoEvalEvent{ + TurnEvalResults: []*entity.OnlineExptTurnEvalResult{ + { + EvaluatorVersionID: 1, + Score: 0.9, + Reasoning: "ok", + Status: entity.EvaluatorRunStatus_Success, + BaseInfo: &entity.BaseInfo{ + CreatedBy: &entity.UserInfo{UserID: "user-1"}, + }, + Ext: map[string]string{ + "workspace_id": strconv.FormatInt(1, 10), + "span_id": "span-1", + "trace_id": "trace-1", + "start_time": strconv.FormatInt(startTime*1000, 10), + "task_id": strconv.FormatInt(101, 10), + "run_id": strconv.FormatInt(202, 10), + }, + }, + }, + } + + require.NoError(t, impl.AutoEvalCallback(context.Background(), event)) +} + +func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockBenefit := benefit_mocks.NewMockIBenefitService(ctrl) + mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + + impl := &TaskCallbackServiceImpl{ + benefitSvc: mockBenefit, + tenantProvider: mockTenant, + traceRepo: mockTraceRepo, + } + + mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), gomock.Any()).Return([]string{"tenant"}, nil).AnyTimes() + mockBenefit.EXPECT().CheckTraceBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckTraceBenefitResult{StorageDuration: 1}, nil).AnyTimes() + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) + + event := &entity.AutoEvalEvent{ + TurnEvalResults: []*entity.OnlineExptTurnEvalResult{ + { + Status: entity.EvaluatorRunStatus_Success, + BaseInfo: &entity.BaseInfo{ + CreatedBy: &entity.UserInfo{UserID: "user-1"}, + }, + Ext: map[string]string{ + "workspace_id": "1", + "span_id": "span-1", + "trace_id": "trace-1", + "start_time": strconv.FormatInt(time.Now().UnixMilli()*1000, 10), + "task_id": "101", + "run_id": "202", + }, + }, + }, + } + + require.Error(t, impl.AutoEvalCallback(context.Background(), event)) +} + +func TestTaskCallbackServiceImpl_getSpan(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tenants := []string{"tenant"} + spanIDs := []string{"span-1"} + traceID := "trace-1" + workspaceID := "ws-1" + start := int64(1000) + end := int64(2000) + + t.Run("with_trace_id", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + expectedSpan := &loop_span.Span{SpanID: spanIDs[0], TraceID: traceID} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( + func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { + require.Equal(t, tenants, param.Tenants) + require.Equal(t, start, param.StartAt) + require.Equal(t, end, param.EndAt) + require.True(t, param.NotQueryAnnotation) + require.Equal(t, int32(2), param.Limit) + require.Len(t, param.Filters.FilterFields, 3) + return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil + }, + ) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.NoError(t, err) + require.Equal(t, []*loop_span.Span{expectedSpan}, spans) + }) + + t.Run("without_trace_id", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + expectedSpan := &loop_span.Span{SpanID: spanIDs[0]} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( + func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { + require.Equal(t, tenants, param.Tenants) + require.Len(t, param.Filters.FilterFields, 2) + return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil + }, + ) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, "", workspaceID, start, end) + require.NoError(t, err) + require.Equal(t, []*loop_span.Span{expectedSpan}, spans) + }) + + t.Run("empty_span_ids", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + _, err := impl.getSpan(ctx, tenants, nil, traceID, workspaceID, start, end) + require.Error(t, err) + }) + + t.Run("empty_workspace", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, "", start, end) + require.Error(t, err) + }) + + t.Run("repo_error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(nil, errors.New("list error")) + + _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.Error(t, err) + }) + + t.Run("no_data", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.NoError(t, err) + require.Nil(t, spans) + }) +} + +func TestTaskCallbackServiceImpl_updateTaskRunDetailsCount(t *testing.T) { + t.Parallel() + + ctx := context.Background() + taskID := int64(101) + runIDStr := "202" + runID := int64(202) + + tests := []struct { + name string + status entity.EvaluatorRunStatus + expectSuccess bool + expectFail bool + expectErr bool + }{ + { + name: "success_status", + status: entity.EvaluatorRunStatus_Success, + expectSuccess: true, + }, + { + name: "fail_status", + status: entity.EvaluatorRunStatus_Fail, + expectFail: true, + }, + { + name: "unknown_status", + status: entity.EvaluatorRunStatus_Unknown, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TaskCallbackServiceImpl{taskRepo: mockRepo} + + turn := &entity.OnlineExptTurnEvalResult{ + Status: tt.status, + Ext: map[string]string{ + "run_id": runIDStr, + }, + } + + if tt.expectSuccess { + mockRepo.EXPECT().IncrTaskRunSuccessCount(ctx, taskID, runID, gomock.Any()).Return(nil) + } + if tt.expectFail { + mockRepo.EXPECT().IncrTaskRunFailCount(ctx, taskID, runID, gomock.Any()).Return(nil) + } + + err := impl.updateTaskRunDetailsCount(ctx, taskID, turn, 0) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } + + t.Run("missing_run_id", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{}}, 0) + require.Error(t, err) + }) + + t.Run("invalid_run_id", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{"run_id": "abc"}}, 0) + require.Error(t, err) + }) +} diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 623c43c6f..74bc33ae1 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -12,21 +12,17 @@ import ( "github.com/bytedance/gg/gptr" "github.com/coze-dev/coze-loop/backend/infra/idgen" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + traceservice "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) @@ -39,28 +35,28 @@ type CreateTaskResp struct { type UpdateTaskReq struct { TaskID int64 WorkspaceID int64 - TaskStatus *task.TaskStatus + TaskStatus *entity.TaskStatus Description *string - EffectiveTime *task.EffectiveTime + EffectiveTime *entity.EffectiveTime SampleRate *float64 } type ListTasksReq struct { WorkspaceID int64 - TaskFilters *filter.TaskFilterFields + TaskFilters *entity.TaskFilterFields Limit int32 Offset int32 OrderBy *common.OrderBy } type ListTasksResp struct { - Tasks []*task.Task - Total *int64 + Tasks []*entity.ObservabilityTask + Total int64 } type GetTaskReq struct { TaskID int64 WorkspaceID int64 } type GetTaskResp struct { - Task *task.Task + Task *entity.ObservabilityTask } type CheckTaskNameReq struct { WorkspaceID int64 @@ -77,37 +73,40 @@ type ITaskService interface { ListTasks(ctx context.Context, req *ListTasksReq) (resp *ListTasksResp, err error) GetTask(ctx context.Context, req *GetTaskReq) (resp *GetTaskResp, err error) CheckTaskName(ctx context.Context, req *CheckTaskNameReq) (resp *CheckTaskNameResp, err error) + + SendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error } func NewTaskServiceImpl( tRepo repo.ITaskRepo, - userProvider rpc.IUserProvider, idGenerator idgen.IIDGenerator, backfillProducer mq.IBackfillProducer, taskProcessor *processor.TaskProcessor, + buildHelper traceservice.TraceFilterProcessorBuilder, ) (ITaskService, error) { return &TaskServiceImpl{ TaskRepo: tRepo, - userProvider: userProvider, idGenerator: idGenerator, backfillProducer: backfillProducer, taskProcessor: *taskProcessor, + buildHelper: buildHelper, }, nil } type TaskServiceImpl struct { TaskRepo repo.ITaskRepo - userProvider rpc.IUserProvider idGenerator idgen.IIDGenerator backfillProducer mq.IBackfillProducer taskProcessor processor.TaskProcessor + buildHelper traceservice.TraceFilterProcessorBuilder } func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (resp *CreateTaskResp, err error) { + taskDO := req.Task // 校验task name是否存在 checkResp, err := t.CheckTaskName(ctx, &CheckTaskNameReq{ - WorkspaceID: req.Task.WorkspaceID, - Name: req.Task.Name, + WorkspaceID: taskDO.WorkspaceID, + Name: taskDO.Name, }) if err != nil { logs.CtxError(ctx, "CheckTaskName err:%v", err) @@ -117,42 +116,46 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r logs.CtxError(ctx, "task name exist") return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("task name exist")) } - proc := t.taskProcessor.GetTaskProcessor(req.Task.TaskType) + + if err := t.buildSpanFilters(ctx, taskDO); err != nil { + logs.CtxError(ctx, "buildSpanFilters err:%v", err) + return nil, err + } + + proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) // 校验配置项是否有效 - if err = proc.ValidateConfig(ctx, req.Task); err != nil { + if err = proc.ValidateConfig(ctx, taskDO); err != nil { logs.CtxError(ctx, "ValidateConfig err:%v", err) return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg(fmt.Sprintf("config invalid:%v", err))) } - id, err := t.TaskRepo.CreateTask(ctx, req.Task) + id, err := t.TaskRepo.CreateTask(ctx, taskDO) if err != nil { return nil, err } // 创建任务的数据准备 // 数据回流任务——创建/更新输出数据集 // 自动评测历史回溯——创建空壳子 - req.Task.ID = id - if err = proc.OnCreateTaskChange(ctx, req.Task); err != nil { + taskDO.ID = id + if err = proc.OnTaskCreated(ctx, taskDO); err != nil { logs.CtxError(ctx, "create initial task run failed, task_id=%d, err=%v", id, err) - if err1 := t.TaskRepo.DeleteTask(ctx, req.Task); err1 != nil { + if err1 := t.TaskRepo.DeleteTask(ctx, taskDO); err1 != nil { logs.CtxError(ctx, "delete task failed, task_id=%d, err=%v", id, err1) } return nil, err } // 历史回溯数据发MQ - if t.shouldTriggerBackfill(req.Task) { + if taskDO.ShouldTriggerBackfill() { backfillEvent := &entity.BackFillEvent{ - SpaceID: req.Task.WorkspaceID, + SpaceID: taskDO.WorkspaceID, TaskID: id, } - // 异步发送MQ消息,不阻塞任务创建流程 - go func() { - if err := t.sendBackfillMessage(context.Background(), backfillEvent); err != nil { - logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", id, err) - } - }() + if err := t.SendBackfillMessage(context.Background(), backfillEvent); err != nil { + // 失败了会有定时任务进行补偿 + logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", id, err) + } } return &CreateTaskResp{TaskID: &id}, nil @@ -164,7 +167,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e return err } if taskDO == nil { - logs.CtxError(ctx, "task not found") + logs.CtxError(ctx, "task [%d] not found", req.TaskID) return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("task not found")) } userID := session.UserIDInCtxOrEmpty(ctx) @@ -176,32 +179,31 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e taskDO.Description = req.Description } if req.EffectiveTime != nil { - validEffectiveTime, err := tconv.CheckEffectiveTime(ctx, req.EffectiveTime, taskDO.TaskStatus, taskDO.EffectiveTime) - if err != nil { + if err := taskDO.SetEffectiveTime(ctx, *req.EffectiveTime); err != nil { return err } - taskDO.EffectiveTime = validEffectiveTime } if req.SampleRate != nil { taskDO.Sampler.SampleRate = *req.SampleRate } if req.TaskStatus != nil { - validTaskStatus, err := tconv.CheckTaskStatus(ctx, *req.TaskStatus, taskDO.TaskStatus) + event, err := taskDO.SetTaskStatus(ctx, *req.TaskStatus) if err != nil { return err } - if validTaskStatus != "" { - if validTaskStatus == task.TaskStatusDisabled { + + if event != nil { + if event.After == entity.TaskStatusDisabled { // 禁用操作处理 proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) var taskRun *entity.TaskRun for _, tr := range taskDO.TaskRuns { - if tr.RunStatus == task.RunStatusRunning { + if tr.RunStatus == entity.TaskRunStatusRunning { taskRun = tr break } } - if err = proc.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{ + if err = proc.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{ Task: taskDO, TaskRun: taskRun, }); err != nil { @@ -213,7 +215,6 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err) } } - taskDO.TaskStatus = *req.TaskStatus } } taskDO.UpdatedBy = userID @@ -225,7 +226,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e } func (t *TaskServiceImpl) ListTasks(ctx context.Context, req *ListTasksReq) (resp *ListTasksResp, err error) { - taskDOs, total, err := t.TaskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskDOs, total, err := t.TaskRepo.ListTasks(ctx, repo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, TaskFilters: req.TaskFilters, ReqLimit: req.Limit, @@ -240,22 +241,12 @@ func (t *TaskServiceImpl) ListTasks(ctx context.Context, req *ListTasksReq) (res logs.CtxInfo(ctx, "GetTasks tasks is nil") return resp, nil } - userMap := make(map[string]bool) - users := make([]string, 0) - for _, tp := range taskDOs { - userMap[tp.CreatedBy] = true - userMap[tp.UpdatedBy] = true - } - for u := range userMap { - users = append(users, u) - } - _, userInfoMap, err := t.userProvider.GetUserInfo(ctx, users) - if err != nil { - logs.CtxError(ctx, "MGetUserInfo err:%v", err) - } + + taskDOs = filterHiddenFilters(taskDOs) + return &ListTasksResp{ - Tasks: tconv.TaskDOs2DTOs(ctx, filterHiddenFilters(taskDOs), userInfoMap), - Total: ptr.Of(total), + Tasks: taskDOs, + Total: total, }, nil } @@ -269,11 +260,10 @@ func (t *TaskServiceImpl) GetTask(ctx context.Context, req *GetTaskReq) (resp *G logs.CtxError(ctx, "GetTasks tasks is nil") return resp, nil } - _, userInfoMap, err := t.userProvider.GetUserInfo(ctx, []string{taskDO.CreatedBy, taskDO.UpdatedBy}) - if err != nil { - logs.CtxError(ctx, "MGetUserInfo err:%v", err) - } - return &GetTaskResp{Task: tconv.TaskDO2DTO(ctx, filterHiddenFilters([]*entity.ObservabilityTask{taskDO})[0], userInfoMap)}, nil + + taskDO = filterHiddenFilters([]*entity.ObservabilityTask{taskDO})[0] + + return &GetTaskResp{Task: taskDO}, nil } func filterHiddenFilters(tasks []*entity.ObservabilityTask) []*entity.ObservabilityTask { @@ -330,15 +320,15 @@ func filterVisibleFilterFields(fields *loop_span.FilterFields) *loop_span.Filter } func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameReq) (resp *CheckTaskNameResp, err error) { - taskPOs, _, err := t.TaskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskPOs, _, err := t.TaskRepo.ListTasks(ctx, repo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ { - FieldName: gptr.Of(filter.TaskFieldNameTaskName), - FieldType: gptr.Of(filter.FieldTypeString), + FieldName: gptr.Of(entity.TaskFieldNameTaskName), + FieldType: gptr.Of(entity.FieldTypeString), Values: []string{req.Name}, - QueryType: gptr.Of(filter.QueryTypeMatch), + QueryType: gptr.Of(entity.QueryTypeMatch), }, }, }, @@ -358,30 +348,39 @@ func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameR return &CheckTaskNameResp{Pass: gptr.Of(pass)}, nil } -// shouldTriggerBackfill 判断是否需要发送历史回溯MQ -func (t *TaskServiceImpl) shouldTriggerBackfill(taskDO *entity.ObservabilityTask) bool { - // 检查任务类型 - taskType := taskDO.TaskType - if taskType != task.TaskTypeAutoEval && taskType != task.TaskTypeAutoDataReflow { - return false +// SendBackfillMessage 发送MQ消息 +func (t *TaskServiceImpl) SendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error { + if t.backfillProducer == nil { + return errorx.NewByCode(obErrorx.CommonInternalErrorCode, errorx.WithExtraMsg("backfill producer not initialized")) } - // 检查回填时间配置 + return t.backfillProducer.SendBackfill(ctx, event) +} - if taskDO.BackfillEffectiveTime == nil { - return false +func (t *TaskServiceImpl) buildSpanFilters(ctx context.Context, taskDO *entity.ObservabilityTask) error { + f, err := t.buildHelper.BuildPlatformRelatedFilter(ctx, taskDO.SpanFilter.PlatformType) + if err != nil { + return err + } + env := &span_filter.SpanEnv{ + WorkspaceID: taskDO.WorkspaceID, } - return taskDO.BackfillEffectiveTime.StartAt > 0 && - taskDO.BackfillEffectiveTime.EndAt > 0 && - taskDO.BackfillEffectiveTime.StartAt < taskDO.BackfillEffectiveTime.EndAt -} + // coze场景中,需要将basic filter提前固化到数据库中,避免任务触发时重复调用coze接口 + basicFilter, forceQuery, err := f.BuildBasicSpanFilter(ctx, env) + if err != nil { + return err + } else if len(basicFilter) == 0 && !forceQuery { + logs.CtxInfo(ctx, "Build basic filter failed, platform type: [%s], workspaceID: [%d]", + taskDO.SpanFilter.PlatformType, taskDO.WorkspaceID) + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("User has no permission")) + } -// sendBackfillMessage 发送MQ消息 -func (t *TaskServiceImpl) sendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error { - if t.backfillProducer == nil { - return errorx.NewByCode(obErrorx.CommonInternalErrorCode, errorx.WithExtraMsg("backfill producer not initialized")) + // basic filter对用户不可见 + for _, filter := range basicFilter { + filter.SetHidden(true) } - return t.backfillProducer.SendBackfill(ctx, event) + taskDO.SpanFilter.Filters.FilterFields = append(taskDO.SpanFilter.Filters.FilterFields, basicFilter...) + return nil } diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index a65312dfe..389760d6d 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -14,18 +14,16 @@ import ( "go.uber.org/mock/gomock" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" componentmq "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" - rpc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" - rpcmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" repomocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - entitycommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) @@ -55,19 +53,67 @@ func (f *fakeProcessor) OnUpdateTaskChange(context.Context, *entity.Observabilit return nil } -func (f *fakeProcessor) OnFinishTaskChange(context.Context, taskexe.OnFinishTaskChangeReq) error { +func (f *fakeProcessor) OnFinishTaskChange(context.Context, taskexe.OnTaskFinishedReq) error { return nil } -func (f *fakeProcessor) OnCreateTaskRunChange(context.Context, taskexe.OnCreateTaskRunChangeReq) error { +func (f *fakeProcessor) OnCreateTaskRunChange(context.Context, taskexe.OnTaskRunCreatedReq) error { return nil } -func (f *fakeProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnFinishTaskRunChangeReq) error { +func (f *fakeProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnTaskRunFinishedReq) error { f.onFinishRunCalled = true return f.onFinishRunErr } +type stubTraceFilterBuilder struct{} + +func (s *stubTraceFilterBuilder) BuildPlatformRelatedFilter(context.Context, loop_span.PlatformType) (span_filter.Filter, error) { + return &stubSpanFilter{}, nil +} + +func (s *stubTraceFilterBuilder) BuildGetTraceProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildListSpansProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildAdvanceInfoProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildIngestTraceProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildSearchTraceOApiProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildListSpansOApiProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +type stubSpanFilter struct{} + +func (s *stubSpanFilter) BuildBasicSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { + return nil, true, nil +} + +func (s *stubSpanFilter) BuildRootSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + +func (s *stubSpanFilter) BuildLLMSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + +func (s *stubSpanFilter) BuildALLSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + type stubBackfillProducer struct { ch chan *entity.BackFillEvent err error @@ -80,11 +126,11 @@ func (s *stubBackfillProducer) SendBackfill(ctx context.Context, message *entity return s.err } -func newTaskServiceWithProcessor(t *testing.T, repo taskrepo.ITaskRepo, userProvider rpc.IUserProvider, backfill componentmq.IBackfillProducer, proc taskexe.Processor, taskType task.TaskType) *TaskServiceImpl { +func newTaskServiceWithProcessor(t *testing.T, repo taskrepo.ITaskRepo, backfill componentmq.IBackfillProducer, proc taskexe.Processor, taskType entity.TaskType) *TaskServiceImpl { t.Helper() tp := processor.NewTaskProcessor() tp.Register(taskType, proc) - service, err := NewTaskServiceImpl(repo, userProvider, nil, backfill, tp) + service, err := NewTaskServiceImpl(repo, nil, backfill, tp, &stubTraceFilterBuilder{}) assert.NoError(t, err) return service.(*TaskServiceImpl) } @@ -108,13 +154,14 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { backfillCh := make(chan *entity.BackFillEvent, 1) backfill := &stubBackfillProducer{ch: backfillCh} - svc := newTaskServiceWithProcessor(t, repoMock, nil, backfill, proc, task.TaskTypeAutoEval) + svc := newTaskServiceWithProcessor(t, repoMock, backfill, proc, entity.TaskTypeAutoEval) reqTask := &entity.ObservabilityTask{ WorkspaceID: 123, Name: "task", - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, + SpanFilter: &entity.SpanFilterFields{}, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, @@ -144,9 +191,9 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return(nil, int64(0), nil) proc := &fakeProcessor{validateErr: errors.New("invalid config")} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -165,8 +212,8 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return([]*entity.ObservabilityTask{{}}, int64(1), nil) proc := &fakeProcessor{} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -188,8 +235,8 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().DeleteTask(gomock.Any(), gomock.AssignableToTypeOf(&entity.ObservabilityTask{})).Return(nil) proc := &fakeProcessor{onCreateErr: errors.New("hook fail")} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.EqualError(t, err, "hook fail") @@ -235,12 +282,12 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - taskDO := &entity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, TaskStatus: task.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{}, Sampler: &entity.Sampler{}} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskTypeAutoEval, TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{}, Sampler: &entity.Sampler{}} repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) err := svc.UpdateTask(context.Background(), &UpdateTaskReq{TaskID: 1, WorkspaceID: 2}) statusErr, ok := errorx.FromStatusError(err) @@ -258,11 +305,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { repoMock := repomocks.NewMockITaskRepo(ctrl) now := time.Now() taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: startAt, EndAt: startAt + 3600000}, Sampler: &entity.Sampler{SampleRate: 0.1}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, UpdatedAt: now, UpdatedBy: "", } @@ -273,7 +320,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) desc := "updated" newStart := startAt + 1000 @@ -283,13 +330,13 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { TaskID: 1, WorkspaceID: 2, Description: &desc, - EffectiveTime: &task.EffectiveTime{StartAt: &newStart, EndAt: &newEnd}, + EffectiveTime: &entity.EffectiveTime{StartAt: newStart, EndAt: newEnd}, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.NoError(t, err) assert.True(t, proc.onFinishRunCalled) - assert.Equal(t, task.TaskStatusDisabled, taskDO.TaskStatus) + assert.Equal(t, entity.TaskStatusDisabled, taskDO.TaskStatus) assert.Equal(t, "user1", taskDO.UpdatedBy) if assert.NotNil(t, taskDO.Description) { assert.Equal(t, desc, *taskDO.Description) @@ -306,11 +353,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { repoMock := repomocks.NewMockITaskRepo(ctrl) taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: time.Now().UnixMilli(), EndAt: time.Now().Add(time.Hour).UnixMilli()}, Sampler: &entity.Sampler{}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) @@ -319,14 +366,14 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) sampleRate := 0.6 err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ TaskID: 1, WorkspaceID: 2, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.NoError(t, err) assert.True(t, proc.onFinishRunCalled) @@ -340,11 +387,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { startAt := time.Now().Add(2 * time.Hour).UnixMilli() repoMock := repomocks.NewMockITaskRepo(ctrl) taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: startAt, EndAt: startAt + 3600000}, Sampler: &entity.Sampler{}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) @@ -352,7 +399,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{onFinishRunErr: errors.New("finish fail")} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) newStart := startAt + 1000 newEnd := startAt + 7200000 @@ -360,9 +407,9 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ TaskID: 1, WorkspaceID: 2, - EffectiveTime: &task.EffectiveTime{StartAt: &newStart, EndAt: &newEnd}, + EffectiveTime: &entity.EffectiveTime{StartAt: newStart, EndAt: newEnd}, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.EqualError(t, err, "finish fail") }) @@ -391,7 +438,6 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - userMock := rpcmock.NewMockIUserProvider(ctrl) hiddenField := &loop_span.FilterField{FieldName: "hidden", Values: []string{"1"}, Hidden: true} visibleField := &loop_span.FilterField{FieldName: "visible", Values: []string{"val"}} @@ -402,8 +448,8 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { ID: 1, Name: "task", WorkspaceID: 2, - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, CreatedBy: "user1", UpdatedBy: "user2", EffectiveTime: &entity.EffectiveTime{}, @@ -414,25 +460,23 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { }}, } repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return([]*entity.ObservabilityTask{taskDO}, int64(1), nil) - userMock.EXPECT().GetUserInfo(gomock.Any(), gomock.Any()).Return(nil, map[string]*entitycommon.UserInfo{}, nil) - svc := &TaskServiceImpl{TaskRepo: repoMock, userProvider: userMock} - resp, err := svc.ListTasks(context.Background(), &ListTasksReq{WorkspaceID: 2, TaskFilters: &filter.TaskFilterFields{}}) + svc := &TaskServiceImpl{TaskRepo: repoMock} + resp, err := svc.ListTasks(context.Background(), &ListTasksReq{WorkspaceID: 2, TaskFilters: &entity.TaskFilterFields{}}) assert.NoError(t, err) if assert.NotNil(t, resp) { - assert.EqualValues(t, 1, *resp.Total) + assert.EqualValues(t, 1, resp.Total) assert.Len(t, resp.Tasks, 1) - filterFields := resp.Tasks[0].GetRule().GetSpanFilters().GetFilters() - if assert.NotNil(t, filterFields) { - fields := filterFields.GetFilterFields() + task := resp.Tasks[0] + if assert.NotNil(t, task.SpanFilter) { + fields := task.SpanFilter.Filters.FilterFields assert.Len(t, fields, 2) - assert.Equal(t, "visible", fields[0].GetFieldName()) - assert.Equal(t, []string{"val"}, fields[0].GetValues()) - sub := fields[1].GetSubFilter() - if assert.NotNil(t, sub) { - subFields := sub.GetFilterFields() + assert.Equal(t, "visible", fields[0].FieldName) + assert.Equal(t, []string{"val"}, fields[0].Values) + if sub := fields[1].SubFilter; assert.NotNil(t, sub) { + subFields := sub.FilterFields assert.Len(t, subFields, 1) - assert.Equal(t, "child", subFields[0].GetFieldName()) + assert.Equal(t, "child", subFields[0].FieldName) } } } @@ -476,7 +520,6 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - userMock := rpcmock.NewMockIUserProvider(ctrl) subHidden := &loop_span.FilterField{FieldName: "inner_hidden", Values: []string{"v"}, Hidden: true} subVisible := &loop_span.FilterField{FieldName: "inner_visible", Values: []string{"v"}} @@ -485,8 +528,8 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { hidden := &loop_span.FilterField{FieldName: "outer_hidden", Values: []string{"v"}, Hidden: true} taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, CreatedBy: "user1", UpdatedBy: "user2", EffectiveTime: &entity.EffectiveTime{}, @@ -498,22 +541,20 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) - userMock.EXPECT().GetUserInfo(gomock.Any(), gomock.Any()).Return(nil, map[string]*entitycommon.UserInfo{}, nil) - svc := &TaskServiceImpl{TaskRepo: repoMock, userProvider: userMock} + svc := &TaskServiceImpl{TaskRepo: repoMock} resp, err := svc.GetTask(context.Background(), &GetTaskReq{TaskID: 1, WorkspaceID: 2}) assert.NoError(t, err) if assert.NotNil(t, resp) { - filters := resp.Task.GetRule().GetSpanFilters().GetFilters() - if assert.NotNil(t, filters) { - fields := filters.GetFilterFields() + task := resp.Task + if assert.NotNil(t, task.SpanFilter) { + fields := task.SpanFilter.Filters.FilterFields assert.Len(t, fields, 2) - assert.Equal(t, "outer_visible", fields[0].GetFieldName()) - sub := fields[1].GetSubFilter() - if assert.NotNil(t, sub) { - subFields := sub.GetFilterFields() + assert.Equal(t, "outer_visible", fields[0].FieldName) + if sub := fields[1].SubFilter; assert.NotNil(t, sub) { + subFields := sub.FilterFields assert.Len(t, subFields, 1) - assert.Equal(t, "inner_visible", subFields[0].GetFieldName()) + assert.Equal(t, "inner_visible", subFields[0].FieldName) } } } @@ -574,18 +615,18 @@ func TestTaskServiceImpl_shouldTriggerBackfill(t *testing.T) { service := &TaskServiceImpl{} t.Run("task type mismatch", func(t *testing.T) { - taskDO := &entity.ObservabilityTask{TaskType: "other"} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskType("other")} assert.False(t, service.shouldTriggerBackfill(taskDO)) }) t.Run("missing effective time", func(t *testing.T) { - taskDO := &entity.ObservabilityTask{TaskType: task.TaskTypeAutoEval} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskTypeAutoEval} assert.False(t, service.shouldTriggerBackfill(taskDO)) }) t.Run("valid", func(t *testing.T) { taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoDataReflow, + TaskType: entity.TaskTypeAutoDataReflow, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: 1, EndAt: 2}, } assert.True(t, service.shouldTriggerBackfill(taskDO)) @@ -595,7 +636,7 @@ func TestTaskServiceImpl_shouldTriggerBackfill(t *testing.T) { func TestTaskServiceImpl_sendBackfillMessage(t *testing.T) { t.Run("producer nil", func(t *testing.T) { svc := &TaskServiceImpl{} - err := svc.sendBackfillMessage(context.Background(), &entity.BackFillEvent{}) + err := svc.SendBackfillMessage(context.Background(), &entity.BackFillEvent{}) statusErr, ok := errorx.FromStatusError(err) if assert.True(t, ok) { assert.EqualValues(t, obErrorx.CommonInternalErrorCode, statusErr.Code()) @@ -605,7 +646,7 @@ func TestTaskServiceImpl_sendBackfillMessage(t *testing.T) { t.Run("success", func(t *testing.T) { ch := make(chan *entity.BackFillEvent, 1) svc := &TaskServiceImpl{backfillProducer: &stubBackfillProducer{ch: ch}} - err := svc.sendBackfillMessage(context.Background(), &entity.BackFillEvent{TaskID: 1}) + err := svc.SendBackfillMessage(context.Background(), &entity.BackFillEvent{TaskID: 1}) assert.NoError(t, err) select { case event := <-ch: diff --git a/backend/modules/observability/domain/task/service/taskexe/processor.go b/backend/modules/observability/domain/task/service/taskexe/processor.go new file mode 100644 index 000000000..fa1542a2e --- /dev/null +++ b/backend/modules/observability/domain/task/service/taskexe/processor.go @@ -0,0 +1,45 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package taskexe + +import ( + "context" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" +) + +type Trigger struct { + Task *entity.ObservabilityTask + Span *loop_span.Span + TaskRun *entity.TaskRun +} + +type OnTaskRunCreatedReq struct { + CurrentTask *entity.ObservabilityTask + RunType entity.TaskRunType + RunStartAt int64 + RunEndAt int64 +} +type OnTaskRunFinishedReq struct { + Task *entity.ObservabilityTask + TaskRun *entity.TaskRun +} +type OnTaskFinishedReq struct { + Task *entity.ObservabilityTask + TaskRun *entity.TaskRun + IsFinish bool +} + +type Processor interface { + ValidateConfig(ctx context.Context, config any) error + Invoke(ctx context.Context, trigger *Trigger) error + + OnTaskCreated(ctx context.Context, currentTask *entity.ObservabilityTask) error + OnTaskUpdated(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error + OnTaskFinished(ctx context.Context, param OnTaskFinishedReq) error + + OnTaskRunCreated(ctx context.Context, param OnTaskRunCreatedReq) error + OnTaskRunFinished(ctx context.Context, param OnTaskRunFinishedReq) error +} diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index 5d8b43e52..c25be5912 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -63,7 +63,7 @@ func NewAutoEvaluteProcessor( func (p *AutoEvaluteProcessor) ValidateConfig(ctx context.Context, config any) error { cfg, ok := config.(*task_entity.ObservabilityTask) if !ok { - return taskexe.ErrInvalidConfig + return errorx.NewByCode(obErrorx.CommonInvalidParamCode) } if cfg.EffectiveTime != nil { startAt := cfg.EffectiveTime.StartAt @@ -112,7 +112,7 @@ func (p *AutoEvaluteProcessor) Invoke(ctx context.Context, trigger *taskexe.Trig logs.CtxInfo(ctx, "[task-debug] AutoEvaluteProcessor Invoke, turns is empty") return nil } - taskTTL := trigger.Task.GetTaskttl() + taskTTL := trigger.Task.GetTaskTTL() _ = p.taskRepo.IncrTaskCount(ctx, trigger.Task.ID, taskTTL) _ = p.taskRepo.IncrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL) taskCount, _ := p.taskRepo.GetTaskCount(ctx, trigger.Task.ID) @@ -156,67 +156,67 @@ func (p *AutoEvaluteProcessor) Invoke(ctx context.Context, trigger *taskexe.Trig return nil } -func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { +func (p *AutoEvaluteProcessor) OnTaskCreated(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { taskRuns, err := p.taskRepo.GetBackfillTaskRun(ctx, nil, currentTask.ID) if err != nil { logs.CtxError(ctx, "GetBackfillTaskRun failed, taskID:%d, err:%v", currentTask.ID, err) return err } if ShouldTriggerBackfill(currentTask) && taskRuns == nil { - err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: currentTask, - RunType: task.TaskRunTypeBackFill, + RunType: task_entity.TaskRunTypeBackFill, RunStartAt: time.Now().UnixMilli(), RunEndAt: time.Now().UnixMilli() + (currentTask.BackfillEffectiveTime.EndAt - currentTask.BackfillEffectiveTime.StartAt), }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } - err = p.OnUpdateTaskChange(ctx, currentTask, task.TaskStatusRunning) + err = p.OnTaskUpdated(ctx, currentTask, task.TaskStatusRunning) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } } if ShouldTriggerNewData(ctx, currentTask) { runStartAt, runEndAt := currentTask.GetRunTimeRange() - err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: currentTask, - RunType: task.TaskRunTypeNewData, + RunType: task_entity.TaskRunTypeNewData, RunStartAt: runStartAt, RunEndAt: runEndAt, }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } - err = p.OnUpdateTaskChange(ctx, currentTask, task.TaskStatusRunning) + err = p.OnTaskUpdated(ctx, currentTask, task.TaskStatusRunning) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } } return nil } -func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error { +func (p *AutoEvaluteProcessor) OnTaskUpdated(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task_entity.TaskStatus) error { switch taskOp { - case task.TaskStatusSuccess: - if currentTask.TaskStatus != task.TaskStatusDisabled { - currentTask.TaskStatus = task.TaskStatusSuccess + case task_entity.TaskStatusSuccess: + if currentTask.TaskStatus != task_entity.TaskStatusDisabled { + currentTask.TaskStatus = task_entity.TaskStatusSuccess } - case task.TaskStatusRunning: - if currentTask.TaskStatus != task.TaskStatusDisabled && currentTask.TaskStatus != task.TaskStatusSuccess { - currentTask.TaskStatus = task.TaskStatusRunning + case task_entity.TaskStatusRunning: + if currentTask.TaskStatus != task_entity.TaskStatusDisabled && currentTask.TaskStatus != task_entity.TaskStatusSuccess { + currentTask.TaskStatus = task_entity.TaskStatusRunning } - case task.TaskStatusDisabled: - if currentTask.TaskStatus != task.TaskStatusDisabled { - currentTask.TaskStatus = task.TaskStatusDisabled + case task_entity.TaskStatusDisabled: + if currentTask.TaskStatus != task_entity.TaskStatusDisabled { + currentTask.TaskStatus = task_entity.TaskStatusDisabled } - case task.TaskStatusPending: - if currentTask.TaskStatus == task.TaskStatusPending || currentTask.TaskStatus == task.TaskStatusUnstarted { - currentTask.TaskStatus = task.TaskStatusPending + case task_entity.TaskStatusPending: + if currentTask.TaskStatus == task_entity.TaskStatusPending || currentTask.TaskStatus == task_entity.TaskStatusUnstarted { + currentTask.TaskStatus = task_entity.TaskStatusPending } default: return fmt.Errorf("OnUpdateChangeProcessor, valid taskOp:%s", taskOp) @@ -230,18 +230,18 @@ func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTa return nil } -func (p *AutoEvaluteProcessor) OnFinishTaskChange(ctx context.Context, param taskexe.OnFinishTaskChangeReq) error { - err := p.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{ +func (p *AutoEvaluteProcessor) OnTaskFinished(ctx context.Context, param taskexe.OnTaskFinishedReq) error { + err := p.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{ Task: param.Task, TaskRun: param.TaskRun, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskRunChange failed, taskRun:%+v, err:%v", param.TaskRun, err) + logs.CtxError(ctx, "OnTaskRunFinished failed, taskRun:%+v, err:%v", param.TaskRun, err) return err } if param.IsFinish { - logs.CtxWarn(ctx, "OnFinishTaskChange, taskID:%d, taskRun:%+v, isFinish:%v", param.Task.ID, param.TaskRun, param.IsFinish) - if err := p.OnUpdateTaskChange(ctx, param.Task, task.TaskStatusSuccess); err != nil { + logs.CtxWarn(ctx, "OnTaskFinished, taskID:%d, taskRun:%+v, isFinish:%v", param.Task.ID, param.TaskRun, param.IsFinish) + if err := p.OnTaskUpdated(ctx, param.Task, task.TaskStatusSuccess); err != nil { logs.CtxError(ctx, "OnUpdateChangeProcessor failed, taskID:%d, err:%v", param.Task.ID, err) return err } @@ -260,7 +260,7 @@ const ( BackFillI18N = "BackFill" ) -func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param taskexe.OnCreateTaskRunChangeReq) error { +func (p *AutoEvaluteProcessor) OnTaskRunCreated(ctx context.Context, param taskexe.OnTaskRunCreatedReq) error { currentTask := param.CurrentTask ctx = session.WithCtxUser(ctx, &session.User{ID: currentTask.CreatedBy}) sessionInfo := p.getSession(ctx, currentTask) @@ -302,11 +302,11 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param FromEvalSet: fromEvalSet, }) } - category := getCategory(currentTask.TaskType) + category := getCategory(task.TaskType(currentTask.TaskType)) schema := convertDatasetSchemaDTO2DO(evaluationSetSchema) logs.CtxInfo(ctx, "[auto_task] CreateDataset,category:%s", category) var datasetName, exptName string - if param.RunType == task.TaskRunTypeBackFill { + if param.RunType == task_entity.TaskRunTypeBackFill { datasetName = fmt.Sprintf("%s_%s_%s_%d.%d.%d.%d", AutoEvaluateCN, BackFillCN, currentTask.Name, time.Now().Year(), time.Now().Month(), time.Now().Day(), time.Now().Unix()) exptName = fmt.Sprintf("%s_%s_%s_%d.%d.%d.%d", AutoEvaluateCN, BackFillCN, currentTask.Name, time.Now().Year(), time.Now().Month(), time.Now().Day(), time.Now().Unix()) } else { @@ -383,7 +383,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param TaskID: currentTask.ID, WorkspaceID: currentTask.WorkspaceID, TaskType: param.RunType, - RunStatus: task.RunStatusRunning, + RunStatus: task_entity.TaskRunStatusRunning, RunStartAt: time.UnixMilli(param.RunStartAt), RunEndAt: time.UnixMilli(param.RunEndAt), CreatedAt: time.Now(), @@ -398,7 +398,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param return nil } -func (p *AutoEvaluteProcessor) OnFinishTaskRunChange(ctx context.Context, param taskexe.OnFinishTaskRunChangeReq) error { +func (p *AutoEvaluteProcessor) OnTaskRunFinished(ctx context.Context, param taskexe.OnTaskRunFinishedReq) error { if param.TaskRun == nil || param.TaskRun.TaskRunConfig == nil || param.TaskRun.TaskRunConfig.AutoEvaluateRunConfig == nil { return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index c7c8c3397..b49f56368 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -107,7 +107,7 @@ func (m *taskRepoMockAdapter) RemoveNonFinalTask(context.Context, string, int64) return nil } -func (m *taskRepoMockAdapter) GetTaskByRedis(context.Context, int64) (*taskentity.ObservabilityTask, error) { +func (m *taskRepoMockAdapter) GetTaskByCache(context.Context, int64) (*taskentity.ObservabilityTask, error) { return nil, nil } @@ -125,8 +125,8 @@ func buildTestTask(t *testing.T) *taskentity.ObservabilityTask { WorkspaceID: 202, Name: "auto-eval", CreatedBy: "1001", - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: taskentity.TaskTypeAutoEval, + TaskStatus: taskentity.TaskStatusUnstarted, EffectiveTime: &taskentity.EffectiveTime{ StartAt: start, EndAt: end, @@ -141,7 +141,7 @@ func buildTestTask(t *testing.T) *taskentity.ObservabilityTask { IsCycle: false, CycleCount: 0, CycleInterval: 1, - CycleTimeUnit: task.TimeUnitDay, + CycleTimeUnit: taskentity.TimeUnitDay, }, TaskConfig: &taskentity.TaskConfig{ AutoEvaluateConfigs: []*taskentity.AutoEvaluateConfig{ @@ -224,7 +224,8 @@ func TestAutoEvaluteProcessor_ValidateConfig(t *testing.T) { name: "invalid type", config: "bad", expectErr: func(err error) bool { - return errors.Is(err, taskexe.ErrInvalidConfig) + status, ok := errorx.FromStatusError(err) + return ok && status.Code() == obErrorx.CommonInvalidParamCode }, }, { @@ -313,8 +314,8 @@ func TestAutoEvaluteProcessor_Invoke(t *testing.T) { ID: 1001, TaskID: taskObj.ID, WorkspaceID: taskObj.WorkspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.RunStatusRunning, + TaskType: taskentity.TaskRunTypeNewData, + RunStatus: taskentity.TaskRunStatusRunning, TaskRunConfig: buildTaskRunConfig(schemaStr), } span := buildSpan("{\"parts\":[]}") @@ -431,14 +432,14 @@ func TestAutoEvaluteProcessor_OnUpdateTaskChange(t *testing.T) { cases := []struct { name string - initial string - op task.TaskStatus - expect string + initial taskentity.TaskStatus + op taskentity.TaskStatus + expect taskentity.TaskStatus }{ - {"success", task.TaskStatusRunning, task.TaskStatusSuccess, task.TaskStatusSuccess}, - {"running", task.TaskStatusPending, task.TaskStatusRunning, task.TaskStatusRunning}, - {"disable", task.TaskStatusRunning, task.TaskStatusDisabled, task.TaskStatusDisabled}, - {"pending", task.TaskStatusUnstarted, task.TaskStatusPending, task.TaskStatusPending}, + {"success", taskentity.TaskStatusRunning, taskentity.TaskStatusSuccess, taskentity.TaskStatusSuccess}, + {"running", taskentity.TaskStatusPending, taskentity.TaskStatusRunning, taskentity.TaskStatusRunning}, + {"disable", taskentity.TaskStatusRunning, taskentity.TaskStatusDisabled, taskentity.TaskStatusDisabled}, + {"pending", taskentity.TaskStatusUnstarted, taskentity.TaskStatusPending, taskentity.TaskStatusPending}, } for _, tt := range cases { @@ -457,14 +458,14 @@ func TestAutoEvaluteProcessor_OnUpdateTaskChange(t *testing.T) { proc := &AutoEvaluteProcessor{taskRepo: repoAdapter} taskObj := &taskentity.ObservabilityTask{TaskStatus: caseItem.initial} - err := proc.OnUpdateTaskChange(ctx, taskObj, caseItem.op) + err := proc.OnTaskUpdated(ctx, taskObj, caseItem.op) assert.NoError(t, err) }) } t.Run("invalid op", func(t *testing.T) { proc := &AutoEvaluteProcessor{} - err := proc.OnUpdateTaskChange(ctx, &taskentity.ObservabilityTask{}, "unknown") + err := proc.OnTaskUpdated(ctx, &taskentity.ObservabilityTask{}, "unknown") assert.Error(t, err) }) } @@ -479,9 +480,9 @@ func TestAutoEvaluteProcessor_OnCreateTaskRunChange(t *testing.T) { repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock} taskObj := buildTestTask(t) - param := taskexe.OnCreateTaskRunChangeReq{ + param := taskexe.OnTaskRunCreatedReq{ CurrentTask: taskObj, - RunType: task.TaskRunTypeNewData, + RunType: taskentity.TaskRunTypeNewData, RunStartAt: time.Now().Add(-time.Minute).UnixMilli(), RunEndAt: time.Now().Add(time.Hour).UnixMilli(), } @@ -506,7 +507,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskRunChange(t *testing.T) { } ctx := session.WithCtxUser(context.Background(), &session.User{ID: taskObj.CreatedBy}) - err := proc.OnCreateTaskRunChange(ctx, param) + err := proc.OnTaskRunCreated(ctx, param) assert.NoError(t, err) assert.NotNil(t, evalAdapter.submitReq) assert.Equal(t, int64(9001), *evalAdapter.submitReq.EvalSetID) @@ -537,13 +538,13 @@ func TestAutoEvaluteProcessor_OnFinishTaskRunChange(t *testing.T) { evaluationSvc: evalAdapter, } - err := proc.OnFinishTaskRunChange(context.Background(), taskexe.OnFinishTaskRunChangeReq{ + err := proc.OnTaskRunFinished(context.Background(), taskexe.OnTaskRunFinishedReq{ Task: &taskentity.ObservabilityTask{WorkspaceID: 1234}, TaskRun: taskRun, }) assert.NoError(t, err) assert.NotNil(t, evalAdapter.finishReq) - assert.Equal(t, task.RunStatusDone, taskRun.RunStatus) + assert.Equal(t, taskentity.TaskRunStatusDone, taskRun.RunStatus) } func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { @@ -555,7 +556,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock} evalAdapter := &fakeEvaluationAdapter{} - taskObj := &taskentity.ObservabilityTask{TaskStatus: task.TaskStatusRunning, WorkspaceID: 123} + taskObj := &taskentity.ObservabilityTask{TaskStatus: taskentity.TaskStatusRunning, WorkspaceID: 123} taskRun := &taskentity.TaskRun{TaskRunConfig: &taskentity.TaskRunConfig{AutoEvaluateRunConfig: &taskentity.AutoEvaluateRunConfig{ExptID: 1, ExptRunID: 2}}} repoMock.EXPECT().UpdateTaskRun(gomock.Any(), gomock.Any()).Return(nil) @@ -566,13 +567,13 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { taskRepo: repoAdapter, } - err := proc.OnFinishTaskChange(context.Background(), taskexe.OnFinishTaskChangeReq{ + err := proc.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ Task: taskObj, TaskRun: taskRun, IsFinish: true, }) assert.NoError(t, err) - assert.Equal(t, task.TaskStatusSuccess, taskObj.TaskStatus) + assert.Equal(t, taskentity.TaskStatusSuccess, taskObj.TaskStatus) } func TestAutoEvaluteProcessor_OnFinishTaskChange_Error(t *testing.T) { @@ -590,7 +591,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange_Error(t *testing.T) { taskRepo: repoAdapter, } - err := proc.OnFinishTaskChange(context.Background(), taskexe.OnFinishTaskChangeReq{ + err := proc.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ Task: &taskentity.ObservabilityTask{WorkspaceID: 123}, TaskRun: &taskentity.TaskRun{TaskRunConfig: &taskentity.TaskRunConfig{AutoEvaluateRunConfig: &taskentity.AutoEvaluateRunConfig{ExptID: 1, ExptRunID: 2}}}, }) @@ -621,10 +622,10 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { } taskObj := buildTestTask(t) - taskObj.TaskStatus = task.TaskStatusPending + taskObj.TaskStatus = taskentity.TaskStatusPending - var runTypes []task.TaskRunType - var statuses []task.TaskStatus + var runTypes []taskentity.TaskRunType + var statuses []taskentity.TaskStatus getBackfill := repoMock.EXPECT().GetBackfillTaskRun(gomock.Any(), (*int64)(nil), taskObj.ID).Return(nil, nil) createDatasetBackfill := datasetProvider.EXPECT().CreateDataset(gomock.Any(), gomock.AssignableToTypeOf(&traceentity.Dataset{})).Return(int64(9101), nil) @@ -666,11 +667,11 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { updateTaskNewData, ) - err := proc.OnCreateTaskChange(context.Background(), taskObj) + err := proc.OnTaskCreated(context.Background(), taskObj) assert.NoError(t, err) - assert.Equal(t, []task.TaskRunType{task.TaskRunTypeBackFill, task.TaskRunTypeNewData}, runTypes) - assert.Equal(t, []task.TaskStatus{task.TaskStatusRunning, task.TaskStatusRunning}, statuses) - assert.Equal(t, task.TaskStatusRunning, taskObj.TaskStatus) + assert.Equal(t, []taskentity.TaskRunType{taskentity.TaskRunTypeBackFill, taskentity.TaskRunTypeNewData}, runTypes) + assert.Equal(t, []taskentity.TaskStatus{taskentity.TaskStatusRunning, taskentity.TaskStatusRunning}, statuses) + assert.Equal(t, taskentity.TaskStatusRunning, taskObj.TaskStatus) } func TestAutoEvaluteProcessor_OnCreateTaskChange_GetBackfillError(t *testing.T) { @@ -685,7 +686,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange_GetBackfillError(t *testing.T) proc := &AutoEvaluteProcessor{taskRepo: repoAdapter} - err := proc.OnCreateTaskChange(context.Background(), buildTestTask(t)) + err := proc.OnTaskCreated(context.Background(), buildTestTask(t)) assert.EqualError(t, err, "db error") } @@ -710,7 +711,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange_CreateDatasetError(t *testing.T repoMock.EXPECT().GetBackfillTaskRun(gomock.Any(), (*int64)(nil), gomock.Any()).Return(nil, nil) datasetProvider.EXPECT().CreateDataset(gomock.Any(), gomock.AssignableToTypeOf(&traceentity.Dataset{})).Return(int64(0), errors.New("create fail")) - err := proc.OnCreateTaskChange(context.Background(), buildTestTask(t)) + err := proc.OnTaskCreated(context.Background(), buildTestTask(t)) assert.EqualError(t, err, "create fail") } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/factory.go b/backend/modules/observability/domain/task/service/taskexe/processor/factory.go index 131437f7f..9ccd32dc4 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/factory.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/factory.go @@ -4,26 +4,26 @@ package processor import ( - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) type TaskProcessor struct { - taskProcessorMap map[task.TaskType]taskexe.Processor + taskProcessorMap map[entity.TaskType]taskexe.Processor } func NewTaskProcessor() *TaskProcessor { return &TaskProcessor{} } -func (t *TaskProcessor) Register(taskType task.TaskType, taskProcessor taskexe.Processor) { +func (t *TaskProcessor) Register(taskType entity.TaskType, taskProcessor taskexe.Processor) { if t.taskProcessorMap == nil { - t.taskProcessorMap = make(map[task.TaskType]taskexe.Processor) + t.taskProcessorMap = make(map[entity.TaskType]taskexe.Processor) } t.taskProcessorMap[taskType] = taskProcessor } -func (t *TaskProcessor) GetTaskProcessor(taskType task.TaskType) taskexe.Processor { +func (t *TaskProcessor) GetTaskProcessor(taskType entity.TaskType) taskexe.Processor { datasetProvider, ok := t.taskProcessorMap[taskType] if !ok { return NewNoopTaskProcessor() diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go index 132644984..7d5773f23 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) @@ -18,13 +19,13 @@ func TestTaskProcessor_RegisterAndGet(t *testing.T) { taskProcessor := NewTaskProcessor() - defaultProcessor := taskProcessor.GetTaskProcessor("unknown") + defaultProcessor := taskProcessor.GetTaskProcessor(entity.TaskType("unknown")) _, ok := defaultProcessor.(*NoopTaskProcessor) assert.True(t, ok) registered := NewNoopTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, registered) - assert.Equal(t, registered, taskProcessor.GetTaskProcessor(task.TaskTypeAutoEval)) + taskProcessor.Register(entity.TaskTypeAutoEval, registered) + assert.Equal(t, registered, taskProcessor.GetTaskProcessor(entity.TaskTypeAutoEval)) } func TestNoopTaskProcessor_Methods(t *testing.T) { @@ -34,9 +35,9 @@ func TestNoopTaskProcessor_Methods(t *testing.T) { assert.NoError(t, p.ValidateConfig(ctx, nil)) assert.NoError(t, p.Invoke(ctx, nil)) - assert.NoError(t, p.OnCreateTaskChange(ctx, nil)) - assert.NoError(t, p.OnUpdateTaskChange(ctx, nil, task.TaskStatusRunning)) - assert.NoError(t, p.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{})) - assert.NoError(t, p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{})) - assert.NoError(t, p.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{})) + assert.NoError(t, p.OnTaskCreated(ctx, nil)) + assert.NoError(t, p.OnTaskUpdated(ctx, nil, task.TaskStatusRunning)) + assert.NoError(t, p.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{})) + assert.NoError(t, p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{})) + assert.NoError(t, p.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{})) } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go index 97da34ddb..37177c8b2 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go @@ -6,8 +6,7 @@ package processor import ( "context" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) @@ -27,22 +26,22 @@ func (p *NoopTaskProcessor) Invoke(ctx context.Context, trigger *taskexe.Trigger return nil } -func (p *NoopTaskProcessor) OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { +func (p *NoopTaskProcessor) OnTaskCreated(ctx context.Context, currentTask *entity.ObservabilityTask) error { return nil } -func (p *NoopTaskProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error { +func (p *NoopTaskProcessor) OnTaskUpdated(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error { return nil } -func (p *NoopTaskProcessor) OnFinishTaskChange(ctx context.Context, param taskexe.OnFinishTaskChangeReq) error { +func (p *NoopTaskProcessor) OnTaskFinished(ctx context.Context, param taskexe.OnTaskFinishedReq) error { return nil } -func (p *NoopTaskProcessor) OnCreateTaskRunChange(ctx context.Context, param taskexe.OnCreateTaskRunChangeReq) error { +func (p *NoopTaskProcessor) OnTaskRunCreated(ctx context.Context, param taskexe.OnTaskRunCreatedReq) error { return nil } -func (p *NoopTaskProcessor) OnFinishTaskRunChange(ctx context.Context, param taskexe.OnFinishTaskRunChangeReq) error { +func (p *NoopTaskProcessor) OnTaskRunFinished(ctx context.Context, param taskexe.OnTaskRunFinishedReq) error { return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go index 0d480490b..8c244584d 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go @@ -57,9 +57,9 @@ func TestShouldTriggerBackfill(t *testing.T) { task *taskentity.ObservabilityTask expected bool }{ - {"nil_time", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval}, false}, - {"invalid_type", &taskentity.ObservabilityTask{TaskType: "manual"}, false}, - {"invalid_range", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, BackfillEffectiveTime: &taskentity.EffectiveTime{StartAt: 10, EndAt: 5}}, false}, + {"nil_time", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval}, false}, + {"invalid_type", &taskentity.ObservabilityTask{TaskType: taskentity.TaskType("manual")}, false}, + {"invalid_range", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, BackfillEffectiveTime: &taskentity.EffectiveTime{StartAt: 10, EndAt: 5}}, false}, {"valid", baseTask, true}, } @@ -90,10 +90,10 @@ func TestShouldTriggerNewData(t *testing.T) { task *taskentity.ObservabilityTask expected bool }{ - {"invalid_type", &taskentity.ObservabilityTask{TaskType: "manual"}, false}, - {"nil_time", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval}, false}, - {"invalid_range", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: 20, EndAt: 10}}, false}, - {"start_in_future", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: now.Add(time.Hour).UnixMilli(), EndAt: now.Add(2 * time.Hour).UnixMilli()}}, false}, + {"invalid_type", &taskentity.ObservabilityTask{TaskType: taskentity.TaskType("manual")}, false}, + {"nil_time", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval}, false}, + {"invalid_range", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: 20, EndAt: 10}}, false}, + {"start_in_future", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: now.Add(time.Hour).UnixMilli(), EndAt: now.Add(2 * time.Hour).UnixMilli()}}, false}, {"valid", baseTask, true}, } diff --git a/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go new file mode 100644 index 000000000..0d254efe9 --- /dev/null +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go @@ -0,0 +1,94 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package scheduledtask + +import ( + "context" + "strconv" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" +) + +type LocalCacheRefreshTask struct { + scheduledtask.BaseScheduledTask + + traceHubService tracehub.ITraceHubService + taskRepo repo.ITaskRepo +} + +func NewLocalCacheRefreshTask(traceHubService tracehub.ITraceHubService, taskRepo repo.ITaskRepo) scheduledtask.ScheduledTask { + return &LocalCacheRefreshTask{ + BaseScheduledTask: scheduledtask.NewBaseScheduledTask("LocalCacheRefreshTask", 2*time.Minute), + traceHubService: traceHubService, + taskRepo: taskRepo, + } +} + +func (t *LocalCacheRefreshTask) RunOnce(ctx context.Context) error { + logs.CtxInfo(ctx, "Start syncing task cache...") + + // 1. Retrieve spaceID, botID, and task information for all non-final tasks from the database + spaceIDs, botIDs, tasks, err := t.getNonFinalTaskInfos(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return err + } + logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) + + if err := t.traceHubService.StoneTaskCache(ctx, tracehub.TaskCacheInfo{ + WorkspaceIDs: spaceIDs, + BotIDs: botIDs, + Tasks: tasks, + UpdateTime: time.Now(), // Set the current time as the update time + }); err != nil { + logs.CtxError(ctx, "Failed to update task cache", "err", err) + return err + } + return nil +} + +func (t *LocalCacheRefreshTask) getNonFinalTaskInfos(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask, error) { + tasks, err := t.taskRepo.ListNonFinalTasks(ctx) + if err != nil { + return nil, nil, nil, err + } + + spaceMap := make(map[string]interface{}) + botMap := make(map[string]interface{}) + + for _, task := range tasks { + spaceMap[strconv.FormatInt(task.WorkspaceID, 10)] = struct{}{} + if task.SpanFilter != nil && task.SpanFilter.Filters.FilterFields != nil { + extractBotIDFromFilters(task.SpanFilter.Filters.FilterFields, botMap) + } + } + + return lo.Keys(spaceMap), lo.Keys(botMap), tasks, nil +} + +// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter +func extractBotIDFromFilters(filterFields []*loop_span.FilterField, botMap map[string]interface{}) { + for _, filterField := range filterFields { + if filterField == nil { + continue + } + // 检查当前 FilterField 的 FieldName + if filterField.FieldName == "bot_id" { + for _, v := range filterField.Values { + botMap[v] = struct{}{} + } + } + // 递归处理 SubFilter + if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { + extractBotIDFromFilters(filterField.SubFilter.FilterFields, botMap) + } + } +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go similarity index 75% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go rename to backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go index 1ac5e587c..0fc7e4a59 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package scheduledtask import ( "context" @@ -13,38 +13,38 @@ import ( "go.uber.org/mock/gomock" lock_mocks "github.com/coze-dev/coze-loop/backend/infra/lock/mocks" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/stretchr/testify/require" ) type trackingProcessor struct { *stubProcessor - finishReqs []taskexe.OnFinishTaskChangeReq - createRunReqs []taskexe.OnCreateTaskRunChangeReq - updateStatuses []string + finishReqs []taskexe.OnTaskFinishedReq + createRunReqs []taskexe.OnTaskRunCreatedReq + updateStatuses []entity.TaskStatus } func newTrackingProcessor() *trackingProcessor { return &trackingProcessor{stubProcessor: &stubProcessor{}} } -func (p *trackingProcessor) OnFinishTaskChange(ctx context.Context, req taskexe.OnFinishTaskChangeReq) error { +func (p *trackingProcessor) OnTaskFinished(ctx context.Context, req taskexe.OnTaskFinishedReq) error { p.finishReqs = append(p.finishReqs, req) - return p.stubProcessor.OnFinishTaskChange(ctx, req) + return p.stubProcessor.OnTaskFinished(ctx, req) } -func (p *trackingProcessor) OnCreateTaskRunChange(ctx context.Context, req taskexe.OnCreateTaskRunChangeReq) error { +func (p *trackingProcessor) OnTaskRunCreated(ctx context.Context, req taskexe.OnTaskRunCreatedReq) error { p.createRunReqs = append(p.createRunReqs, req) - return p.stubProcessor.OnCreateTaskRunChange(ctx, req) + return p.stubProcessor.OnTaskRunCreated(ctx, req) } -func (p *trackingProcessor) OnUpdateTaskChange(ctx context.Context, obsTask *entity.ObservabilityTask, status string) error { +func (p *trackingProcessor) OnTaskUpdated(ctx context.Context, obsTask *entity.ObservabilityTask, status entity.TaskStatus) error { p.updateStatuses = append(p.updateStatuses, status) - return p.stubProcessor.OnUpdateTaskChange(ctx, obsTask, status) + return p.stubProcessor.OnTaskUpdated(ctx, obsTask, status) } func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { @@ -63,23 +63,23 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { backfillRun := &entity.TaskRun{ ID: 2, TaskID: 1, - TaskType: string(task.TaskRunTypeBackFill), - RunStatus: string(task.RunStatusDone), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusDone, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-2 * time.Hour), } currentRun := &entity.TaskRun{ ID: 3, TaskID: 1, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-4 * time.Hour), RunEndAt: now.Add(2 * time.Hour), } taskPO := &entity.ObservabilityTask{ ID: 1, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, EffectiveTime: &entity.EffectiveTime{ StartAt: now.Add(-5 * time.Hour).UnixMilli(), EndAt: now.Add(-1 * time.Hour).UnixMilli(), @@ -95,7 +95,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -117,8 +117,8 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { now := time.Now() taskPO := &entity.ObservabilityTask{ ID: 10, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusUnstarted), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{ StartAt: now.Add(-2 * time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli(), @@ -129,7 +129,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -140,9 +140,9 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { }, assert: func(t *testing.T, _ *TraceHubServiceImpl, proc *trackingProcessor) { require.Len(t, proc.createRunReqs, 1) - require.Equal(t, task.TaskRunTypeNewData, proc.createRunReqs[0].RunType) + require.Equal(t, entity.TaskRunTypeNewData, proc.createRunReqs[0].RunType) require.Len(t, proc.updateStatuses, 1) - require.Equal(t, string(task.TaskStatusRunning), proc.updateStatuses[0]) + require.Equal(t, entity.TaskStatusRunning, proc.updateStatuses[0]) }, }, { @@ -153,15 +153,15 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { currentRun := &entity.TaskRun{ ID: 30, TaskID: 20, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), } taskPO := &entity.ObservabilityTask{ ID: 20, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, Sampler: &entity.Sampler{IsCycle: true}, TaskRuns: []*entity.TaskRun{currentRun}, } @@ -169,7 +169,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -194,24 +194,24 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { backfillRun := &entity.TaskRun{ ID: 40, TaskID: 40, - TaskType: string(task.TaskRunTypeBackFill), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Hour), RunEndAt: now.Add(time.Hour), } currentRun := &entity.TaskRun{ ID: 41, TaskID: 40, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Hour), RunEndAt: now.Add(time.Hour), } taskPO := &entity.ObservabilityTask{ ID: 40, WorkspaceID: 99, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-2 * time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, Sampler: &entity.Sampler{IsCycle: false}, TaskRuns: []*entity.TaskRun{backfillRun, currentRun}, @@ -223,7 +223,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) producer := &stubBackfillProducer{ch: make(chan *entity.BackFillEvent, 1)} impl := &TraceHubServiceImpl{ @@ -336,10 +336,26 @@ func TestTraceHubServiceImpl_syncTaskCache(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} impl.taskCache.Store("ObjListWithTask", TaskCacheInfo{}) - workspaceIDs := []string{"space-1"} + tasks := []*entity.ObservabilityTask{ + { + ID: 100, + WorkspaceID: 1, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-1"}, + }, + }, + }, + }, + }, + } + workspaceIDs := []string{"1"} botIDs := []string{"bot-1"} - tasks := []*entity.ObservabilityTask{{ID: 100}} - mockRepo.EXPECT().GetObjListWithTask(gomock.Any()).Return(workspaceIDs, botIDs, tasks) + + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(tasks, nil) before := time.Now() impl.syncTaskCache() @@ -451,3 +467,88 @@ func TestTraceHubServiceImpl_listNonFinalTask(t *testing.T) { require.Nil(t, tasks) }) } + +func TestTraceHubServiceImpl_getNonFinalTaskInfos(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TraceHubServiceImpl{taskRepo: mockRepo} + + tasks := []*entity.ObservabilityTask{ + { + WorkspaceID: 101, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-a", "bot-b"}, + }, + { + FieldName: "ignored", + SubFilter: &loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-c"}, + }, + }, + }, + }, + }, + }, + }, + }, + { + WorkspaceID: 202, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "other", + Values: []string{"value"}, + }, + }, + }, + }, + }, + { + WorkspaceID: 101, + }, + } + + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(tasks, nil) + + workspaceIDs, botIDs, resultTasks, err := impl.getNonFinalTaskInfos(context.Background()) + require.NoError(t, err) + require.ElementsMatch(t, []string{"101", "202"}, workspaceIDs) + require.ElementsMatch(t, []string{"bot-a", "bot-b", "bot-c"}, botIDs) + require.Equal(t, tasks, resultTasks) + }) + + t.Run("repo error", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TraceHubServiceImpl{taskRepo: mockRepo} + + expectErr := errors.New("repo err") + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(nil, expectErr) + + workspaceIDs, botIDs, tasks, err := impl.getNonFinalTaskInfos(context.Background()) + require.Error(t, err) + require.ErrorIs(t, err, expectErr) + require.Nil(t, workspaceIDs) + require.Nil(t, botIDs) + require.Nil(t, tasks) + }) +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go old mode 100755 new mode 100644 similarity index 52% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go rename to backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go index d9b07aa36..edf8212a7 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go @@ -1,27 +1,28 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package scheduledtask import ( "context" "fmt" - "os" - "slices" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/infra/lock" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/pkg/json" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/pkg/errors" ) -// TaskRunCountInfo represents the TaskRunCount information structure type TaskRunCountInfo struct { TaskID int64 TaskRunID int64 @@ -30,86 +31,85 @@ type TaskRunCountInfo struct { TaskRunFailCount int64 } -// TaskCacheInfo represents task cache information -type TaskCacheInfo struct { - WorkspaceIDs []string - BotIDs []string - Tasks []*entity.ObservabilityTask - UpdateTime time.Time -} - const ( - transformTaskStatusLockKey = "observability:tracehub:transform_task_status" - transformTaskStatusLockTTL = 3 * time.Minute - syncTaskRunCountsLockKey = "observability:tracehub:sync_task_run_counts" + syncTaskRunCountLockTTL = 3 * time.Minute + checkTaskStatusLockKey = "observability:task:check_task_status" + checkTaskStatusLockTTL = 3 * time.Minute + backfillLockKeyTemplate = "observability:tracehub:backfill:%d" + backfillLockMaxHold = 24 * time.Hour ) -// startScheduledTask launches the scheduled task goroutine -func (h *TraceHubServiceImpl) startScheduledTask() { - h.syncTaskCache() - go func() { - for { - select { - case <-h.scheduledTaskTicker.C: - // Execute scheduled task - h.transformTaskStatus() // 抢锁 - case <-h.stopChan: - // Stop scheduled task - h.scheduledTaskTicker.Stop() - return - } - } - }() - go func() { - for { - select { - case <-h.syncTaskTicker.C: - // Execute scheduled task - h.syncTaskRunCounts() // 抢锁 - h.syncTaskCache() - case <-h.stopChan: - // Stop scheduled task - h.syncTaskTicker.Stop() - return - } - } - }() +type StatusCheckTask struct { + scheduledtask.BaseScheduledTask + + config config.ITraceConfig + locker lock.ILocker + traceHubService tracehub.ITraceHubService + taskService service.ITaskService + taskProcessor processor.TaskProcessor + taskRepo repo.ITaskRepo } -func (h *TraceHubServiceImpl) transformTaskStatus() { - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := h.loader.UnmarshalKey(context.Background(), key, cfg); err != nil { - return - } - if !cfg.IsEnabled || !cfg.IsAllSpace { - return +func NewStatusCheckTask( + locker lock.ILocker, + config config.ITraceConfig, + traceHubService tracehub.ITraceHubService, + taskService service.ITaskService, + taskProcessor processor.TaskProcessor, + taskRepo repo.ITaskRepo, +) scheduledtask.ScheduledTask { + return &StatusCheckTask{ + BaseScheduledTask: scheduledtask.NewBaseScheduledTask("StatusCheckTask", 5*time.Minute), + locker: locker, + config: config, + traceHubService: traceHubService, + taskService: taskService, + taskProcessor: taskProcessor, + taskRepo: taskRepo, } +} - if slices.Contains([]string{TracehubClusterName, InjectClusterName}, os.Getenv(TceCluster)) { - return +func (t *StatusCheckTask) RunOnce(ctx context.Context) error { + cfg, err := t.config.GetConsumerListening(ctx) + if err != nil { + return err + } + if !cfg.IsEnabled || !cfg.IsAllSpace { + return nil } - ctx := context.Background() - ctx = h.fillCtx(ctx) - if h.locker != nil { - locked, lockErr := h.locker.Lock(ctx, transformTaskStatusLockKey, transformTaskStatusLockTTL) + if t.locker != nil { + locked, lockErr := t.locker.Lock(ctx, checkTaskStatusLockKey, checkTaskStatusLockTTL) if lockErr != nil { logs.CtxError(ctx, "transformTaskStatus acquire lock failed", "err", lockErr) - return + return lockErr } if !locked { logs.CtxInfo(ctx, "transformTaskStatus lock held by others, skip execution") - return + return nil } } + + if err = t.checkTaskStatus(ctx); err != nil { + logs.CtxError(ctx, "Failed to check task status", "err", err) + return err + } + if err = t.syncTaskRunCount(ctx); err != nil { + logs.CtxError(ctx, "Failed to sync task run count", "err", err) + return err + } + + return nil +} + +func (t *StatusCheckTask) checkTaskStatus(ctx context.Context) error { logs.CtxInfo(ctx, "Scheduled task started...") // Read all non-final (success/disabled) tasks - taskPOs, err := h.listNonFinalTask(ctx) + taskPOs, err := t.listNonFinalTask(ctx) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return + return err } logs.CtxInfo(ctx, "Scheduled task retrieved number of tasks:%d", len(taskPOs)) for _, taskPO := range taskPOs { @@ -123,36 +123,36 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { endTime = time.UnixMilli(taskPO.EffectiveTime.EndAt) startTime = time.UnixMilli(taskPO.EffectiveTime.StartAt) } - proc := h.taskProcessor.GetTaskProcessor(taskPO.TaskType) + proc := t.taskProcessor.GetTaskProcessor(taskPO.TaskType) // Task time horizon reached // End when the task end time is reached logs.CtxInfo(ctx, "[auto_task]taskID:%d, endTime:%v, startTime:%v", taskPO.ID, endTime, startTime) if taskPO.BackfillEffectiveTime != nil && taskPO.EffectiveTime != nil && backfillTaskRun != nil { - if time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + if time.Now().After(endTime) && backfillTaskRun.RunStatus == entity.TaskRunStatusDone { + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } - if backfillTaskRun.RunStatus != task.RunStatusDone { + if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { if time.Now().Add(-backfillTaskRun.RunEndAt.Sub(backfillTaskRun.RunStartAt)).Before(backfillTaskRun.RunEndAt) { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) - locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) + locked, _, cancel, lockErr := t.locker.LockWithRenew(ctx, lockKey, syncTaskRunCountLockTTL, backfillLockMaxHold) if lockErr != nil || !locked { - _ = h.sendBackfillMessage(ctx, &entity.BackFillEvent{ + _ = t.taskService.SendBackfillMessage(ctx, &entity.BackFillEvent{ TaskID: taskPO.ID, SpaceID: taskPO.WorkspaceID, }) } defer cancel() } else { - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: false, @@ -164,31 +164,31 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } } else if taskPO.BackfillEffectiveTime != nil && backfillTaskRun != nil { - if backfillTaskRun.RunStatus == task.RunStatusDone { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + if backfillTaskRun.RunStatus == entity.TaskRunStatusDone { + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } - if backfillTaskRun.RunStatus != task.RunStatusDone { + if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { if time.Now().Add(-backfillTaskRun.RunEndAt.Sub(backfillTaskRun.RunStartAt)).Before(backfillTaskRun.RunEndAt) { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) - locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) + locked, _, cancel, lockErr := t.locker.LockWithRenew(ctx, lockKey, syncTaskRunCountLockTTL, backfillLockMaxHold) if lockErr != nil || !locked { - _ = h.sendBackfillMessage(ctx, &entity.BackFillEvent{ + _ = t.taskService.SendBackfillMessage(ctx, &entity.BackFillEvent{ TaskID: taskPO.ID, SpaceID: taskPO.WorkspaceID, }) } defer cancel() } else { - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: false, @@ -201,24 +201,24 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } else if taskPO.EffectiveTime != nil { if time.Now().After(endTime) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(endTime)", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(endTime)", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: taskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } } // If the task status is unstarted, create it once the task start time is reached - if taskPO.TaskStatus == task.TaskStatusUnstarted && time.Now().After(startTime) { + if taskPO.TaskStatus == entity.TaskStatusUnstarted && time.Now().After(startTime) { runStartAt, runEndAt := taskPO.GetRunTimeRange() - err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = proc.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: taskPO, - RunType: task.TaskRunTypeNewData, + RunType: entity.TaskRunTypeNewData, RunStartAt: runStartAt, RunEndAt: runEndAt, }) @@ -226,14 +226,14 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) continue } - err = proc.OnUpdateTaskChange(ctx, taskPO, task.TaskStatusRunning) + err = proc.OnTaskUpdated(ctx, taskPO, entity.TaskStatusRunning) if err != nil { logs.CtxError(ctx, "OnUpdateTaskChange err:%v", err) continue } } // Handle taskRun - if taskPO.TaskStatus == task.TaskStatusRunning || taskPO.TaskStatus == task.TaskStatusPending { + if taskPO.TaskStatus == entity.TaskStatusRunning || taskPO.TaskStatus == entity.TaskStatusPending { if taskRun == nil { logs.CtxError(ctx, "taskID:%d, taskRun is nil", taskPO.ID) continue @@ -241,62 +241,45 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { logs.CtxInfo(ctx, "taskID:%d, taskRun.RunEndAt:%v", taskPO.ID, taskRun.RunEndAt) // Handling repeated tasks: single task time horizon reached if time.Now().After(taskRun.RunEndAt) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(cycleEndTime)", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(cycleEndTime)", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: taskRun, IsFinish: false, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } if taskPO.Sampler.IsCycle { - err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = proc.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: taskPO, - RunType: task.TaskRunTypeNewData, + RunType: entity.TaskRunTypeNewData, RunStartAt: taskRun.RunEndAt.UnixMilli(), RunEndAt: taskRun.RunEndAt.UnixMilli() + (taskRun.RunEndAt.UnixMilli() - taskRun.RunStartAt.UnixMilli()), }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) + logs.CtxError(ctx, "OnTaskRunCreated err:%v", err) continue } } } } } + return nil } -// syncTaskRunCounts synchronizes TaskRunCount data to the database -func (h *TraceHubServiceImpl) syncTaskRunCounts() { - if slices.Contains([]string{TracehubClusterName, InjectClusterName}, os.Getenv(TceCluster)) { - return - } - ctx := context.Background() - ctx = h.fillCtx(ctx) - - if h.locker != nil { - locked, lockErr := h.locker.Lock(ctx, syncTaskRunCountsLockKey, transformTaskStatusLockTTL) - if lockErr != nil { - logs.CtxError(ctx, "syncTaskRunCounts acquire lock failed", "err", lockErr) - return - } - if !locked { - logs.CtxInfo(ctx, "syncTaskRunCounts lock held by others, skip execution") - return - } - } +func (t *StatusCheckTask) syncTaskRunCount(ctx context.Context) error { logs.CtxInfo(ctx, "Start syncing TaskRunCounts to database...") // 1. Retrieve non-final task list - taskDOs, err := h.listSyncTaskRunTask(ctx) + taskDOs, err := t.listSyncTaskRunTask(ctx) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return + return err } if len(taskDOs) == 0 { logs.CtxInfo(ctx, "No non-final tasks need syncing") - return + return nil } // 2. Collect all TaskRun information that needs syncing @@ -316,7 +299,7 @@ func (h *TraceHubServiceImpl) syncTaskRunCounts() { if len(taskRunInfos) == 0 { logs.CtxInfo(ctx, "No TaskRun requires syncing") - return + return nil } logs.CtxInfo(ctx, "Number of TaskRun entries requiring syncing:%d", len(taskRunInfos)) @@ -330,44 +313,115 @@ func (h *TraceHubServiceImpl) syncTaskRunCounts() { } batch := taskRunInfos[i:end] - h.processBatch(ctx, batch) + t.processBatch(ctx, batch) } + return nil } -func (h *TraceHubServiceImpl) syncTaskCache() { - ctx := context.Background() - ctx = h.fillCtx(ctx) +func (t *StatusCheckTask) listSyncTaskRunTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { + var taskDOs []*entity.ObservabilityTask + taskDOs, err := t.listNonFinalTask(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return nil, err + } + var offset int32 = 0 + const limit int32 = 1000 + // Paginate through all tasks + for { + tasklist, _, err := t.taskRepo.ListTasks(ctx, repo.ListTaskParam{ + ReqLimit: limit, + ReqOffset: offset, + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ + { + FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), + Values: []string{ + string(entity.TaskStatusSuccess), + string(entity.TaskStatusDisabled), + }, + QueryType: ptr.Of(entity.QueryTypeIn), + FieldType: ptr.Of(entity.FieldTypeString), + }, + { + FieldName: ptr.Of(entity.TaskFieldName("updated_at")), + Values: []string{ + fmt.Sprintf("%d", time.Now().Add(-24*time.Hour).UnixMilli()), + }, + QueryType: ptr.Of(entity.QueryTypeGt), + FieldType: ptr.Of(entity.FieldTypeLong), + }, + }, + }, + }) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + break + } - logs.CtxInfo(ctx, "Start syncing task cache...") + // Add tasks from the current page to the full list + taskDOs = append(taskDOs, tasklist...) - // 1. Retrieve spaceID, botID, and task information for all non-final tasks from the database - spaceIDs, botIDs, tasks := h.taskRepo.GetObjListWithTask(ctx) - logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) + // If fewer tasks than limit are returned, this is the last page + if len(tasklist) < int(limit) { + break + } - // 2. Build a new cache map - newCache := TaskCacheInfo{ - WorkspaceIDs: spaceIDs, - BotIDs: botIDs, - Tasks: tasks, - UpdateTime: time.Now(), // Set the current time as the update time + // Move to the next page, increasing offset by 1000 + offset += limit } + return taskDOs, nil +} + +func (t *StatusCheckTask) listNonFinalTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { + var taskPOs []*entity.ObservabilityTask + var offset int32 = 0 + const limit int32 = 500 + // Paginate through all tasks + for { + tasklist, _, err := t.taskRepo.ListTasks(ctx, repo.ListTaskParam{ + ReqLimit: limit, + ReqOffset: offset, + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ + { + FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), + Values: []string{ + string(entity.TaskStatusUnstarted), + string(entity.TaskStatusRunning), + string(entity.TaskStatusPending), + }, + QueryType: ptr.Of(entity.QueryTypeIn), + FieldType: ptr.Of(entity.FieldTypeString), + }, + }, + }, + }) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return nil, err + } - // 3. Clear old cache and update with new cache - h.taskCacheLock.Lock() - defer h.taskCacheLock.Unlock() + // Add tasks from the current page to the full list + taskPOs = append(taskPOs, tasklist...) - // 4. Write new cache into local cache - h.taskCache.Store("ObjListWithTask", newCache) + // If fewer tasks than limit are returned, this is the last page + if len(tasklist) < int(limit) { + break + } - logs.CtxInfo(ctx, "Task cache sync completed, taskCount:%d, updateTime:%s", len(tasks), newCache.UpdateTime.Format(time.RFC3339)) + // Move to the next page, increasing offset by 1000 + offset += limit + } + return taskPOs, nil } // processBatch synchronizes TaskRun counts in batches -func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRunCountInfo) { +func (t *StatusCheckTask) processBatch(ctx context.Context, batch []*TaskRunCountInfo) { // 1. Read Redis count data in batch for _, info := range batch { // Read taskruncount - count, err := h.taskRepo.GetTaskRunCount(ctx, info.TaskID, info.TaskRunID) + count, err := t.taskRepo.GetTaskRunCount(ctx, info.TaskID, info.TaskRunID) if err != nil || count == -1 { logs.CtxWarn(ctx, "Failed to get TaskRunCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) } else { @@ -375,7 +429,7 @@ func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRun } // Read taskrun success count - successCount, err := h.taskRepo.GetTaskRunSuccessCount(ctx, info.TaskID, info.TaskRunID) + successCount, err := t.taskRepo.GetTaskRunSuccessCount(ctx, info.TaskID, info.TaskRunID) if err != nil || successCount == -1 { logs.CtxWarn(ctx, "Failed to get TaskRunSuccessCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) } else { @@ -383,7 +437,7 @@ func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRun } // Read taskrun fail count - failCount, err := h.taskRepo.GetTaskRunFailCount(ctx, info.TaskID, info.TaskRunID) + failCount, err := t.taskRepo.GetTaskRunFailCount(ctx, info.TaskID, info.TaskRunID) if err != nil || failCount == -1 { logs.CtxWarn(ctx, "Failed to get TaskRunFailCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) } else { @@ -400,7 +454,7 @@ func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRun logs.CtxInfo(ctx, "Start updating TaskRun detail in batch, batchSize:%d, batch:%v", len(batch), batch) // 2. Update database in batch for _, info := range batch { - err := h.updateTaskRunDetail(ctx, info) + err := t.updateTaskRunDetail(ctx, info) if err != nil { logs.CtxError(ctx, "Failed to update TaskRun detail", "taskID", info.TaskID, @@ -417,7 +471,7 @@ func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRun } // updateTaskRunDetail updates the run_detail field of TaskRun -func (h *TraceHubServiceImpl) updateTaskRunDetail(ctx context.Context, info *TaskRunCountInfo) error { +func (t *StatusCheckTask) updateTaskRunDetail(ctx context.Context, info *TaskRunCountInfo) error { // Build run_detail JSON data runDetail := map[string]interface{}{ "total_count": info.TaskRunCount, @@ -426,8 +480,8 @@ func (h *TraceHubServiceImpl) updateTaskRunDetail(ctx context.Context, info *Tas } // Update using optimistic locking - err := h.taskRepo.UpdateTaskRunWithOCC(ctx, info.TaskRunID, 0, map[string]interface{}{ - "run_detail": ToJSONString(ctx, runDetail), + err := t.taskRepo.UpdateTaskRunWithOCC(ctx, info.TaskRunID, 0, map[string]interface{}{ + "run_detail": json.MarshalStringIgnoreErr(runDetail), }) if err != nil { return errors.Wrap(err, "Failed to update TaskRun") @@ -435,126 +489,3 @@ func (h *TraceHubServiceImpl) updateTaskRunDetail(ctx context.Context, info *Tas return nil } - -func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, spaceID string) ([]*entity.ObservabilityTask, error) { - var taskPOs []*entity.ObservabilityTask - nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTask(ctx, spaceID) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return nil, err - } - logs.CtxInfo(ctx, "Start listing non-final tasks, taskCount:%d, nonFinalTaskIDs:%v", len(nonFinalTaskIDs), nonFinalTaskIDs) - if len(nonFinalTaskIDs) == 0 { - return taskPOs, nil - } - for _, taskID := range nonFinalTaskIDs { - taskPO, err := h.taskRepo.GetTaskByRedis(ctx, taskID) - if err != nil { - logs.CtxError(ctx, "Failed to get task", "err", err) - return nil, err - } - if taskPO == nil { - continue - } - taskPOs = append(taskPOs, taskPO) - } - return taskPOs, nil -} - -func (h *TraceHubServiceImpl) listNonFinalTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { - var taskPOs []*entity.ObservabilityTask - var offset int32 = 0 - const limit int32 = 500 - // Paginate through all tasks - for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ - ReqLimit: limit, - ReqOffset: offset, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ - { - FieldName: ptr.Of(filter.TaskFieldNameTaskStatus), - Values: []string{ - string(task.TaskStatusUnstarted), - string(task.TaskStatusRunning), - string(task.TaskStatusPending), - }, - QueryType: ptr.Of(filter.QueryTypeIn), - FieldType: ptr.Of(filter.FieldTypeString), - }, - }, - }, - }) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return nil, err - } - - // Add tasks from the current page to the full list - taskPOs = append(taskPOs, tasklist...) - - // If fewer tasks than limit are returned, this is the last page - if len(tasklist) < int(limit) { - break - } - - // Move to the next page, increasing offset by 1000 - offset += limit - } - return taskPOs, nil -} - -func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { - var taskDOs []*entity.ObservabilityTask - taskDOs, err := h.listNonFinalTask(ctx) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return nil, err - } - var offset int32 = 0 - const limit int32 = 1000 - // Paginate through all tasks - for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ - ReqLimit: limit, - ReqOffset: offset, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ - { - FieldName: ptr.Of(filter.TaskFieldNameTaskStatus), - Values: []string{ - string(task.TaskStatusSuccess), - string(task.TaskStatusDisabled), - }, - QueryType: ptr.Of(filter.QueryTypeIn), - FieldType: ptr.Of(filter.FieldTypeString), - }, - { - FieldName: ptr.Of("updated_at"), - Values: []string{ - fmt.Sprintf("%d", time.Now().Add(-24*time.Hour).UnixMilli()), - }, - QueryType: ptr.Of(filter.QueryTypeGt), - FieldType: ptr.Of(filter.FieldTypeLong), - }, - }, - }, - }) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - break - } - - // Add tasks from the current page to the full list - taskDOs = append(taskDOs, tasklist...) - - // If fewer tasks than limit are returned, this is the last page - if len(tasklist) < int(limit) { - break - } - - // Move to the next page, increasing offset by 1000 - offset += limit - } - return taskDOs, nil -} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index a15dcc6e3..b66c316c1 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -11,8 +11,6 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -35,99 +33,92 @@ const ( // 定时任务+锁 func (h *TraceHubServiceImpl) BackFill(ctx context.Context, event *entity.BackFillEvent) error { // 1. Set the current task context - ctx = h.fillCtx(ctx) logs.CtxInfo(ctx, "BackFill msg %+v", event) var ( lockKey string lockCancel func() ) - if h.locker != nil && event != nil { - lockKey = fmt.Sprintf(backfillLockKeyTemplate, event.TaskID) - locked, lockCtx, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, backfillLockTTL, backfillLockMaxHold) - if lockErr != nil { - logs.CtxError(ctx, "backfill acquire lock failed", "task_id", event.TaskID, "err", lockErr) - return lockErr + + if h.locker != nil { + var err error + ctx, lockCancel, lockKey, err = h.acquireBackfillLock(ctx, event.TaskID) + if err != nil { + return err } - if !locked { - logs.CtxInfo(ctx, "backfill lock held by others, skip execution", "task_id", event.TaskID) + + // 如果lockKey不为空,说明成功获取了锁,需要在函数退出时释放 + if lockKey != "" { + defer func(cancel func(), key string) { + if cancel != nil { + cancel() + } else if key != "" { + if _, err := h.locker.Unlock(key); err != nil { + logs.CtxWarn(ctx, "backfill release lock failed", "task_id", event.TaskID, "err", err) + } + } + }(lockCancel, lockKey) + } else if lockCancel == nil { + // 如果lockKey为空且lockCancel为nil,说明锁被其他实例持有,直接返回 return nil } - lockCancel = cancel - ctx = lockCtx - defer func(cancel func()) { - if cancel != nil { - cancel() - } else if lockKey != "" { - if _, err := h.locker.Unlock(lockKey); err != nil { - logs.CtxWarn(ctx, "backfill release lock failed", "task_id", event.TaskID, "err", err) - } - } - }(lockCancel) } - sub, err := h.setBackfillTask(ctx, event) + sub, err := h.buildSubscriber(ctx, event) if err != nil { return err } + if sub == nil || sub.t == nil { + return errors.New("subscriber or task config not found") + } - if sub != nil && sub.t != nil && sub.t.GetBaseInfo() != nil && sub.t.GetBaseInfo().GetCreatedBy() != nil { - ctx = session.WithCtxUser(ctx, &session.User{ID: sub.t.GetBaseInfo().GetCreatedBy().GetUserID()}) + // todo tyf 是否需要 + if sub.t != nil && sub.t.CreatedBy != "" { + ctx = session.WithCtxUser(ctx, &session.User{ID: sub.t.CreatedBy}) } // 2. Determine whether the backfill task is completed to avoid repeated execution isDone, err := h.isBackfillDone(ctx, sub) if err != nil { - logs.CtxError(ctx, "check backfill task done failed, task_id=%d, err=%v", sub.t.GetID(), err) + logs.CtxError(ctx, "check backfill task done failed, task_id=%d, err=%v", sub.t.ID, err) return err } if isDone { - logs.CtxInfo(ctx, "backfill already completed, task_id=%d", sub.t.GetID()) + logs.CtxInfo(ctx, "backfill already completed, task_id=%d", sub.t.ID) return nil } - // 顺序执行时重置 flush 错误收集器 - h.flushErrLock.Lock() - h.flushErr = nil - h.flushErrLock.Unlock() - // 5. Retrieve span data from the observability service - listErr := h.listAndSendSpans(ctx, sub) - if listErr != nil { - logs.CtxError(ctx, "list spans failed, task_id=%d, err=%v", sub.t.GetID(), listErr) - } + err = h.listAndSendSpans(ctx, sub) - // 6. Synchronously wait for completion to ensure all data is processed - return h.onHandleDone(ctx, listErr, sub) + return h.onHandleDone(ctx, err, sub) } -// setBackfillTask sets the context for the current backfill task -func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity.BackFillEvent) (*spanSubscriber, error) { - taskConfig, err := h.taskRepo.GetTask(ctx, event.TaskID, nil, nil) +// buildSubscriber sets the context for the current backfill task +func (h *TraceHubServiceImpl) buildSubscriber(ctx context.Context, event *entity.BackFillEvent) (*spanSubscriber, error) { + taskDO, err := h.taskRepo.GetTask(ctx, event.TaskID, nil, nil) if err != nil { logs.CtxError(ctx, "get task config failed, task_id=%d, err=%v", event.TaskID, err) return nil, err } - if taskConfig == nil { + if taskDO == nil { return nil, errors.New("task config not found") } - taskConfigDO := tconv.TaskDO2DTO(ctx, taskConfig, nil) - taskRun, err := h.taskRepo.GetBackfillTaskRun(ctx, ptr.Of(taskConfigDO.GetWorkspaceID()), taskConfigDO.GetID()) - if err != nil { - logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", taskConfigDO.GetID(), err) - return nil, err + + taskRun := taskDO.GetBackfillTaskRun() + if taskRun == nil { + logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", taskDO.ID) + return nil, errors.New("get backfill task run not found") } - taskRunDTO := tconv.TaskRunDO2DTO(ctx, taskRun, nil) - proc := h.taskProcessor.GetTaskProcessor(taskConfig.TaskType) + + proc := h.taskProcessor.GetTaskProcessor(taskDO.TaskType) sub := &spanSubscriber{ - taskID: taskConfigDO.GetID(), - t: taskConfigDO, - tr: taskRunDTO, - processor: proc, - bufCap: 0, - maxFlushInterval: time.Second * 5, - taskRepo: h.taskRepo, - runType: task.TaskRunTypeBackFill, + taskID: taskDO.ID, + t: taskDO, + tr: taskRun, + processor: proc, + taskRepo: h.taskRepo, + runType: entity.TaskRunTypeBackFill, } return sub, nil @@ -136,7 +127,7 @@ func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity // isBackfillDone checks whether the backfill task has been completed func (h *TraceHubServiceImpl) isBackfillDone(ctx context.Context, sub *spanSubscriber) (bool, error) { if sub.tr == nil { - logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", sub.t.GetID(), nil) + logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", sub.t.ID, nil) return true, nil } @@ -144,10 +135,10 @@ func (h *TraceHubServiceImpl) isBackfillDone(ctx context.Context, sub *spanSubsc } func (h *TraceHubServiceImpl) listAndSendSpans(ctx context.Context, sub *spanSubscriber) error { - backfillTime := sub.t.GetRule().GetBackfillEffectiveTime() - tenants, err := h.getTenants(ctx, loop_span.PlatformType(sub.t.GetRule().GetSpanFilters().GetPlatformType())) + backfillTime := sub.t.BackfillEffectiveTime + tenants, err := h.getTenants(ctx, sub.t.SpanFilter.PlatformType) if err != nil { - logs.CtxError(ctx, "get tenants failed, task_id=%d, err=%v", sub.t.GetID(), err) + logs.CtxError(ctx, "get tenants failed, task_id=%d, err=%v", sub.t.ID, err) return err } @@ -155,18 +146,47 @@ func (h *TraceHubServiceImpl) listAndSendSpans(ctx context.Context, sub *spanSub listParam := &repo.ListSpansParam{ Tenants: tenants, Filters: h.buildSpanFilters(ctx, sub.t), - StartAt: backfillTime.GetStartAt(), - EndAt: backfillTime.GetEndAt(), + StartAt: backfillTime.StartAt, + EndAt: backfillTime.EndAt, Limit: pageSize, // Page size DescByStartTime: true, NotQueryAnnotation: true, // No annotation query required during backfill } - if sub.tr.BackfillRunDetail != nil && sub.tr.BackfillRunDetail.LastSpanPageToken != nil { - listParam.PageToken = *sub.tr.BackfillRunDetail.LastSpanPageToken + if sub.tr.BackfillDetail != nil && sub.tr.BackfillDetail.LastSpanPageToken != nil { + listParam.PageToken = *sub.tr.BackfillDetail.LastSpanPageToken + } + + totalCount := int64(0) + for { + logs.CtxInfo(ctx, "TaskID: %d, ListSpansParam:%v", sub.t.ID, listParam) + spans, pageToken, err := h.fetchSpans(ctx, listParam, sub) + if err != nil { + logs.CtxError(ctx, "list spans failed, task_id=%d, err=%v", sub.t.ID, err) + return err + } + + err, shouldFinish := h.flushSpans(ctx, spans, sub) + if err != nil { + return err + } + + totalCount += int64(len(spans)) + logs.CtxInfo(ctx, "Processed %d spans completed, total=%d, task_id=%d", len(spans), totalCount, sub.t.ID) + + if pageToken == "" || shouldFinish { + logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.ID) + if err = sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, + TaskRun: sub.tr, + IsFinish: false, + }); err != nil { + return err + } + return nil + } + listParam.PageToken = pageToken } - // Paginate query and send data - return h.fetchAndSendSpans(ctx, listParam, sub) } type ListSpansReq struct { @@ -183,22 +203,26 @@ type ListSpansReq struct { } // buildSpanFilters constructs span filter conditions -func (h *TraceHubServiceImpl) buildSpanFilters(ctx context.Context, taskConfig *task.Task) *loop_span.FilterFields { +func (h *TraceHubServiceImpl) buildSpanFilters(ctx context.Context, taskConfig *entity.ObservabilityTask) *loop_span.FilterFields { // More complex filters can be built based on the task configuration // Simplified here: return nil to indicate no additional filters - platformFilter, err := h.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(taskConfig.GetRule().GetSpanFilters().GetPlatformType())) + platformFilter, err := h.buildHelper.BuildPlatformRelatedFilter(ctx, taskConfig.SpanFilter.PlatformType) if err != nil { + logs.CtxError(ctx, "build platform filter failed, task_id=%d, err=%v", taskConfig.ID, err) + // 不需要重试 return nil } builtinFilter, err := h.buildBuiltinFilters(ctx, platformFilter, &ListSpansReq{ - WorkspaceID: taskConfig.GetWorkspaceID(), - SpanListType: loop_span.SpanListType(taskConfig.GetRule().GetSpanFilters().GetSpanListType()), + WorkspaceID: taskConfig.WorkspaceID, + SpanListType: taskConfig.SpanFilter.SpanListType, }) if err != nil { + logs.CtxError(ctx, "build builtin filter failed, task_id=%d, err=%v", taskConfig.ID, err) + // 不需要重试 return nil } - filters := h.combineFilters(builtinFilter, convertor.FilterFieldsDTO2DO(taskConfig.GetRule().GetSpanFilters().GetFilters())) + filters := h.combineFilters(builtinFilter, &taskConfig.SpanFilter.Filters) return filters } @@ -262,138 +286,86 @@ func (h *TraceHubServiceImpl) combineFilters(filters ...*loop_span.FilterFields) return filterAggr } -// fetchAndSendSpans paginates and sends span data -func (h *TraceHubServiceImpl) fetchAndSendSpans(ctx context.Context, listParam *repo.ListSpansParam, sub *spanSubscriber) error { - totalCount := int64(0) - pageToken := listParam.PageToken - for { - logs.CtxInfo(ctx, "ListSpansParam:%v", listParam) - result, err := h.traceRepo.ListSpans(ctx, listParam) - if err != nil { - logs.CtxError(ctx, "list spans failed, task_id=%d, page_token=%s, err=%v", sub.t.GetID(), pageToken, err) - return err - } - spans := result.Spans - processors, err := h.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ - WorkspaceId: sub.t.GetWorkspaceID(), - PlatformType: loop_span.PlatformType(sub.t.GetRule().GetSpanFilters().GetPlatformType()), - QueryStartTime: listParam.StartAt, - QueryEndTime: listParam.EndAt, - }) - if err != nil { - return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) - } - for _, p := range processors { - spans, err = p.Transform(ctx, spans) - if err != nil { - return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) - } - } - - if len(spans) > 0 { - flush := &flushReq{ - retrievedSpanCount: int64(len(spans)), - pageToken: result.PageToken, - spans: spans, - noMore: !result.HasMore, - } - - if err = h.flushSpans(ctx, flush, sub); err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return err - } - } - - totalCount += int64(len(spans)) - logs.CtxInfo(ctx, "processed %d spans, total=%d, task_id=%d", len(spans), totalCount, sub.t.GetID()) - } - - if !result.HasMore { - logs.CtxInfo(ctx, "completed listing spans, total_count=%d, task_id=%d", totalCount, sub.t.GetID()) - break - } - listParam.PageToken = result.PageToken - pageToken = result.PageToken +// fetchSpans paginates span data +func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.ListSpansParam, + sub *spanSubscriber) ([]*loop_span.Span, string, error) { + result, err := h.traceRepo.ListSpans(ctx, listParam) + if err != nil { + logs.CtxError(ctx, "List spans failed, parma=%v, err=%v", listParam, err) + return nil, "", err } - - return nil -} - -func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, fr *flushReq, sub *spanSubscriber) error { - if ctx.Err() != nil { - return ctx.Err() + logs.CtxInfo(ctx, "Fetch %d spans", len(result.Spans)) + spans := result.Spans + if len(spans) == 0 { + return nil, "", nil } - _, _, err := h.doFlush(ctx, fr, sub) + processors, err := h.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ + WorkspaceId: sub.t.WorkspaceID, + PlatformType: sub.t.SpanFilter.PlatformType, + QueryStartTime: listParam.StartAt, + QueryEndTime: listParam.EndAt, + }) if err != nil { - logs.CtxError(ctx, "flush spans failed, task_id=%d, err=%v", sub.t.GetID(), err) - h.flushErrLock.Lock() - h.flushErr = append(h.flushErr, err) - h.flushErrLock.Unlock() + return nil, "", errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } + for _, p := range processors { + spans, err = p.Transform(ctx, spans) + if err != nil { + return nil, "", errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } } - return nil + if !result.HasMore { + logs.CtxInfo(ctx, "Completed listing spans, task_id=%d", sub.t.ID) + return spans, "", nil + } + + return spans, result.PageToken, nil } -func (h *TraceHubServiceImpl) doFlush(ctx context.Context, fr *flushReq, sub *spanSubscriber) (flushed, sampled int, _ error) { - if fr == nil || len(fr.spans) == 0 { - return 0, 0, nil +func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { + logs.CtxInfo(ctx, "Start processing %d spans for backfill, task_id=%d", len(spans), sub.t.ID) + if len(spans) == 0 { + return nil, false } - logs.CtxInfo(ctx, "processing %d spans for backfill, task_id=%d", len(fr.spans), sub.t.GetID()) - // Apply sampling logic - sampledSpans := h.applySampling(fr.spans, sub) + sampledSpans := h.applySampling(spans, sub) if len(sampledSpans) == 0 { - logs.CtxInfo(ctx, "no spans after sampling, task_id=%d", sub.t.GetID()) - return len(fr.spans), 0, nil + logs.CtxInfo(ctx, "no spans after sampling, task_id=%d", sub.t.ID) + return nil, false } // Execute specific business logic - err := h.processSpansForBackfill(ctx, sampledSpans, sub) + err, shouldFinish = h.processSpansForBackfill(ctx, sampledSpans, sub) if err != nil { - logs.CtxError(ctx, "process spans failed, task_id=%d, err=%v", sub.t.GetID(), err) - return len(fr.spans), len(sampledSpans), err + logs.CtxError(ctx, "process spans failed, task_id=%d, err=%v", sub.t.ID, err) + return } - sub.tr.BackfillRunDetail = &task.BackfillDetail{ - LastSpanPageToken: ptr.Of(fr.pageToken), - } + // todo 不应该这里直接写po字段 err = h.taskRepo.UpdateTaskRunWithOCC(ctx, sub.tr.ID, sub.tr.WorkspaceID, map[string]interface{}{ - "backfill_detail": ToJSONString(ctx, sub.tr.BackfillRunDetail), + "backfill_detail": ToJSONString(ctx, sub.tr.BackfillDetail), }) if err != nil { - logs.CtxError(ctx, "update task run failed, task_id=%d, err=%v", sub.t.GetID(), err) - return len(fr.spans), len(sampledSpans), err - } - if fr.noMore { - logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.GetID()) - if err = sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), - TaskRun: tconv.TaskRunDTO2DO(sub.tr), - IsFinish: false, - }); err != nil { - return len(fr.spans), len(sampledSpans), err - } + logs.CtxError(ctx, "update task run failed, task_id=%d, err=%v", sub.t.ID, err) + return } logs.CtxInfo(ctx, "successfully processed %d spans (sampled from %d), task_id=%d", - len(sampledSpans), len(fr.spans), sub.t.GetID()) - return len(fr.spans), len(sampledSpans), nil + len(sampledSpans), len(spans), sub.t.ID) + return } // applySampling applies sampling logic func (h *TraceHubServiceImpl) applySampling(spans []*loop_span.Span, sub *spanSubscriber) []*loop_span.Span { - if sub.t == nil || sub.t.Rule == nil { - return spans - } - - sampler := sub.t.GetRule().GetSampler() + sampler := sub.t.Sampler if sampler == nil { return spans } - sampleRate := sampler.GetSampleRate() + sampleRate := sampler.SampleRate if sampleRate >= 1.0 { return spans // 100% sampling } @@ -416,7 +388,7 @@ func (h *TraceHubServiceImpl) applySampling(spans []*loop_span.Span, sub *spanSu } // processSpansForBackfill handles spans for backfill -func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) error { +func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { // Batch processing spans for efficiency const batchSize = 50 @@ -427,79 +399,63 @@ func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans } batch := spans[i:end] - if err := h.processBatchSpans(ctx, batch, sub); err != nil { + err, shouldFinish = h.processBatchSpans(ctx, batch, sub) + if err != nil { logs.CtxError(ctx, "process batch spans failed, task_id=%d, batch_start=%d, err=%v", - sub.t.GetID(), i, err) - // Continue with the next batch without stopping due to a single failure - continue + sub.t.ID, i, err) + return + } + if shouldFinish { + return } // ml_flow rate-limited: 50/5s time.Sleep(5 * time.Second) } - return nil + return err, shouldFinish } // processBatchSpans processes a batch of span data -func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) error { +func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { for _, span := range spans { // Execute processing logic according to the task type logs.CtxInfo(ctx, "processing span for backfill, span_id=%s, trace_id=%s, task_id=%d", - span.SpanID, span.TraceID, sub.t.GetID()) + span.SpanID, span.TraceID, sub.t.ID) taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) - taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, sub.tr.GetID()) - sampler := sub.t.GetRule().GetSampler() - if taskCount+1 > sampler.GetSampleSize() { - logs.CtxWarn(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.GetSampleSize()) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), - TaskRun: tconv.TaskRunDTO2DO(sub.tr), - IsFinish: true, - }); err != nil { - return err - } - break + sampler := sub.t.Sampler + if taskCount+1 > sampler.SampleSize { + logs.CtxInfo(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.SampleSize) + return nil, true } - logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) - if err := h.dispatch(ctx, span, []*spanSubscriber{sub}); err != nil { - return err + if err = h.dispatch(ctx, span, []*spanSubscriber{sub}); err != nil { + return err, false } } - return nil + return nil, false } // onHandleDone handles completion callback -func (h *TraceHubServiceImpl) onHandleDone(ctx context.Context, listErr error, sub *spanSubscriber) error { - // Collect all errors - h.flushErrLock.Lock() - allErrors := append([]error{}, h.flushErr...) - if listErr != nil { - allErrors = append(allErrors, listErr) - } - h.flushErrLock.Unlock() - - if len(allErrors) > 0 { - backfillEvent := &entity.BackFillEvent{ - SpaceID: sub.t.GetWorkspaceID(), - TaskID: sub.t.GetID(), - } - - // Send MQ message asynchronously without blocking task creation flow - go func() { - if time.Now().UnixMilli()-(sub.tr.RunEndAt-sub.tr.RunStartAt) < sub.tr.RunEndAt { - if err := h.sendBackfillMessage(context.Background(), backfillEvent); err != nil { - logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", sub.t.GetID(), err) - } - } - }() - logs.CtxWarn(ctx, "backfill completed with %d errors, task_id=%d", len(allErrors), sub.t.GetID()) - // Return the first error as a representative - return allErrors[0] +func (h *TraceHubServiceImpl) onHandleDone(ctx context.Context, err error, sub *spanSubscriber) error { + if err == nil { + logs.CtxInfo(ctx, "backfill completed successfully, task_id=%d", sub.t.ID) + return nil + } + // failed, need retry + logs.CtxWarn(ctx, "backfill completed with error: %v, task_id=%d", err, sub.t.ID) + backfillEvent := &entity.BackFillEvent{ + SpaceID: sub.t.WorkspaceID, + TaskID: sub.t.ID, } - logs.CtxInfo(ctx, "backfill completed successfully, task_id=%d", sub.t.GetID()) + if time.Now().UnixMilli()-(sub.tr.RunEndAt.UnixMilli()-sub.tr.RunStartAt.UnixMilli()) < sub.tr.RunEndAt.UnixMilli() { + if sendErr := h.sendBackfillMessage(context.Background(), backfillEvent); sendErr != nil { + logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", sub.t.ID, sendErr) + return sendErr + } + } + // 依靠MQ进行重试 return nil } @@ -511,3 +467,21 @@ func (h *TraceHubServiceImpl) sendBackfillMessage(ctx context.Context, event *en return h.backfillProducer.SendBackfill(ctx, event) } + +// acquireBackfillLock 尝试获取回填任务的分布式锁 +// 返回值: 新的上下文, 取消函数, 锁键, 错误 +func (h *TraceHubServiceImpl) acquireBackfillLock(ctx context.Context, taskID int64) (context.Context, func(), string, error) { + lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskID) + locked, lockCtx, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, backfillLockTTL, backfillLockMaxHold) + if lockErr != nil { + logs.CtxError(ctx, "backfill acquire lock failed", "task_id", taskID, "err", lockErr) + return ctx, nil, "", lockErr + } + + if !locked { + logs.CtxInfo(ctx, "backfill lock held by others, skip execution", "task_id", taskID) + return ctx, nil, "", nil + } + + return lockCtx, cancel, lockKey, nil +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index dc0efd4a4..fa586c272 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -15,18 +15,17 @@ import ( lockmock "github.com/coze-dev/coze-loop/backend/infra/lock/mocks" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" tenant_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - repo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" trepo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" builder_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/mocks" spanfilter_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter/mocks" - span_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" @@ -43,7 +42,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { mockRepo := repo_mocks.NewMockITaskRepo(ctrl) taskProcessor := processor.NewTaskProcessor() proc := &stubProcessor{} - taskProcessor.Register(task.TaskTypeAutoEval, proc) + taskProcessor.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -54,7 +53,8 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { obsTask := &entity.ObservabilityTask{ ID: 1, WorkspaceID: 1, - TaskType: task.TaskTypeAutoEval, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, SpanFilter: &entity.SpanFilterFields{ Filters: loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), @@ -70,20 +70,21 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { ID: 2, TaskID: 1, WorkspaceID: 1, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Minute), RunEndAt: now.Add(time.Minute), } + obsTask.TaskRuns = []*entity.TaskRun{backfillRun} + mockRepo.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Nil(), gomock.Nil()).Return(obsTask, nil) - mockRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), int64(1)).Return(backfillRun, nil) - sub, err := impl.setBackfillTask(context.Background(), &entity.BackFillEvent{TaskID: 1}) + sub, err := impl.buildSubscriber(context.Background(), &entity.BackFillEvent{TaskID: 1}) require.NoError(t, err) require.NotNil(t, sub) require.Equal(t, int64(1), sub.taskID) - require.Equal(t, task.TaskRunTypeBackFill, sub.runType) + require.Equal(t, entity.TaskRunTypeBackFill, sub.runType) } func TestTraceHubServiceImpl_SetBackfillTaskNotFound(t *testing.T) { @@ -97,7 +98,7 @@ func TestTraceHubServiceImpl_SetBackfillTaskNotFound(t *testing.T) { mockRepo.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Nil(), gomock.Nil()).Return(nil, nil) - _, err := impl.setBackfillTask(context.Background(), &entity.BackFillEvent{TaskID: 1}) + _, err := impl.buildSubscriber(context.Background(), &entity.BackFillEvent{TaskID: 1}) require.Error(t, err) } @@ -113,48 +114,56 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_TaskLimit(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} now := time.Now() - sampler := &task.Sampler{ - SampleRate: floatPtr(1), - SampleSize: int64Ptr(1), - IsCycle: boolPtr(false), - CycleInterval: int64Ptr(0), + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 1, + IsCycle: false, + CycleInterval: 0, } - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(1)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - Sampler: sampler, - EffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.Add(time.Hour).UnixMilli()), - }, - }, + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + EffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, } - taskRunDTO := &task.TaskRun{ - ID: 10, - TaskRunConfig: &task.TaskRunConfig{}, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 1, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), } sub := &spanSubscriber{ taskID: 1, - t: taskDTO, - tr: taskRunDTO, + t: taskDO, + tr: taskRun, processor: proc, taskRepo: mockRepo, + runType: entity.TaskRunTypeBackFill, } - mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(1), nil) - mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), int64(10)).Return(int64(0), nil) + mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil).AnyTimes() + mockRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(&entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 2, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: time.Now().Add(-time.Minute), + RunEndAt: time.Now().Add(time.Minute), + }, nil) - spans := []*loop_span.Span{{SpanID: "span-1"}} + spans := []*loop_span.Span{{SpanID: "span-1", StartTime: time.Now().UnixMilli()}} ctx := context.Background() - require.NoError(t, impl.processBatchSpans(ctx, spans, sub)) - require.Equal(t, 1, proc.finishChangeInvoked) + err, shouldFinish := impl.processBatchSpans(ctx, spans, sub) + require.NoError(t, err) + require.False(t, shouldFinish) + require.True(t, proc.invokeCalled) } func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { @@ -169,37 +178,35 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} now := time.Now() - sampler := &task.Sampler{ - SampleRate: floatPtr(1), - SampleSize: int64Ptr(2), - IsCycle: boolPtr(false), - CycleInterval: int64Ptr(0), + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 2, + IsCycle: false, + CycleInterval: 0, } - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(1)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - Sampler: sampler, - EffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.Add(time.Hour).UnixMilli()), - }, - }, + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + EffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, } - taskRunDTO := &task.TaskRun{ - ID: 10, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 1, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), } sub := &spanSubscriber{ taskID: 1, - t: taskDTO, - tr: taskRunDTO, + t: taskDO, + tr: taskRun, processor: proc, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, taskRepo: mockRepo, } @@ -207,19 +214,18 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { ID: 20, TaskID: 1, WorkspaceID: 1, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.RunStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Minute), RunEndAt: now.Add(time.Minute), } mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), int64(10)).Return(int64(0), nil) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(spanRun, nil) spans := []*loop_span.Span{{SpanID: "span-1", StartTime: now.Add(10 * time.Millisecond).UnixMilli(), WorkspaceID: "space", TraceID: "trace"}} - err := impl.processBatchSpans(context.Background(), spans, sub) + err, _ := impl.processBatchSpans(context.Background(), spans, sub) require.Error(t, err) require.ErrorContains(t, err, "invoke fail") } @@ -268,27 +274,22 @@ func TestTraceHubServiceImpl_ListAndSendSpans_GetTenantsError(t *testing.T) { impl := &TraceHubServiceImpl{tenantProvider: tenantProvider} now := time.Now() - taskStatus := task.TaskStatusRunning + spanFilters := &entity.SpanFilterFields{ + PlatformType: loop_span.PlatformType(common.PlatformTypeCozeBot), + SpanListType: loop_span.SpanListType(common.SpanListTypeRootSpan), + Filters: loop_span.FilterFields{FilterFields: []*loop_span.FilterField{}}, + } sub := &spanSubscriber{ - t: &task.Task{ - ID: ptr.Of(int64(1)), - Name: "task", - WorkspaceID: ptr.Of(int64(2)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: &taskStatus, - Rule: &task.Rule{ - SpanFilters: &filter.SpanFilterFields{ - PlatformType: ptr.Of(common.PlatformType(common.PlatformTypeCozeBot)), - SpanListType: ptr.Of(common.SpanListTypeRootSpan), - Filters: &filter.FilterFields{FilterFields: []*filter.FilterField{}}, - }, - BackfillEffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.UnixMilli()), - }, - }, + t: &entity.ObservabilityTask{ + ID: 1, + Name: "task", + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + SpanFilter: spanFilters, + BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.UnixMilli()}, }, - tr: &task.TaskRun{}, + tr: &entity.TaskRun{}, } tenantErr := errors.New("tenant failed") @@ -319,7 +320,7 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { now := time.Now() sub, proc := newBackfillSubscriber(mockTaskRepo, now) - sub.tr.BackfillRunDetail = &task.BackfillDetail{LastSpanPageToken: ptr.Of("prev")} + sub.tr.BackfillDetail = &entity.BackfillDetail{LastSpanPageToken: ptr.Of("prev")} domainRun := newDomainBackfillTaskRun(now) span := newTestSpan(now) @@ -341,34 +342,46 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { }) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(nil) err := impl.listAndSendSpans(context.Background(), sub) require.NoError(t, err) require.True(t, proc.invokeCalled) - require.NotNil(t, sub.tr.BackfillRunDetail) - require.Equal(t, "next", sub.tr.BackfillRunDetail.GetLastSpanPageToken()) + require.NotNil(t, sub.tr.BackfillDetail) + require.NotNil(t, sub.tr.BackfillDetail.LastSpanPageToken) + require.Equal(t, "prev", ptr.From(sub.tr.BackfillDetail.LastSpanPageToken)) } -func TestTraceHubServiceImpl_FetchAndSendSpans_ListError(t *testing.T) { +func TestTraceHubServiceImpl_ListAndSendSpans_ListError(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) mockTraceRepo := trepo_mocks.NewMockITraceRepo(ctrl) + mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) + mockBuilder := builder_mocks.NewMockTraceFilterProcessorBuilder(ctrl) + filterMock := spanfilter_mocks.NewMockFilter(ctrl) + impl := &TraceHubServiceImpl{ - taskRepo: mockTaskRepo, - traceRepo: mockTraceRepo, + taskRepo: mockTaskRepo, + traceRepo: mockTraceRepo, + tenantProvider: mockTenant, + buildHelper: mockBuilder, } now := time.Now() sub, _ := newBackfillSubscriber(mockTaskRepo, now) + mockBuilder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)). + Return(filterMock, nil) + filterMock.EXPECT().BuildBasicSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, true, nil) + filterMock.EXPECT().BuildRootSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, nil) + mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)).Return([]string{"tenant"}, nil) + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.Any()).Return(nil, errors.New("list failed")) - err := impl.fetchAndSendSpans(context.Background(), &repo.ListSpansParam{Tenants: []string{"tenant"}}, sub) + err := impl.listAndSendSpans(context.Background(), sub) require.Error(t, err) } @@ -377,9 +390,19 @@ func TestTraceHubServiceImpl_FlushSpans_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - err := impl.flushSpans(ctx, &flushReq{}, &spanSubscriber{}) - require.Error(t, err) - require.ErrorIs(t, err, context.Canceled) + sub := &spanSubscriber{ + t: &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + }, + tr: &entity.TaskRun{ + ID: 1, + WorkspaceID: 1, + }, + } + + err, _ := impl.flushSpans(ctx, []*loop_span.Span{}, sub) + require.NoError(t, err) } func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { @@ -390,19 +413,18 @@ func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockTaskRepo} now := time.Now() - sub, _ := newBackfillSubscriber(mockTaskRepo, now) + sub, proc := newBackfillSubscriber(mockTaskRepo, now) span := newTestSpan(now) domainRun := newDomainBackfillTaskRun(now) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(errors.New("update fail")) - flushed, sampled, err := impl.doFlush(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}}, sub) - require.Equal(t, 1, flushed) - require.Equal(t, 1, sampled) + err, _ := impl.flushSpans(context.Background(), []*loop_span.Span{span}, sub) require.Error(t, err) + require.ErrorContains(t, err, "update fail") + require.True(t, proc.invokeCalled) } func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { @@ -419,39 +441,46 @@ func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { domainRun := newDomainBackfillTaskRun(now) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(nil) - flushed, sampled, err := impl.doFlush(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}, noMore: true}, sub) - require.Equal(t, 1, flushed) - require.Equal(t, 1, sampled) - require.Error(t, err) - require.ErrorContains(t, err, "finish fail") + // 调用flushSpans,然后手动调用OnTaskFinished来触发finish错误 + err, _ := impl.flushSpans(context.Background(), []*loop_span.Span{span}, sub) + require.NoError(t, err) // flushSpans本身不应该返回错误 + + // 手动调用OnTaskFinished来触发finish错误 + finishErr := sub.processor.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ + Task: sub.t, + TaskRun: sub.tr, + IsFinish: true, + }) + require.Error(t, finishErr) + require.ErrorContains(t, finishErr, "finish fail") + require.True(t, proc.invokeCalled) } -func TestTraceHubServiceImpl_DoFlush_SamplingZero(t *testing.T) { +func TestTraceHubServiceImpl_FlushSpans_SamplingZero(t *testing.T) { impl := &TraceHubServiceImpl{} sub := &spanSubscriber{ - t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0))}}}, + t: &entity.ObservabilityTask{ + Sampler: &entity.Sampler{SampleRate: 0}, + }, } - fr := &flushReq{retrievedSpanCount: 2, spans: []*loop_span.Span{{SpanID: "s1"}, {SpanID: "s2"}}} + spans := []*loop_span.Span{{SpanID: "s1"}, {SpanID: "s2"}} - flushed, sampled, err := impl.doFlush(context.Background(), fr, sub) + err, _ := impl.flushSpans(context.Background(), spans, sub) require.NoError(t, err) - require.Equal(t, 2, flushed) - require.Zero(t, sampled) } func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Parallel() impl := &TraceHubServiceImpl{} - taskDTO := &task.Task{ID: ptr.Of(int64(1))} + taskDO := &entity.ObservabilityTask{ID: 1} t.Run("nil task run", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO} + sub := &spanSubscriber{t: taskDO} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.True(t, isDone) @@ -459,7 +488,7 @@ func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Run("status running", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO, tr: &task.TaskRun{RunStatus: task.RunStatusRunning}} + sub := &spanSubscriber{t: taskDO, tr: &entity.TaskRun{RunStatus: entity.TaskRunStatusRunning}} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.False(t, isDone) @@ -467,7 +496,7 @@ func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Run("status done", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO, tr: &task.TaskRun{RunStatus: task.RunStatusDone}} + sub := &spanSubscriber{t: taskDO, tr: &entity.TaskRun{RunStatus: entity.TaskRunStatusDone}} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.True(t, isDone) @@ -556,15 +585,15 @@ func TestTraceHubServiceImpl_ApplySampling(t *testing.T) { impl := &TraceHubServiceImpl{} spans := []*loop_span.Span{{SpanID: "1"}, {SpanID: "2"}, {SpanID: "3"}} - sub := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(1.0))}}}} + sub := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 1.0}}} res := impl.applySampling(spans, sub) require.Len(t, res, 3) - subZero := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0.0))}}}} + subZero := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0}}} resZero := impl.applySampling(spans, subZero) require.Nil(t, resZero) - subHalf := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0.4))}}}} + subHalf := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0.4}}} resHalf := impl.applySampling(spans, subHalf) require.Len(t, resHalf, 1) require.Equal(t, spans[:1], resHalf) @@ -579,24 +608,22 @@ func TestTraceHubServiceImpl_OnHandleDone(t *testing.T) { now := time.Now() impl := &TraceHubServiceImpl{ backfillProducer: &stubBackfillProducer{ch: ch}, - flushErr: []error{errors.New("flush err"), errors.New("other")}, } sub := &spanSubscriber{ - t: &task.Task{ID: ptr.Of(int64(10)), WorkspaceID: ptr.Of(int64(20))}, - tr: &task.TaskRun{ + t: &entity.ObservabilityTask{ID: 10, WorkspaceID: 20}, + tr: &entity.TaskRun{ ID: 1, WorkspaceID: 20, TaskID: 10, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Hour).UnixMilli(), - RunEndAt: now.Add(time.Hour).UnixMilli(), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Hour), + RunEndAt: now.Add(time.Hour), }, } - err := impl.onHandleDone(context.Background(), nil, sub) - require.Error(t, err) - require.EqualError(t, err, "flush err") + err := impl.onHandleDone(context.Background(), errors.New("flush err"), sub) + require.NoError(t, err) select { case msg := <-ch: @@ -613,15 +640,15 @@ func TestTraceHubServiceImpl_OnHandleDone(t *testing.T) { now := time.Now() impl := &TraceHubServiceImpl{backfillProducer: &stubBackfillProducer{ch: ch}} sub := &spanSubscriber{ - t: &task.Task{ID: ptr.Of(int64(10)), WorkspaceID: ptr.Of(int64(20))}, - tr: &task.TaskRun{ + t: &entity.ObservabilityTask{ID: 10, WorkspaceID: 20}, + tr: &entity.TaskRun{ ID: 1, WorkspaceID: 20, TaskID: 10, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusDone, - RunStartAt: now.Add(-time.Hour).UnixMilli(), - RunEndAt: now.UnixMilli(), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Hour), + RunEndAt: now.Add(time.Hour), }, } @@ -648,63 +675,57 @@ func TestTraceHubServiceImpl_SendBackfillMessage(t *testing.T) { } func newBackfillSubscriber(taskRepo taskrepo.ITaskRepo, now time.Time) (*spanSubscriber, *stubProcessor) { - sampler := &task.Sampler{ - SampleRate: ptr.Of(float64(1)), - SampleSize: ptr.Of(int64(5)), + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 5, } - filters := &filter.FilterFields{FilterFields: []*filter.FilterField{}} - spanFilters := &filter.SpanFilterFields{ - PlatformType: ptr.Of(common.PlatformType(common.PlatformTypeCozeBot)), - SpanListType: ptr.Of(common.SpanListTypeRootSpan), - Filters: filters, - } - rule := &task.Rule{ - Sampler: sampler, - SpanFilters: spanFilters, - BackfillEffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.UnixMilli()), - }, + spanFilters := &entity.SpanFilterFields{ + PlatformType: loop_span.PlatformType(common.PlatformTypeCozeBot), + SpanListType: loop_span.SpanListType(common.SpanListTypeRootSpan), + Filters: loop_span.FilterFields{FilterFields: []*loop_span.FilterField{}}, } - status := task.TaskStatusRunning - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - Name: "task", - WorkspaceID: ptr.Of(int64(2)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: &status, - Rule: rule, + taskDO := &entity.ObservabilityTask{ + ID: 1, + Name: "task", + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + SpanFilter: spanFilters, + BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.UnixMilli()}, } - taskRun := &task.TaskRun{ - ID: 10, - WorkspaceID: 2, - TaskID: 1, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + WorkspaceID: 2, + TaskID: 1, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), + BackfillDetail: &entity.BackfillDetail{}, } proc := &stubProcessor{} sub := &spanSubscriber{ taskID: 1, - t: taskDTO, + t: taskDO, tr: taskRun, processor: proc, taskRepo: taskRepo, - runType: task.TaskRunTypeBackFill, + runType: entity.TaskRunTypeBackFill, } return sub, proc } func newDomainBackfillTaskRun(now time.Time) *entity.TaskRun { return &entity.TaskRun{ - ID: 10, - TaskID: 1, - WorkspaceID: 2, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute), - RunEndAt: now.Add(time.Minute), + ID: 10, + TaskID: 1, + WorkspaceID: 2, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), + BackfillDetail: &entity.BackfillDetail{}, } } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go deleted file mode 100644 index 70f454685..000000000 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2025 coze-dev Authors -// SPDX-License-Identifier: Apache-2.0 - -package tracehub - -import ( - "context" - "fmt" - "time" - - "github.com/coze-dev/coze-loop/backend/infra/external/benefit" - "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - "github.com/coze-dev/coze-loop/backend/pkg/logs" - "github.com/samber/lo" -) - -func (h *TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { - for _, turn := range event.TurnEvalResults { - workspaceIDStr, workspaceID := turn.GetWorkspaceIDFromExt() - tenants, err := h.getTenants(ctx, loop_span.PlatformType("callback_all")) - if err != nil { - return err - } - var storageDuration int64 = 1 - res, err := h.benefitSvc.CheckTraceBenefit(ctx, &benefit.CheckTraceBenefitParams{ - ConnectorUID: turn.BaseInfo.CreatedBy.UserID, - SpaceID: workspaceID, - }) - if err != nil { - logs.CtxWarn(ctx, "fail to check trace benefit, %v", err) - } else if res == nil { - logs.CtxWarn(ctx, "fail to get trace benefit, got nil response") - } else if res != nil { - storageDuration = res.StorageDuration - } - - spans, err := h.getSpan(ctx, - tenants, - []string{turn.GetSpanIDFromExt()}, - turn.GetTraceIDFromExt(), - workspaceIDStr, - turn.GetStartTimeFromExt()/1000-(24*time.Duration(storageDuration)*time.Hour).Milliseconds(), - turn.GetStartTimeFromExt()/1000+10*time.Minute.Milliseconds(), - ) - if err != nil { - return err - } - if len(spans) == 0 { - logs.CtxWarn(ctx, "span not found, span_id: %s", turn.GetSpanIDFromExt()) - return fmt.Errorf("span not found, span_id: %s", turn.GetSpanIDFromExt()) - } - span := spans[0] - - // Newly added: write Redis counters based on the Status - err = h.updateTaskRunDetailsCount(ctx, turn.GetTaskIDFromExt(), turn, storageDuration*24*60*60) - if err != nil { - logs.CtxWarn(ctx, "更新TaskRun状态计数失败: taskID=%d, status=%d, err=%v", - turn.GetTaskIDFromExt(), turn.Status, err) - // Continue processing without interrupting the flow - } - - annotation := &loop_span.Annotation{ - SpanID: turn.GetSpanIDFromExt(), - TraceID: span.TraceID, - WorkspaceID: workspaceIDStr, - AnnotationType: loop_span.AnnotationTypeAutoEvaluate, - StartTime: time.UnixMicro(span.StartTime), - Key: fmt.Sprintf("%d:%d", turn.GetTaskIDFromExt(), turn.EvaluatorVersionID), - Value: loop_span.AnnotationValue{ - ValueType: loop_span.AnnotationValueTypeDouble, - FloatValue: turn.Score, - }, - Reasoning: turn.Reasoning, - Status: loop_span.AnnotationStatusNormal, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - Metadata: &loop_span.AutoEvaluateMetadata{ - TaskID: turn.GetTaskIDFromExt(), - EvaluatorRecordID: turn.EvaluatorRecordID, - EvaluatorVersionID: turn.EvaluatorVersionID, - }, - } - if err = annotation.GenID(); err != nil { - return err - } - - err = h.traceRepo.InsertAnnotations(ctx, &repo.InsertAnnotationParam{ - Tenant: span.GetTenant(), - TTL: span.GetTTL(ctx), - Annotations: []*loop_span.Annotation{annotation}, - }) - if err != nil { - return err - } - } - return nil -} - -func (h *TraceHubServiceImpl) Correction(ctx context.Context, event *entity.CorrectionEvent) error { - workspaceIDStr, workspaceID := event.GetWorkspaceIDFromExt() - if workspaceID == 0 { - return fmt.Errorf("workspace_id is empty") - } - tenants, err := h.getTenants(ctx, loop_span.PlatformType("callback_all")) - if err != nil { - return err - } - spans, err := h.getSpan(ctx, - tenants, - []string{event.GetSpanIDFromExt()}, - event.GetTraceIDFromExt(), - workspaceIDStr, - event.GetStartTimeFromExt()/1000-time.Second.Milliseconds(), - event.GetStartTimeFromExt()/1000+time.Second.Milliseconds(), - ) - if err != nil { - return err - } - if event.EvaluatorResult.Correction == nil || event.EvaluatorResult == nil { - return err - } - if len(spans) == 0 { - return fmt.Errorf("span not found, span_id: %s", event.GetSpanIDFromExt()) - } - span := spans[0] - annotations, err := h.traceRepo.ListAnnotations(ctx, &repo.ListAnnotationsParam{ - Tenants: tenants, - SpanID: event.GetSpanIDFromExt(), - TraceID: event.GetTraceIDFromExt(), - WorkspaceId: workspaceID, - StartAt: event.GetStartTimeFromExt() - 5*time.Second.Milliseconds(), - EndAt: event.GetStartTimeFromExt() + 5*time.Second.Milliseconds(), - }) - if err != nil { - return err - } - var annotation *loop_span.Annotation - for _, a := range annotations { - meta := a.GetAutoEvaluateMetadata() - if meta != nil && meta.EvaluatorRecordID == event.EvaluatorRecordID { - annotation = a - break - } - } - - updateBy := session.UserIDInCtxOrEmpty(ctx) - if updateBy == "" { - return err - } - annotation.CorrectAutoEvaluateScore(event.EvaluatorResult.Correction.Score, event.EvaluatorResult.Correction.Explain, updateBy) - - // Then synchronize the observability data - param := &repo.InsertAnnotationParam{ - Tenant: span.GetTenant(), - TTL: span.GetTTL(ctx), - Annotations: []*loop_span.Annotation{annotation}, - } - if err = h.traceRepo.InsertAnnotations(ctx, param); err != nil { - recordID := lo.Ternary(annotation.GetAutoEvaluateMetadata() != nil, annotation.GetAutoEvaluateMetadata().EvaluatorRecordID, 0) - // If the synchronous update fails, compensate asynchronously - // TODO: asynchronous processing has issues and may duplicate - logs.CtxWarn(ctx, "Sync upsert annotation failed, try async upsert. span_id=[%v], recored_id=[%v], err:%v", - annotation.SpanID, recordID, err) - return nil - } - return nil -} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go deleted file mode 100755 index bbea9d3e7..000000000 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) 2025 coze-dev Authors -// SPDX-License-Identifier: Apache-2.0 - -package tracehub - -import ( - "context" - "strconv" - "testing" - "time" - - "go.uber.org/mock/gomock" - - "github.com/coze-dev/coze-loop/backend/infra/external/benefit" - benefit_mocks "github.com/coze-dev/coze-loop/backend/infra/external/benefit/mocks" - tenant_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant/mocks" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - trace_repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" - "github.com/stretchr/testify/require" -) - -func TestTraceHubServiceImpl_CallBackSuccess(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockBenefit := benefit_mocks.NewMockIBenefitService(ctrl) - mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) - - impl := &TraceHubServiceImpl{ - benefitSvc: mockBenefit, - tenantProvider: mockTenant, - traceRepo: mockTraceRepo, - taskRepo: mockTaskRepo, - } - - mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), gomock.Any()).Return([]string{"tenant"}, nil).AnyTimes() - mockBenefit.EXPECT().CheckTraceBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckTraceBenefitResult{StorageDuration: 1}, nil).AnyTimes() - - now := time.Now() - span := &loop_span.Span{ - SpanID: "span-1", - TraceID: "trace-1", - SystemTagsString: map[string]string{loop_span.SpanFieldTenant: "tenant"}, - LogicDeleteTime: now.Add(24 * time.Hour).UnixMicro(), - StartTime: now.UnixMicro(), - } - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{Spans: loop_span.SpanList{span}}, nil) - mockTaskRepo.EXPECT().IncrTaskRunSuccessCount(gomock.Any(), int64(101), int64(202), gomock.Any()).Return(nil) - mockTraceRepo.EXPECT().InsertAnnotations(gomock.Any(), gomock.AssignableToTypeOf(&repo.InsertAnnotationParam{})).DoAndReturn( - func(_ context.Context, param *repo.InsertAnnotationParam) error { - require.Len(t, param.Annotations, 1) - require.Equal(t, loop_span.AnnotationTypeAutoEvaluate, param.Annotations[0].AnnotationType) - return nil - }, - ) - - startTime := now.Add(-time.Minute).UnixMilli() - event := &entity.AutoEvalEvent{ - TurnEvalResults: []*entity.OnlineExptTurnEvalResult{ - { - EvaluatorVersionID: 1, - Score: 0.9, - Reasoning: "ok", - Status: entity.EvaluatorRunStatus_Success, - BaseInfo: &entity.BaseInfo{ - CreatedBy: &entity.UserInfo{UserID: "user-1"}, - }, - Ext: map[string]string{ - "workspace_id": strconv.FormatInt(1, 10), - "span_id": "span-1", - "trace_id": "trace-1", - "start_time": strconv.FormatInt(startTime*1000, 10), - "task_id": strconv.FormatInt(101, 10), - "run_id": strconv.FormatInt(202, 10), - }, - }, - }, - } - - require.NoError(t, impl.CallBack(context.Background(), event)) -} - -func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockBenefit := benefit_mocks.NewMockIBenefitService(ctrl) - mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - - impl := &TraceHubServiceImpl{ - benefitSvc: mockBenefit, - tenantProvider: mockTenant, - traceRepo: mockTraceRepo, - } - - mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), gomock.Any()).Return([]string{"tenant"}, nil).AnyTimes() - mockBenefit.EXPECT().CheckTraceBenefit(gomock.Any(), gomock.Any()).Return(&benefit.CheckTraceBenefitResult{StorageDuration: 1}, nil).AnyTimes() - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) - - event := &entity.AutoEvalEvent{ - TurnEvalResults: []*entity.OnlineExptTurnEvalResult{ - { - Status: entity.EvaluatorRunStatus_Success, - BaseInfo: &entity.BaseInfo{ - CreatedBy: &entity.UserInfo{UserID: "user-1"}, - }, - Ext: map[string]string{ - "workspace_id": "1", - "span_id": "span-1", - "trace_id": "trace-1", - "start_time": strconv.FormatInt(time.Now().UnixMilli()*1000, 10), - "task_id": "101", - "run_id": "202", - }, - }, - }, - } - - require.Error(t, impl.CallBack(context.Background(), event)) -} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go new file mode 100644 index 000000000..dc1f09232 --- /dev/null +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go @@ -0,0 +1,54 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package tracehub + +import ( + "context" + "sync" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +const CacheKeyObjListWithTask = "ObjListWithTask" + +// TaskCacheInfo represents task cache information +type TaskCacheInfo struct { + WorkspaceIDs []string + BotIDs []string + Tasks []*entity.ObservabilityTask + UpdateTime time.Time +} + +type LocalCache struct { + taskCache sync.Map +} + +func NewLocalCache() *LocalCache { + return &LocalCache{} +} + +func (l *LocalCache) StoneTaskCache(ctx context.Context, info TaskCacheInfo) { + l.taskCache.Store(CacheKeyObjListWithTask, info) +} + +func (l *LocalCache) LoadTaskCache(ctx context.Context) TaskCacheInfo { + // First, try to retrieve tasks from cache + objListWithTask, ok := l.taskCache.Load(CacheKeyObjListWithTask) + if !ok { + // Cache is empty, fallback to the database + logs.CtxError(ctx, "Cache is empty, retrieving task list from database") + return TaskCacheInfo{} + } + + cacheInfo, ok := objListWithTask.(TaskCacheInfo) + if !ok { + logs.CtxError(ctx, "Cache data type mismatch") + return TaskCacheInfo{} + } + + logs.CtxInfo(ctx, "Retrieve task list from cache, taskCount=%d, spaceCount=%d, botCount=%d", len(cacheInfo.Tasks), len(cacheInfo.WorkspaceIDs), len(cacheInfo.BotIDs)) + return cacheInfo +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go index a46b07d6b..db391b2af 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go @@ -14,6 +14,7 @@ import ( reflect "reflect" entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" gomock "go.uber.org/mock/gomock" ) @@ -21,6 +22,7 @@ import ( type MockITraceHubService struct { ctrl *gomock.Controller recorder *MockITraceHubServiceMockRecorder + isgomock struct{} } // MockITraceHubServiceMockRecorder is the mock recorder for MockITraceHubService. @@ -41,57 +43,29 @@ func (m *MockITraceHubService) EXPECT() *MockITraceHubServiceMockRecorder { } // BackFill mocks base method. -func (m *MockITraceHubService) BackFill(arg0 context.Context, arg1 *entity.BackFillEvent) error { +func (m *MockITraceHubService) BackFill(ctx context.Context, event *entity.BackFillEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BackFill", arg0, arg1) + ret := m.ctrl.Call(m, "BackFill", ctx, event) ret0, _ := ret[0].(error) return ret0 } // BackFill indicates an expected call of BackFill. -func (mr *MockITraceHubServiceMockRecorder) BackFill(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceHubServiceMockRecorder) BackFill(ctx, event any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackFill", reflect.TypeOf((*MockITraceHubService)(nil).BackFill), arg0, arg1) -} - -// CallBack mocks base method. -func (m *MockITraceHubService) CallBack(arg0 context.Context, arg1 *entity.AutoEvalEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CallBack", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CallBack indicates an expected call of CallBack. -func (mr *MockITraceHubServiceMockRecorder) CallBack(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallBack", reflect.TypeOf((*MockITraceHubService)(nil).CallBack), arg0, arg1) -} - -// Correction mocks base method. -func (m *MockITraceHubService) Correction(arg0 context.Context, arg1 *entity.CorrectionEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Correction", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Correction indicates an expected call of Correction. -func (mr *MockITraceHubServiceMockRecorder) Correction(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Correction", reflect.TypeOf((*MockITraceHubService)(nil).Correction), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackFill", reflect.TypeOf((*MockITraceHubService)(nil).BackFill), ctx, event) } // SpanTrigger mocks base method. -func (m *MockITraceHubService) SpanTrigger(arg0 context.Context, arg1 *entity.RawSpan) error { +func (m *MockITraceHubService) SpanTrigger(ctx context.Context, span *loop_span.Span) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SpanTrigger", arg0, arg1) + ret := m.ctrl.Call(m, "SpanTrigger", ctx, span) ret0, _ := ret[0].(error) return ret0 } // SpanTrigger indicates an expected call of SpanTrigger. -func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(ctx, span any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), ctx, span) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 21374e027..8d15b254b 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -6,13 +6,9 @@ package tracehub import ( "context" "fmt" - "sync" "time" "github.com/bytedance/gg/gslice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -21,57 +17,54 @@ import ( "github.com/pkg/errors" ) -func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan) error { - ctx = h.fillCtx(ctx) - logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", rawSpan.LogID, rawSpan.TraceID, rawSpan.SpanID) - logs.CtxInfo(ctx, "auto_task start, log_suffix=%s", logSuffix) - // 1、Convert to standard span and perform initial filtering based on space_id - span := rawSpan.RawSpanConvertToLoopSpan() +func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.Span) error { + logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", span.LogID, span.TraceID, span.SpanID) + logs.CtxInfo(ctx, "auto_task start, %s", logSuffix) + + // 1. perform initial filtering based on space_id // 1.1 Filter out spans that do not belong to any space or bot - spaceIDs, botIDs, _ := h.getObjListWithTaskFromCache(ctx) + cacheInfo := h.localCache.LoadTaskCache(ctx) + spaceIDs, botIDs := cacheInfo.WorkspaceIDs, cacheInfo.BotIDs if !gslice.Contains(spaceIDs, span.WorkspaceID) && !gslice.Contains(botIDs, span.TagsString["bot_id"]) { - logs.CtxInfo(ctx, "no space or bot found for span, space_id=%s,bot_id=%s, log_suffix=%s", span.WorkspaceID, span.TagsString["bot_id"], logSuffix) + logs.CtxInfo(ctx, "no space or bot found for span, space_id=%s, bot_id=%s, %s", span.WorkspaceID, span.TagsString["bot_id"], logSuffix) return nil } // 1.2 Filter out spans of type Evaluator - if gslice.Contains([]string{"Evaluator"}, span.CallType) { + if gslice.Contains([]string{loop_span.CallTypeEvaluator}, span.CallType) { return nil } + // 2、Match spans against task rules - subs, err := h.getSubscriberOfSpan(ctx, span) + subs, err := h.buildSubscriberOfSpan(ctx, span) if err != nil { logs.CtxWarn(ctx, "get subscriber of flow span failed, %s, err: %v", logSuffix, err) + return err } logs.CtxInfo(ctx, "%d subscriber of flow span found, %s", len(subs), logSuffix) if len(subs) == 0 { return nil } - // 3、Sample - subs = gslice.Filter(subs, func(sub *spanSubscriber) bool { return sub.Sampled() }) - logs.CtxInfo(ctx, "%d subscriber of flow span sampled, %s", len(subs), logSuffix) - if len(subs) == 0 { - return nil - } + // 3. PreDispatch - err = h.preDispatch(ctx, span, subs) - if err != nil { + if err = h.preDispatch(ctx, subs); err != nil { logs.CtxWarn(ctx, "preDispatch flow span failed, %s, err: %v", logSuffix, err) + return err } logs.CtxInfo(ctx, "%d preDispatch success, %v", len(subs), subs) + // 4、Dispatch if err = h.dispatch(ctx, span, subs); err != nil { logs.CtxError(ctx, "dispatch flow span failed, %s, err: %v", logSuffix, err) - // Dispatch failed, continue to the next span - return nil + return err } return nil } -func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loop_span.Span) ([]*spanSubscriber, error) { - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := h.loader.UnmarshalKey(ctx, key, cfg); err != nil { +func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *loop_span.Span) ([]*spanSubscriber, error) { + cfg, err := h.config.GetConsumerListening(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get consumer listening config, err: %v", err) return nil, err } @@ -81,23 +74,26 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo logs.CtxError(ctx, "Failed to get non-final task list, err: %v", err) return nil, err } - taskList := tconv.TaskDOs2DTOs(ctx, taskDOs, nil) - for _, taskDO := range taskList { - if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.GetWorkspaceID()) { + for _, taskDO := range taskDOs { + if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.WorkspaceID) { continue } + if taskDO.EffectiveTime == nil || taskDO.EffectiveTime.StartAt == 0 { + continue + } + if span.StartTime < taskDO.EffectiveTime.StartAt { + logs.CtxInfo(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) + continue + } + proc := h.taskProcessor.GetTaskProcessor(taskDO.TaskType) subscribers = append(subscribers, &spanSubscriber{ - taskID: taskDO.GetID(), - RWMutex: sync.RWMutex{}, - t: taskDO, - processor: proc, - bufCap: 0, - flushWait: sync.WaitGroup{}, - maxFlushInterval: time.Second * 5, - taskRepo: h.taskRepo, - runType: task.TaskRunTypeNewData, - buildHelper: h.buildHelper, + taskID: taskDO.ID, + t: taskDO, + processor: proc, + taskRepo: h.taskRepo, + runType: entity.TaskRunTypeNewData, + buildHelper: h.buildHelper, }) } @@ -114,69 +110,66 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo continue } if ok { - subscribers[keep] = s - keep++ + if s.Sampled() { + subscribers[keep] = s + keep++ + } else { + logs.CtxInfo(ctx, "span not sampled, task_id=%d, trace_id=%s, span_id=%s", s.taskID, span.TraceID, span.SpanID) + } } } return subscribers[:keep], merr.ErrorOrNil() } -func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { +func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.GetRule().GetEffectiveTime() == nil || sub.t.GetRule().GetEffectiveTime().GetStartAt() == 0 { - continue - } - if span.StartTime < sub.t.GetRule().GetEffectiveTime().GetStartAt() { - logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) - continue - } // First step: lock for task status change // Task run status var runStartAt, runEndAt int64 - if sub.t.GetTaskStatus() == task.TaskStatusUnstarted { + if sub.t.TaskStatus == entity.TaskStatusUnstarted { logs.CtxWarn(ctx, "task is unstarted, need sub.Creative") - runStartAt = sub.t.GetRule().GetEffectiveTime().GetStartAt() - if !sub.t.GetRule().GetSampler().GetIsCycle() { - runEndAt = sub.t.GetRule().GetEffectiveTime().GetEndAt() + runStartAt = sub.t.EffectiveTime.StartAt + if !sub.t.Sampler.IsCycle { + runEndAt = sub.t.EffectiveTime.EndAt } else { - switch *sub.t.GetRule().GetSampler().CycleTimeUnit { - case task.TimeUnitDay: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*7*24*time.Hour.Milliseconds() + switch sub.t.Sampler.CycleTimeUnit { + case entity.TimeUnitDay: + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*24*time.Hour.Milliseconds() + case entity.TimeUnitWeek: + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*7*24*time.Hour.Milliseconds() default: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*10*time.Minute.Milliseconds() + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*10*time.Minute.Milliseconds() } } if err := sub.Creative(ctx, runStartAt, runEndAt); err != nil { merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } - if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t, "", nil), task.TaskStatusRunning); err != nil { - logs.CtxWarn(ctx, "OnUpdateTaskChange, task_id=%d, err=%v", sub.taskID, err) + if err := sub.processor.OnTaskUpdated(ctx, sub.t, entity.TaskStatusRunning); err != nil { + logs.CtxWarn(ctx, "OnTaskUpdated, task_id=%d, err=%v", sub.taskID, err) continue } } // Fetch the corresponding task config - taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, sub.t.WorkspaceID, sub.taskID) + taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) if err != nil { logs.CtxWarn(ctx, "GetLatestNewDataTaskRun, task_id=%d, err=%v", sub.taskID, err) continue } if taskRunConfig == nil { logs.CtxWarn(ctx, "task run config not found, task_id=%d", sub.taskID) - runStartAt = sub.t.GetRule().GetEffectiveTime().GetStartAt() - if !sub.t.GetRule().GetSampler().GetIsCycle() { - runEndAt = sub.t.GetRule().GetEffectiveTime().GetEndAt() + runStartAt = sub.t.EffectiveTime.StartAt + if !sub.t.Sampler.IsCycle { + runEndAt = sub.t.EffectiveTime.EndAt } else { - switch *sub.t.GetRule().GetSampler().CycleTimeUnit { - case task.TimeUnitDay: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*7*24*time.Hour.Milliseconds() + switch sub.t.Sampler.CycleTimeUnit { + case entity.TimeUnitDay: + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*24*time.Hour.Milliseconds() + case entity.TimeUnitWeek: + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*7*24*time.Hour.Milliseconds() default: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*10*time.Minute.Milliseconds() + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*10*time.Minute.Milliseconds() } } if err = sub.Creative(ctx, runStartAt, runEndAt); err != nil { @@ -184,17 +177,13 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } continue } - sampler := sub.t.GetRule().GetSampler() - // Fetch the corresponding task count and subtask count - taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) - taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, taskRunConfig.ID) - logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) - endTime := time.UnixMilli(sub.t.GetRule().GetEffectiveTime().GetEndAt()) + + endTime := time.UnixMilli(sub.t.EffectiveTime.EndAt) // Reached task time limit if time.Now().After(endTime) { - logs.CtxWarn(ctx, "[OnFinishTaskChange]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + logs.CtxWarn(ctx, "[OnTaskFinished]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -203,11 +192,17 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S continue } } + + sampler := sub.t.Sampler + // Fetch the corresponding task count and subtask count + taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) + taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, taskRunConfig.ID) + logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) // Reached task limit - if taskCount+1 > sampler.GetSampleSize() { - logs.CtxWarn(ctx, "[OnFinishTaskChange]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + if taskCount+1 > sampler.SampleSize { + logs.CtxWarn(ctx, "[OnTaskFinished]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -215,13 +210,13 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S continue } } - if sampler.GetIsCycle() { + if sampler.IsCycle { cycleEndTime := time.Unix(0, taskRunConfig.RunEndAt.UnixMilli()*1e6) // Reached single cycle task time limit if time.Now().After(cycleEndTime) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + logs.CtxInfo(ctx, "[OnTaskFinished]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, TaskRun: taskRunConfig, IsFinish: false, }); err != nil { @@ -236,10 +231,10 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } } // Reached single cycle task limit - if taskRunCount+1 > sampler.GetCycleCount() { - logs.CtxWarn(ctx, "[OnFinishTaskChange]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + if taskRunCount+1 > sampler.CycleCount { + logs.CtxWarn(ctx, "[OnTaskFinished]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, TaskRun: taskRunConfig, IsFinish: false, }); err != nil { @@ -255,36 +250,42 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.GetTaskStatus() != task.TaskStatusRunning { + if sub.t.TaskStatus != entity.TaskStatusRunning { continue } logs.CtxInfo(ctx, " sub.AddSpan: %v", sub) if err := sub.AddSpan(ctx, span); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "add span to subscriber, task_id=%d", sub.taskID)) - continue + merr = multierror.Append(merr, errors.WithMessagef(err, "add span to subscriber, log_id=%s, trace_id=%s, span_id=%s, task_id=%d", + span.LogID, span.TraceID, span.SpanID, sub.taskID)) + } else { + logs.CtxInfo(ctx, "add span to subscriber, task_id=%d, log_id=%s, trace_id=%s, span_id=%s", sub.taskID, + span.LogID, span.TraceID, span.SpanID) } - logs.CtxInfo(ctx, "add span to subscriber, task_id=%d, log_id=%s, trace_id=%s, span_id=%s", sub.taskID, - span.LogID, span.TraceID, span.SpanID) } return merr.ErrorOrNil() } -// getObjListWithTaskFromCache retrieves the task list from cache, falling back to the database if cache is empty -func (h *TraceHubServiceImpl) getObjListWithTaskFromCache(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) { - // First, try to retrieve tasks from cache - objListWithTask, ok := h.taskCache.Load("ObjListWithTask") - if !ok { - // Cache is empty, fallback to the database - logs.CtxError(ctx, "Cache is empty, retrieving task list from database") - return nil, nil, nil +func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, spaceID string) ([]*entity.ObservabilityTask, error) { + var taskPOs []*entity.ObservabilityTask + nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTaskBySpaceID(ctx, spaceID) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return nil, err } - - cacheInfo, ok := objListWithTask.(TaskCacheInfo) - if !ok { - logs.CtxError(ctx, "Cache data type mismatch") - return nil, nil, nil + logs.CtxInfo(ctx, "Start listing non-final tasks, taskCount:%d, nonFinalTaskIDs:%v", len(nonFinalTaskIDs), nonFinalTaskIDs) + if len(nonFinalTaskIDs) == 0 { + return taskPOs, nil } - - logs.CtxInfo(ctx, "Retrieve task list from cache, taskCount=%d, spaceCount=%d, botCount=%d", len(cacheInfo.Tasks), len(cacheInfo.WorkspaceIDs), len(cacheInfo.BotIDs)) - return cacheInfo.WorkspaceIDs, cacheInfo.BotIDs, cacheInfo.Tasks + for _, taskID := range nonFinalTaskIDs { + taskPO, err := h.taskRepo.GetTaskByCache(ctx, taskID) + if err != nil { + logs.CtxError(ctx, "Failed to get task", "err", err) + return nil, err + } + if taskPO == nil { + continue + } + taskPOs = append(taskPOs, taskPO) + } + return taskPOs, nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index e9d29792f..2e4a8913d 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -13,6 +13,7 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + taskconvertor "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" componentconfig "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" @@ -45,7 +46,7 @@ func TestTraceHubServiceImpl_SpanTriggerSkipNoWorkspace(t *testing.T) { ServerEnv: &entity.ServerInRawSpan{}, } - require.NoError(t, impl.SpanTrigger(context.Background(), raw)) + require.NoError(t, impl.SpanTrigger(context.Background(), raw.RawSpanConvertToLoopSpan())) } func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { @@ -64,11 +65,11 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { taskDO := &entity.ObservabilityTask{ ID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusRunning, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, SpanFilter: &entity.SpanFilterFields{ - PlatformType: common.PlatformTypeLoopAll, - SpanListType: common.SpanListTypeAllSpan, + PlatformType: loop_span.PlatformDefault, + SpanListType: loop_span.SpanListTypeAllSpan, Filters: loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), FilterFields: []*loop_span.FilterField{}, @@ -88,14 +89,16 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ID: 101, TaskID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.TaskStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), }, }, } + mockRepo.EXPECT().ListNonFinalTaskBySpaceID(gomock.Any(), gomock.Any()).Return([]int64{taskDO.ID}, nil).AnyTimes() + configLoader.EXPECT().UnmarshalKey(gomock.Any(), "consumer_listening", gomock.Any()).DoAndReturn( func(_ context.Context, _ string, value any, _ ...pkgconf.DecodeOptionFn) error { cfg := value.(*componentconfig.ConsumerListening) @@ -103,8 +106,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { return nil }, ).AnyTimes() - mockRepo.EXPECT().ListNonFinalTask(gomock.Any(), "space-1").Return([]int64{taskDO.ID}, nil).AnyTimes() - mockRepo.EXPECT().GetTaskByRedis(gomock.Any(), taskDO.ID).Return(taskDO, nil).AnyTimes() + mockRepo.EXPECT().GetTaskByCache(gomock.Any(), taskDO.ID).Return(taskDO, nil).AnyTimes() mockFilter.EXPECT().BuildBasicSpanFilter(gomock.Any(), gomock.Any()).Return(nil, false, nil).AnyTimes() mockFilter.EXPECT().BuildALLSpanFilter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockBuilder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), gomock.Any()).Return(mockFilter, nil).AnyTimes() @@ -113,8 +115,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ID: 201, TaskID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.TaskStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-15 * time.Minute), RunEndAt: now.Add(15 * time.Minute), } @@ -124,7 +126,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { proc := &stubProcessor{invokeErr: errors.New("invoke error"), createTaskRunErr: errors.New("create run error")} taskProcessor := processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, proc) + taskProcessor.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -151,7 +153,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ServerEnv: &entity.ServerInRawSpan{}, } - err := impl.SpanTrigger(context.Background(), raw) + err := impl.SpanTrigger(context.Background(), raw.RawSpanConvertToLoopSpan()) require.NoError(t, err) require.True(t, proc.invokeCalled) } @@ -187,25 +189,25 @@ func TestTraceHubServiceImpl_preDispatchHandlesUnstartedAndLimits(t *testing.T) } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 303, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-90 * time.Minute), RunEndAt: now.Add(-30 * time.Minute), @@ -264,19 +266,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesMissingTaskRunConfig(t *testing.T } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -324,25 +326,25 @@ func TestTraceHubServiceImpl_preDispatchHandlesNonCycle(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 707, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -393,19 +395,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesCycleDefaultUnit(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -455,25 +457,25 @@ func TestTraceHubServiceImpl_preDispatchTimeLimitFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1101, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-2 * time.Hour), @@ -522,25 +524,25 @@ func TestTraceHubServiceImpl_preDispatchSampleLimitFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1404, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -589,25 +591,25 @@ func TestTraceHubServiceImpl_preDispatchCycleTimeLimitFinishError(t *testing.T) } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1707, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), @@ -656,25 +658,25 @@ func TestTraceHubServiceImpl_preDispatchCycleCountFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 2009, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -719,19 +721,19 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) impl := &TraceHubServiceImpl{taskRepo: mockRepo} span := &loop_span.Span{StartTime: now.UnixMilli(), TraceID: "trace", SpanID: "span"} @@ -742,6 +744,10 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { require.Equal(t, 1, len(stubProc.createTaskRunReqs)) } +func toObservabilityTask(dto *task.Task) *entity.ObservabilityTask { + return taskconvertor.TaskDTO2DO(dto) +} + func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -758,22 +764,22 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { CycleTimeUnit: &firstSamplerUnit, } firstSub := &spanSubscriber{ - taskID: 11, - t: &task.Task{ - ID: ptr.Of(int64(11)), - WorkspaceID: ptr.Of(int64(21)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: &task.Rule{ - EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(firstStartAt), EndAt: ptr.Of(now.Add(time.Hour).UnixMilli())}, - Sampler: firstSampler, - }, - BaseInfo: &common.BaseInfo{}, - }, + taskID: 11, processor: firstProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + firstSub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(int64(11)), + WorkspaceID: ptr.Of(int64(21)), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: &task.Rule{ + EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(firstStartAt), EndAt: ptr.Of(now.Add(time.Hour).UnixMilli())}, + Sampler: firstSampler, + }, + BaseInfo: &common.BaseInfo{}, + }) secondStartAt := now.Add(-2 * time.Hour).UnixMilli() secondEndAt := now.Add(-time.Minute).UnixMilli() @@ -790,29 +796,29 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { ID: 101, TaskID: secondTaskID, WorkspaceID: secondWorkspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-90 * time.Minute), } secondProc := &stubProcessor{finishErrSeq: []error{errors.New("second fail")}} secondSub := &spanSubscriber{ - taskID: secondTaskID, - t: &task.Task{ - ID: ptr.Of(secondTaskID), - WorkspaceID: ptr.Of(secondWorkspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(secondStartAt), EndAt: ptr.Of(secondEndAt)}, - Sampler: secondSampler, - }, - BaseInfo: &common.BaseInfo{}, - }, + taskID: secondTaskID, processor: secondProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + secondSub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(secondTaskID), + WorkspaceID: ptr.Of(secondWorkspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: &task.Rule{ + EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(secondStartAt), EndAt: ptr.Of(secondEndAt)}, + Sampler: secondSampler, + }, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), secondTaskID).Return(secondRun, nil) mockRepo.EXPECT().GetTaskCount(gomock.Any(), secondTaskID).Return(int64(0), nil) @@ -856,19 +862,19 @@ func TestTraceHubServiceImpl_preDispatchUpdateError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) impl := &TraceHubServiceImpl{taskRepo: mockRepo} span := &loop_span.Span{StartTime: now.UnixMilli(), TraceID: "trace", SpanID: "span"} @@ -902,19 +908,19 @@ func TestTraceHubServiceImpl_preDispatchListTaskRunError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, errors.New("repo fail")) @@ -952,19 +958,19 @@ func TestTraceHubServiceImpl_preDispatchTaskRunConfigDay(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -1009,25 +1015,25 @@ func TestTraceHubServiceImpl_preDispatchCycleCreativeError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 3102, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index 8e05b1831..13f4f0a39 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -6,12 +6,8 @@ package tracehub import ( "context" "math/rand" - "sync" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" @@ -25,39 +21,28 @@ import ( ) type spanSubscriber struct { - taskID int64 - sync.RWMutex // protect t, buf - t *task.Task - tr *task.TaskRun - processor taskexe.Processor - bufCap int // max buffer size - - flushWait sync.WaitGroup - maxFlushInterval time.Duration - taskRepo repo.ITaskRepo - runType task.TaskRunType - buildHelper service.TraceFilterProcessorBuilder + taskID int64 + t *entity.ObservabilityTask + tr *entity.TaskRun + processor taskexe.Processor + + taskRepo repo.ITaskRepo + runType entity.TaskRunType + buildHelper service.TraceFilterProcessorBuilder } // Sampled determines whether a span is sampled based on the sampling rate; the sample size will be validated during flush. func (s *spanSubscriber) Sampled() bool { - t := s.getTask() - if t == nil || t.Rule == nil || t.Rule.Sampler == nil { + if s.t == nil || s.t.Sampler == nil { return false } const base = 10000 - threshold := int64(float64(base) * t.GetRule().GetSampler().GetSampleRate()) + threshold := int64(float64(base) * s.t.Sampler.SampleRate) r := rand.Int63n(base) return r <= threshold } -func (s *spanSubscriber) getTask() *task.Task { - s.RLock() - defer s.RUnlock() - return s.t -} - func combineFilters(filters ...*loop_span.FilterFields) *loop_span.FilterFields { filterAggr := &loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), @@ -77,7 +62,7 @@ func combineFilters(filters ...*loop_span.FilterFields) *loop_span.FilterFields // Match checks whether the span matches the task filter. func (s *spanSubscriber) Match(ctx context.Context, span *loop_span.Span) (bool, error) { task := s.t - if task == nil || task.Rule == nil { + if task == nil { return false, nil } @@ -90,22 +75,22 @@ func (s *spanSubscriber) Match(ctx context.Context, span *loop_span.Span) (bool, return true, nil } -func (s *spanSubscriber) buildSpanFilters(ctx context.Context, taskConfig *task.Task) *loop_span.FilterFields { +func (s *spanSubscriber) buildSpanFilters(ctx context.Context, taskDO *entity.ObservabilityTask) *loop_span.FilterFields { // Additional filters can be constructed based on task configuration if needed. // Simplified handling here: returning nil means no extra filters are applied. filters := &loop_span.FilterFields{} - platformFilter, err := s.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(taskConfig.GetRule().GetSpanFilters().GetPlatformType())) + platformFilter, err := s.buildHelper.BuildPlatformRelatedFilter(ctx, taskDO.SpanFilter.PlatformType) if err != nil { return filters } builtinFilter, err := buildBuiltinFilters(ctx, platformFilter, &ListSpansReq{ - WorkspaceID: taskConfig.GetWorkspaceID(), - SpanListType: loop_span.SpanListType(taskConfig.GetRule().GetSpanFilters().GetSpanListType()), + WorkspaceID: taskDO.WorkspaceID, + SpanListType: taskDO.SpanFilter.SpanListType, }) if err != nil { return filters } - filters = combineFilters(builtinFilter, convertor.FilterFieldsDTO2DO(taskConfig.GetRule().GetSpanFilters().GetFilters())) + filters = combineFilters(builtinFilter, &taskDO.SpanFilter.Filters) return filters } @@ -154,8 +139,8 @@ func buildBuiltinFilters(ctx context.Context, f span_filter.Filter, req *ListSpa } func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int64) error { - err := s.processor.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ - CurrentTask: tconv.TaskDTO2DO(s.t, "", nil), + err := s.processor.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ + CurrentTask: s.t, RunType: s.runType, RunStartAt: runStartAt, RunEndAt: runEndAt, @@ -169,16 +154,16 @@ func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int6 func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) error { var taskRunConfig *entity.TaskRun var err error - if s.runType == task.TaskRunTypeNewData { - taskRunConfig, err = s.taskRepo.GetLatestNewDataTaskRun(ctx, nil, s.t.GetID()) + if s.runType == entity.TaskRunTypeNewData { + taskRunConfig, err = s.taskRepo.GetLatestNewDataTaskRun(ctx, nil, s.t.ID) if err != nil { - logs.CtxWarn(ctx, "get latest new data task run failed, task_id=%d, err: %v", s.t.GetID(), err) + logs.CtxWarn(ctx, "get latest new data task run failed, task_id=%d, err: %v", s.t.ID, err) return err } } else { - taskRunConfig, err = s.taskRepo.GetBackfillTaskRun(ctx, nil, s.t.GetID()) + taskRunConfig, err = s.taskRepo.GetBackfillTaskRun(ctx, nil, s.t.ID) if err != nil { - logs.CtxWarn(ctx, "get backfill task run failed, task_id=%d, err: %v", s.t.GetID(), err) + logs.CtxWarn(ctx, "get backfill task run failed, task_id=%d, err: %v", s.t.ID, err) return err } } @@ -195,7 +180,7 @@ func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) erro logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) return nil } - trigger := &taskexe.Trigger{Task: tconv.TaskDTO2DO(s.t, "", nil), Span: span, TaskRun: taskRunConfig} + trigger := &taskexe.Trigger{Task: s.t, Span: span, TaskRun: taskRunConfig} logs.CtxInfo(ctx, "invoke processor, trigger: %v", trigger) err = s.processor.Invoke(ctx, trigger) if err != nil { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go index 15134e6c9..5664d53bc 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go @@ -29,8 +29,8 @@ type stubProcessor struct { createTaskRunErr error finishChangeInvoked int invokeCalled bool - createTaskRunReqs []taskexe.OnCreateTaskRunChangeReq - finishChangeReqs []taskexe.OnFinishTaskChangeReq + createTaskRunReqs []taskexe.OnTaskRunCreatedReq + finishChangeReqs []taskexe.OnTaskFinishedReq updateCallCount int createTaskRunErrSeq []error finishErrSeq []error @@ -45,16 +45,16 @@ func (s *stubProcessor) Invoke(context.Context, *taskexe.Trigger) error { return s.invokeErr } -func (s *stubProcessor) OnCreateTaskChange(context.Context, *entity.ObservabilityTask) error { +func (s *stubProcessor) OnTaskCreated(context.Context, *entity.ObservabilityTask) error { return s.createTaskErr } -func (s *stubProcessor) OnUpdateTaskChange(context.Context, *entity.ObservabilityTask, string) error { +func (s *stubProcessor) OnTaskUpdated(context.Context, *entity.ObservabilityTask, entity.TaskStatus) error { s.updateCallCount++ return s.updateErr } -func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnFinishTaskChangeReq) error { +func (s *stubProcessor) OnTaskFinished(_ context.Context, req taskexe.OnTaskFinishedReq) error { idx := len(s.finishChangeReqs) s.finishChangeReqs = append(s.finishChangeReqs, req) s.finishChangeInvoked++ @@ -64,7 +64,7 @@ func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnFini return s.finishErr } -func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnCreateTaskRunChangeReq) error { +func (s *stubProcessor) OnTaskRunCreated(_ context.Context, req taskexe.OnTaskRunCreatedReq) error { s.createTaskRunReqs = append(s.createTaskRunReqs, req) idx := len(s.createTaskRunReqs) - 1 if idx >= 0 && idx < len(s.createTaskRunErrSeq) { @@ -75,7 +75,7 @@ func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnC return s.createTaskRunErr } -func (s *stubProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnFinishTaskRunChangeReq) error { +func (s *stubProcessor) OnTaskRunFinished(context.Context, taskexe.OnTaskRunFinishedReq) error { return s.finishTaskRunErr } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go index 13da88d47..8b9863543 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go @@ -5,11 +5,9 @@ package tracehub import ( "context" - "sync" - "time" - "github.com/coze-dev/coze-loop/backend/infra/external/benefit" "github.com/coze-dev/coze-loop/backend/infra/lock" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" @@ -18,16 +16,14 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" trace_repo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - "github.com/coze-dev/coze-loop/backend/pkg/conf" ) //go:generate mockgen -destination=mocks/trace_hub_service.go -package=mocks . ITraceHubService type ITraceHubService interface { - SpanTrigger(ctx context.Context, event *entity.RawSpan) error - CallBack(ctx context.Context, event *entity.AutoEvalEvent) error - Correction(ctx context.Context, event *entity.CorrectionEvent) error + SpanTrigger(ctx context.Context, span *loop_span.Span) error BackFill(ctx context.Context, event *entity.BackFillEvent) error + StoneTaskCache(ctx context.Context, cacheInfo TaskCacheInfo) error } func NewTraceHubImpl( @@ -36,71 +32,44 @@ func NewTraceHubImpl( tenantProvider tenant.ITenantProvider, buildHelper service.TraceFilterProcessorBuilder, taskProcessor *processor.TaskProcessor, - benefitSvc benefit.IBenefitService, aid int32, backfillProducer mq.IBackfillProducer, locker lock.ILocker, - loader conf.IConfigLoader, + config config.ITraceConfig, ) (ITraceHubService, error) { - // Create two independent timers with different intervals - scheduledTaskTicker := time.NewTicker(5 * time.Minute) // Task status lifecycle management - 5-minute interval - syncTaskTicker := time.NewTicker(2 * time.Minute) // Data synchronization - 1-minute interval impl := &TraceHubServiceImpl{ - taskRepo: tRepo, - scheduledTaskTicker: scheduledTaskTicker, - syncTaskTicker: syncTaskTicker, - stopChan: make(chan struct{}), - traceRepo: traceRepo, - tenantProvider: tenantProvider, - buildHelper: buildHelper, - taskProcessor: taskProcessor, - benefitSvc: benefitSvc, - aid: aid, - backfillProducer: backfillProducer, - locker: locker, - loader: loader, + taskRepo: tRepo, + traceRepo: traceRepo, + tenantProvider: tenantProvider, + buildHelper: buildHelper, + taskProcessor: taskProcessor, + aid: aid, + backfillProducer: backfillProducer, + locker: locker, + config: config, + localCache: NewLocalCache(), } - // Start the scheduled tasks immediately - impl.startScheduledTask() - - // default+lane?+新集群?——定时任务和任务处理分开——内场 return impl, nil } type TraceHubServiceImpl struct { - scheduledTaskTicker *time.Ticker // Task status lifecycle management timer - 5-minute interval - syncTaskTicker *time.Ticker // Data synchronization timer - 1-minute interval - stopChan chan struct{} - taskRepo repo.ITaskRepo - traceRepo trace_repo.ITraceRepo - tenantProvider tenant.ITenantProvider - taskProcessor *processor.TaskProcessor - buildHelper service.TraceFilterProcessorBuilder - benefitSvc benefit.IBenefitService - backfillProducer mq.IBackfillProducer - locker lock.ILocker - loader conf.IConfigLoader - - flushErrLock sync.Mutex - flushErr []error + taskRepo repo.ITaskRepo + traceRepo trace_repo.ITraceRepo + tenantProvider tenant.ITenantProvider + taskProcessor *processor.TaskProcessor + buildHelper service.TraceFilterProcessorBuilder + backfillProducer mq.IBackfillProducer + locker lock.ILocker + config config.ITraceConfig // Local cache - caching non-terminal task information - taskCache sync.Map - taskCacheLock sync.RWMutex + localCache *LocalCache aid int32 } -type flushReq struct { - retrievedSpanCount int64 - pageToken string - spans []*loop_span.Span - noMore bool -} - -const TagKeyResult = "tag_key" - -func (h *TraceHubServiceImpl) Close() { - close(h.stopChan) +func (h *TraceHubServiceImpl) StoneTaskCache(ctx context.Context, cacheInfo TaskCacheInfo) error { + h.localCache.StoneTaskCache(ctx, cacheInfo) + return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go index 5740db074..3c8bd53e1 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go @@ -5,17 +5,13 @@ package tracehub import ( "context" - "errors" "testing" "go.uber.org/mock/gomock" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - trace_repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" "github.com/stretchr/testify/require" ) @@ -72,107 +68,15 @@ func TestTraceHubServiceImpl_applySampling(t *testing.T) { spans := []*loop_span.Span{{SpanID: "1"}, {SpanID: "2"}, {SpanID: "3"}} impl := &TraceHubServiceImpl{} - fullRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(1.0)}}, - }, - } - zeroRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(0.0)}}, - }, - } - halfRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(0.5)}}, - }, - } + fullRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 1.0}}} + zeroRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0}}} + halfRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0.5}}} require.Len(t, impl.applySampling(spans, fullRate), len(spans)) require.Nil(t, impl.applySampling(spans, zeroRate)) require.Len(t, impl.applySampling(spans, halfRate), 1) } -func TestTraceHubServiceImpl_updateTaskRunDetailsCount(t *testing.T) { - t.Parallel() - - ctx := context.Background() - taskID := int64(101) - runIDStr := "202" - runID := int64(202) - - tests := []struct { - name string - status entity.EvaluatorRunStatus - expectSuccess bool - expectFail bool - expectErr bool - }{ - { - name: "success_status", - status: entity.EvaluatorRunStatus_Success, - expectSuccess: true, - }, - { - name: "fail_status", - status: entity.EvaluatorRunStatus_Fail, - expectFail: true, - }, - { - name: "unknown_status", - status: entity.EvaluatorRunStatus_Unknown, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockRepo := repo_mocks.NewMockITaskRepo(ctrl) - impl := &TraceHubServiceImpl{taskRepo: mockRepo} - - turn := &entity.OnlineExptTurnEvalResult{ - Status: tt.status, - Ext: map[string]string{ - "run_id": runIDStr, - }, - } - - if tt.expectSuccess { - mockRepo.EXPECT().IncrTaskRunSuccessCount(ctx, taskID, runID, gomock.Any()).Return(nil) - } - if tt.expectFail { - mockRepo.EXPECT().IncrTaskRunFailCount(ctx, taskID, runID, gomock.Any()).Return(nil) - } - - err := impl.updateTaskRunDetailsCount(ctx, taskID, turn, 0) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } - - t.Run("missing_run_id", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{}}, 0) - require.Error(t, err) - }) - - t.Run("invalid_run_id", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{"run_id": "abc"}}, 0) - require.Error(t, err) - }) -} - func TestTraceHubServiceImpl_sendBackfillMessage(t *testing.T) { t.Parallel() @@ -188,105 +92,6 @@ func TestTraceHubServiceImpl_sendBackfillMessage(t *testing.T) { require.Equal(t, evt, fake.event) } -func TestTraceHubServiceImpl_getSpan(t *testing.T) { - t.Parallel() - - ctx := context.Background() - tenants := []string{"tenant"} - spanIDs := []string{"span-1"} - traceID := "trace-1" - workspaceID := "ws-1" - start := int64(1000) - end := int64(2000) - - t.Run("with_trace_id", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - expectedSpan := &loop_span.Span{SpanID: spanIDs[0], TraceID: traceID} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( - func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { - require.Equal(t, tenants, param.Tenants) - require.Equal(t, start, param.StartAt) - require.Equal(t, end, param.EndAt) - require.True(t, param.NotQueryAnnotation) - require.Equal(t, int32(2), param.Limit) - require.Len(t, param.Filters.FilterFields, 3) - return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil - }, - ) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.NoError(t, err) - require.Equal(t, []*loop_span.Span{expectedSpan}, spans) - }) - - t.Run("without_trace_id", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - expectedSpan := &loop_span.Span{SpanID: spanIDs[0]} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( - func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { - require.Equal(t, tenants, param.Tenants) - require.Len(t, param.Filters.FilterFields, 2) - return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil - }, - ) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, "", workspaceID, start, end) - require.NoError(t, err) - require.Equal(t, []*loop_span.Span{expectedSpan}, spans) - }) - - t.Run("empty_span_ids", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - _, err := impl.getSpan(ctx, tenants, nil, traceID, workspaceID, start, end) - require.Error(t, err) - }) - - t.Run("empty_workspace", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, "", start, end) - require.Error(t, err) - }) - - t.Run("repo_error", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(nil, errors.New("list error")) - - _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.Error(t, err) - }) - - t.Run("no_data", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.NoError(t, err) - require.Nil(t, spans) - }) -} - type fakeBackfillProducer struct { event *entity.BackFillEvent } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go index 5a3487489..3dc04a816 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go @@ -5,18 +5,12 @@ package tracehub import ( "context" - "fmt" "os" "strconv" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/bytedance/sonic" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) @@ -42,6 +36,7 @@ func ToJSONString(ctx context.Context, obj interface{}) string { return jsonStr } +// todo 看看有没有更好的写法 func (h *TraceHubServiceImpl) fillCtx(ctx context.Context) context.Context { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) @@ -55,77 +50,3 @@ func (h *TraceHubServiceImpl) fillCtx(ctx context.Context) context.Context { func (h *TraceHubServiceImpl) getTenants(ctx context.Context, platform loop_span.PlatformType) ([]string, error) { return h.tenantProvider.GetTenantsByPlatformType(ctx, platform) } - -func (h *TraceHubServiceImpl) getSpan(ctx context.Context, tenants []string, spanIds []string, traceId, workspaceId string, startAt, endAt int64) ([]*loop_span.Span, error) { - if len(spanIds) == 0 || workspaceId == "" { - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode) - } - var filterFields []*loop_span.FilterField - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldSpanId, - FieldType: loop_span.FieldTypeString, - Values: spanIds, - QueryType: ptr.Of(loop_span.QueryTypeEnumIn), - }) - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldSpaceId, - FieldType: loop_span.FieldTypeString, - Values: []string{workspaceId}, - QueryType: ptr.Of(loop_span.QueryTypeEnumEq), - }) - if traceId != "" { - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldTraceId, - FieldType: loop_span.FieldTypeString, - Values: []string{traceId}, - - QueryType: ptr.Of(loop_span.QueryTypeEnumEq), - }) - } - var spans []*loop_span.Span - for _, tenant := range tenants { - res, err := h.traceRepo.ListSpans(ctx, &repo.ListSpansParam{ - Tenants: []string{tenant}, - Filters: &loop_span.FilterFields{ - FilterFields: filterFields, - }, - StartAt: startAt, - EndAt: endAt, - NotQueryAnnotation: true, - Limit: 2, - }) - if err != nil { - logs.CtxError(ctx, "failed to list span, %v", err) - return spans, err - } - spans = append(spans, res.Spans...) - } - logs.CtxInfo(ctx, "list span, spans: %v", spans) - - return spans, nil -} - -// updateTaskRunStatusCount updates the Redis count based on Status -func (h *TraceHubServiceImpl) updateTaskRunDetailsCount(ctx context.Context, taskID int64, turn *entity.OnlineExptTurnEvalResult, ttl int64) error { - // Retrieve taskRunID from Ext - taskRunIDStr := turn.Ext["run_id"] - if taskRunIDStr == "" { - return fmt.Errorf("task_run_id not found in ext") - } - - taskRunID, err := strconv.ParseInt(taskRunIDStr, 10, 64) - if err != nil { - return fmt.Errorf("invalid task_run_id: %s, err: %v", taskRunIDStr, err) - } - // Increase the corresponding counter based on Status - switch turn.Status { - case entity.EvaluatorRunStatus_Success: - return h.taskRepo.IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl) - case entity.EvaluatorRunStatus_Fail: - return h.taskRepo.IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl) - default: - logs.CtxDebug(ctx, "未知的评估状态,跳过计数: taskID=%d, taskRunID=%d, status=%d", - taskID, taskRunID, turn.Status) - return nil - } -} diff --git a/backend/modules/observability/domain/task/service/taskexe/types.go b/backend/modules/observability/domain/task/service/taskexe/types.go deleted file mode 100644 index 40a7dee5c..000000000 --- a/backend/modules/observability/domain/task/service/taskexe/types.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2025 coze-dev Authors -// SPDX-License-Identifier: Apache-2.0 - -package taskexe - -import ( - "context" - "errors" - - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" -) - -type Trigger struct { - Task *task_entity.ObservabilityTask - Span *loop_span.Span - TaskRun *task_entity.TaskRun -} - -var ( - ErrInvalidConfig = errors.New("invalid config") - ErrInvalidTrigger = errors.New("invalid span trigger") -) - -type OnCreateTaskRunChangeReq struct { - CurrentTask *task_entity.ObservabilityTask - RunType task.TaskRunType - RunStartAt int64 - RunEndAt int64 -} -type OnFinishTaskRunChangeReq struct { - Task *task_entity.ObservabilityTask - TaskRun *task_entity.TaskRun -} -type OnFinishTaskChangeReq struct { - Task *task_entity.ObservabilityTask - TaskRun *task_entity.TaskRun - IsFinish bool -} - -type Processor interface { - ValidateConfig(ctx context.Context, config any) error - Invoke(ctx context.Context, trigger *Trigger) error - - OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error - OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error - OnFinishTaskChange(ctx context.Context, param OnFinishTaskChangeReq) error - - OnCreateTaskRunChange(ctx context.Context, param OnCreateTaskRunChangeReq) error - OnFinishTaskRunChange(ctx context.Context, param OnFinishTaskRunChangeReq) error -} - -type ProcessorUnion interface { - Processor -} diff --git a/backend/modules/observability/domain/trace/entity/common/page.go b/backend/modules/observability/domain/trace/entity/common/page.go new file mode 100644 index 000000000..ad3ade039 --- /dev/null +++ b/backend/modules/observability/domain/trace/entity/common/page.go @@ -0,0 +1,9 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package common + +type OrderBy struct { + Field string + IsAsc bool +} diff --git a/backend/modules/observability/domain/trace/entity/loop_span/annotation.go b/backend/modules/observability/domain/trace/entity/loop_span/annotation.go index ed8b7a5c1..2615f6dc0 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/annotation.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/annotation.go @@ -312,6 +312,16 @@ func (a AnnotationList) Uniq() AnnotationList { }) } +func (a AnnotationList) FindByEvaluatorRecordID(evaluatorRecordID int64) (*Annotation, bool) { + for _, annotation := range a { + meta := annotation.GetAutoEvaluateMetadata() + if meta != nil && meta.EvaluatorRecordID == evaluatorRecordID { + return annotation, true + } + } + return nil, false +} + func NewStringValue(v string) AnnotationValue { return AnnotationValue{ ValueType: AnnotationValueTypeString, diff --git a/backend/modules/observability/domain/trace/entity/loop_span/filter.go b/backend/modules/observability/domain/trace/entity/loop_span/filter.go index bef15eb80..c70273889 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/filter.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/filter.go @@ -364,6 +364,15 @@ func (f *FilterField) CheckValue(val any) bool { } } +func (f *FilterField) SetHidden(hidden bool) { + f.Hidden = hidden + if f.SubFilter != nil { + for _, subFilters := range f.SubFilter.FilterFields { + subFilters.SetHidden(hidden) + } + } +} + func CompareBool(val bool, values []bool, qType QueryTypeEnum) bool { switch qType { case QueryTypeEnumEq: diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span.go b/backend/modules/observability/domain/trace/entity/loop_span/span.go index ebc602a49..de9fc8dee 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span.go @@ -77,6 +77,8 @@ const ( MaxKeySize = 100 MaxTextSize = 1024 * 1024 MaxCommonValueSize = 1024 + + CallTypeEvaluator = "Evaluator" ) type TTL string @@ -425,6 +427,35 @@ func (s *Span) AddManualDatasetAnnotation(datasetID int64, userID string, annota return a, nil } +func (s *Span) AddAutoEvalAnnotation(taskID, evaluatorRecordID, evaluatorVersionID int64, score float64, reasoning, userID string) (*Annotation, error) { + a := &Annotation{} + a.SpanID = s.SpanID + a.TraceID = s.TraceID + a.StartTime = time.UnixMicro(s.StartTime) + a.WorkspaceID = s.WorkspaceID + a.AnnotationType = AnnotationTypeAutoEvaluate + a.Key = fmt.Sprintf("%d:%d", taskID, evaluatorVersionID) + a.Value = NewDoubleValue(score) + a.Reasoning = reasoning + a.Metadata = &AutoEvaluateMetadata{ + TaskID: taskID, + EvaluatorRecordID: evaluatorRecordID, + EvaluatorVersionID: evaluatorVersionID, + } + a.Status = AnnotationStatusNormal + a.CreatedAt = time.Now() + a.CreatedBy = userID + a.UpdatedAt = time.Now() + a.UpdatedBy = userID + + if err := a.GenID(); err != nil { + return nil, err + } + + s.AddAnnotation(a) + return a, nil +} + func (s *Span) ExtractByJsonpath(ctx context.Context, key string, jsonpath string) (string, error) { jsonpath = strings.TrimPrefix(jsonpath, key) jsonpath = strings.TrimPrefix(jsonpath, ".") diff --git a/backend/modules/observability/domain/trace/service/trace_service.go b/backend/modules/observability/domain/trace/service/trace_service.go index 7296cee40..b3abd2302 100644 --- a/backend/modules/observability/domain/trace/service/trace_service.go +++ b/backend/modules/observability/domain/trace/service/trace_service.go @@ -11,8 +11,7 @@ import ( "time" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" - taskRepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "golang.org/x/sync/errgroup" "github.com/bytedance/gg/gptr" @@ -37,7 +36,7 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/lang/goroutine" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" - time_util "github.com/coze-dev/coze-loop/backend/pkg/time" + timeutil "github.com/coze-dev/coze-loop/backend/pkg/time" "github.com/samber/lo" ) @@ -277,7 +276,7 @@ func NewTraceServiceImpl( buildHelper TraceFilterProcessorBuilder, tenantProvider tenant.ITenantProvider, evalSvc rpc.IEvaluatorRPCAdapter, - taskRepo taskRepo.ITaskRepo, + taskRepo taskrepo.ITaskRepo, ) (ITraceService, error) { return &TraceServiceImpl{ traceRepo: tRepo, @@ -301,7 +300,7 @@ type TraceServiceImpl struct { buildHelper TraceFilterProcessorBuilder tenantProvider tenant.ITenantProvider evalSvc rpc.IEvaluatorRPCAdapter - taskRepo taskRepo.ITaskRepo + taskRepo taskrepo.ITaskRepo } func (r *TraceServiceImpl) GetTrace(ctx context.Context, req *GetTraceReq) (*GetTraceResp, error) { @@ -1235,7 +1234,7 @@ func (r *TraceServiceImpl) ListAnnotationEvaluators(ctx context.Context, req *Li evaluators = append(evaluators, evaluatorList...) } else { // 没有name先查task - taskDOs, _, err := r.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskDOs, _, err := r.taskRepo.ListTasks(ctx, taskrepo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, ReqLimit: int32(500), ReqOffset: int32(0), @@ -1437,7 +1436,7 @@ func processLatencyFilter(f *loop_span.FilterField) error { if err != nil { return fmt.Errorf("fail to parse long value %s, %v", val, err) } - integer = time_util.MillSec2MicroSec(integer) + integer = timeutil.MillSec2MicroSec(integer) micros = append(micros, strconv.FormatInt(integer, 10)) } f.Values = micros diff --git a/backend/modules/observability/domain/trace/service/trace_service_test.go b/backend/modules/observability/domain/trace/service/trace_service_test.go index 1d5839ded..0660f2839 100644 --- a/backend/modules/observability/domain/trace/service/trace_service_test.go +++ b/backend/modules/observability/domain/trace/service/trace_service_test.go @@ -52,7 +52,7 @@ func newTaskRepoMock(ctrl *gomock.Controller) *taskRepoMock { } func (m *taskRepoMock) ListNonFinalTask(context.Context, string) ([]int64, error) { - panic("unexpected call to ListNonFinalTask in taskRepoMock") + panic("unexpected call to ListNonFinalTaskBySpaceID in taskRepoMock") } func (m *taskRepoMock) AddNonFinalTask(context.Context, string, int64) error { @@ -63,8 +63,8 @@ func (m *taskRepoMock) RemoveNonFinalTask(context.Context, string, int64) error panic("unexpected call to RemoveNonFinalTask in taskRepoMock") } -func (m *taskRepoMock) GetTaskByRedis(context.Context, int64) (*taskentity.ObservabilityTask, error) { - panic("unexpected call to GetTaskByRedis in taskRepoMock") +func (m *taskRepoMock) GetTaskByCache(context.Context, int64) (*taskentity.ObservabilityTask, error) { + panic("unexpected call to GetTaskByCache in taskRepoMock") } func (m *taskRepoMock) SetTask(context.Context, *taskentity.ObservabilityTask) error { diff --git a/backend/modules/observability/infra/config/trace.go b/backend/modules/observability/infra/config/trace.go index 6732147fb..49cf7bc4d 100644 --- a/backend/modules/observability/infra/config/trace.go +++ b/backend/modules/observability/infra/config/trace.go @@ -26,6 +26,7 @@ const ( queryTraceRateLimitCfgKey = "query_trace_rate_limit_config" keySpanTypeCfgKey = "key_span_type" backfillMqProducerCfgKey = "backfill_mq_producer_config" + consumerListeningCfgKey = "consumer_listening" ) type TraceConfigCenter struct { @@ -171,6 +172,14 @@ func (t *TraceConfigCenter) GetKeySpanTypes(ctx context.Context) map[string][]st return keyColumns } +func (t *TraceConfigCenter) GetConsumerListening(ctx context.Context) (*config.ConsumerListening, error) { + consumerListening := new(config.ConsumerListening) + if err := t.UnmarshalKey(ctx, consumerListeningCfgKey, &consumerListening); err != nil { + return nil, err + } + return consumerListening, nil +} + func NewTraceConfigCenter(confP conf.IConfigLoader) config.ITraceConfig { ret := &TraceConfigCenter{ IConfigLoader: confP, diff --git a/backend/modules/observability/infra/mq/consumer/annotation_consumer.go b/backend/modules/observability/infra/mq/consumer/annotation_consumer.go index b0034d6ce..648694841 100644 --- a/backend/modules/observability/infra/mq/consumer/annotation_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/annotation_consumer.go @@ -22,7 +22,7 @@ type AnnotationConsumer struct { conf.IConfigLoader } -func newAnnotationConsumer(handler obapp.IAnnotationQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewAnnotationConsumer(handler obapp.IAnnotationQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &AnnotationConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go index 5ae8ceee6..20f41a7b4 100644 --- a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go @@ -17,19 +17,19 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -type CallbackConsumer struct { +type AutoTaskCallbackConsumer struct { handler obapp.ITaskQueueConsumer conf.IConfigLoader } -func newCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { - return &CallbackConsumer{ +func NewCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { + return &AutoTaskCallbackConsumer{ handler: handler, IConfigLoader: loader, } } -func (e *CallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { +func (e *AutoTaskCallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { const key = "autotask_callback_mq_consumer_config" cfg := &config.MqConsumerCfg{} if err := e.UnmarshalKey(ctx, key, cfg); err != nil { @@ -46,7 +46,7 @@ func (e *CallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, return res, nil } -func (e *CallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { +func (e *AutoTaskCallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) event := new(entity.AutoEvalEvent) @@ -55,5 +55,5 @@ func (e *CallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt return nil } logs.CtxInfo(ctx, "Callback msg, event: %v,msgID: %s", event, ext.MsgID) - return e.handler.CallBack(ctx, event) + return e.handler.AutoEvalCallback(ctx, event) } diff --git a/backend/modules/observability/infra/mq/consumer/backfill_consumer.go b/backend/modules/observability/infra/mq/consumer/backfill_consumer.go index fba8bd0e4..c5003e165 100644 --- a/backend/modules/observability/infra/mq/consumer/backfill_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/backfill_consumer.go @@ -22,7 +22,7 @@ type BackFillConsumer struct { conf.IConfigLoader } -func newBackFillConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewBackFillConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &BackFillConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/consumer.go b/backend/modules/observability/infra/mq/consumer/consumer.go index 93f47ff72..40c133afc 100644 --- a/backend/modules/observability/infra/mq/consumer/consumer.go +++ b/backend/modules/observability/infra/mq/consumer/consumer.go @@ -4,18 +4,9 @@ package consumer import ( - "context" - "os" - "github.com/coze-dev/coze-loop/backend/infra/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/application" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/pkg/conf" - "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" -) - -const ( - TceCluster = "TCE_CLUSTER" ) func NewConsumerWorkers( @@ -25,20 +16,12 @@ func NewConsumerWorkers( ) ([]mq.IConsumerWorker, error) { workers := []mq.IConsumerWorker{} workers = append(workers, - newAnnotationConsumer(handler, loader), + NewAnnotationConsumer(handler, loader), + NewTaskConsumer(taskConsumer, loader), + NewCallbackConsumer(taskConsumer, loader), + NewCorrectionConsumer(taskConsumer, loader), + NewBackFillConsumer(taskConsumer, loader), ) - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := loader.UnmarshalKey(context.Background(), key, cfg); err != nil { - return nil, err - } - if cfg.IsEnabled && slices.Contains(cfg.Clusters, os.Getenv(TceCluster)) { - workers = append(workers, - newTaskConsumer(taskConsumer, loader), - newCallbackConsumer(taskConsumer, loader), - newCorrectionConsumer(taskConsumer, loader), - newBackFillConsumer(taskConsumer, loader), - ) - } + return workers, nil } diff --git a/backend/modules/observability/infra/mq/consumer/correction_consumer.go b/backend/modules/observability/infra/mq/consumer/correction_consumer.go index a72ff61bf..9ab1eae9b 100644 --- a/backend/modules/observability/infra/mq/consumer/correction_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/correction_consumer.go @@ -21,7 +21,7 @@ type CorrectionConsumer struct { conf.IConfigLoader } -func newCorrectionConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewCorrectionConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &CorrectionConsumer{ handler: handler, IConfigLoader: loader, @@ -50,9 +50,9 @@ func (e *CorrectionConsumer) HandleMessage(ctx context.Context, ext *mq.MessageE ctx = logs.SetLogID(ctx, logID) event := new(entity.CorrectionEvent) if err := json.Unmarshal(ext.Body, event); err != nil { - logs.CtxError(ctx, "Correction msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) + logs.CtxError(ctx, "AutoEvalCorrection msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) return nil } - logs.CtxInfo(ctx, "Correction msg, event: %v,msgID=%s", event, ext.MsgID) - return e.handler.Correction(ctx, event) + logs.CtxInfo(ctx, "AutoEvalCorrection msg, event: %v,msgID=%s", event, ext.MsgID) + return e.handler.AutoEvalCorrection(ctx, event) } diff --git a/backend/modules/observability/infra/mq/consumer/task_consumer.go b/backend/modules/observability/infra/mq/consumer/task_consumer.go index 048793ddd..e6f5554fd 100644 --- a/backend/modules/observability/infra/mq/consumer/task_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/task_consumer.go @@ -22,7 +22,7 @@ type TaskConsumer struct { conf.IConfigLoader } -func newTaskConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewTaskConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &TaskConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/repo/mysql/convertor/task.go b/backend/modules/observability/infra/repo/mysql/convertor/task.go index f632d8d70..a788a3719 100644 --- a/backend/modules/observability/infra/repo/mysql/convertor/task.go +++ b/backend/modules/observability/infra/repo/mysql/convertor/task.go @@ -17,8 +17,8 @@ func TaskDO2PO(task *entity.ObservabilityTask) *model.ObservabilityTask { WorkspaceID: task.WorkspaceID, Name: task.Name, Description: task.Description, - TaskType: task.TaskType, - TaskStatus: task.TaskStatus, + TaskType: string(task.TaskType), + TaskStatus: string(task.TaskStatus), TaskDetail: ptr.Of(ToJSONString(task.TaskDetail)), SpanFilter: ptr.Of(ToJSONString(task.SpanFilter)), EffectiveTime: ptr.Of(ToJSONString(task.EffectiveTime)), @@ -38,8 +38,8 @@ func TaskPO2DO(task *model.ObservabilityTask) *entity.ObservabilityTask { WorkspaceID: task.WorkspaceID, Name: task.Name, Description: task.Description, - TaskType: task.TaskType, - TaskStatus: task.TaskStatus, + TaskType: entity.TaskType(task.TaskType), + TaskStatus: entity.TaskStatus(task.TaskStatus), TaskDetail: TaskDetailJSON2DO(task.TaskDetail), SpanFilter: SpanFilterJSON2DO(task.SpanFilter), EffectiveTime: EffectiveTimeJSON2DO(task.EffectiveTime), @@ -118,8 +118,8 @@ func TaskRunDO2PO(taskRun *entity.TaskRun) *model.ObservabilityTaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: string(taskRun.TaskType), + RunStatus: string(taskRun.RunStatus), RunDetail: ptr.Of(ToJSONString(taskRun.RunDetail)), BackfillDetail: ptr.Of(ToJSONString(taskRun.BackfillDetail)), RunStartAt: taskRun.RunStartAt, @@ -135,8 +135,8 @@ func TaskRunPO2DO(taskRun *model.ObservabilityTaskRun) *entity.TaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: entity.TaskRunType(taskRun.TaskType), + RunStatus: entity.TaskRunStatus(taskRun.RunStatus), RunDetail: RunDetailJSON2DO(taskRun.RunDetail), BackfillDetail: BackfillRunDetailJSON2DO(taskRun.BackfillDetail), RunStartAt: taskRun.RunStartAt, diff --git a/backend/modules/observability/infra/repo/mysql/task.go b/backend/modules/observability/infra/repo/mysql/task.go index b1fb01581..efa8ce8b8 100644 --- a/backend/modules/observability/infra/repo/mysql/task.go +++ b/backend/modules/observability/infra/repo/mysql/task.go @@ -11,9 +11,8 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/db" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" genquery "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/query" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" @@ -28,11 +27,14 @@ const ( DefaultLimit = 20 MaxLimit = 501 DefaultOffset = 0 + + MaxRetries = 3 + RetryDelay = 100 * time.Millisecond ) type ListTaskParam struct { WorkspaceIDs []int64 - TaskFilters *filter.TaskFilterFields + TaskFilters *entity.TaskFilterFields ReqLimit int32 ReqOffset int32 OrderBy *common.OrderBy @@ -46,7 +48,7 @@ type ITaskDao interface { DeleteTask(ctx context.Context, id int64, workspaceID int64, userID string) error ListTasks(ctx context.Context, param ListTaskParam) ([]*model.ObservabilityTask, int64, error) UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error - GetObjListWithTask(ctx context.Context) ([]string, []string, []*model.ObservabilityTask, error) + ListNonFinalTasks(ctx context.Context) ([]*model.ObservabilityTask, error) } func NewTaskDaoImpl(db db.Provider) ITaskDao { @@ -133,7 +135,13 @@ func (v *TaskDaoImpl) ListTasks(ctx context.Context, param ListTaskParam) ([]*mo return nil, 0, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) } // order by - qd = qd.Order(v.order(q, param.OrderBy.GetField(), param.OrderBy.GetIsAsc())) + orderField := "" + orderAsc := false + if param.OrderBy != nil { + orderField = param.OrderBy.Field + orderAsc = param.OrderBy.IsAsc + } + qd = qd.Order(v.order(q, orderField, orderAsc)) // 计算分页参数 limit, offset := calculatePagination(param.ReqLimit, param.ReqOffset) results, err := qd.Limit(limit).Offset(offset).Find() @@ -144,7 +152,7 @@ func (v *TaskDaoImpl) ListTasks(ctx context.Context, param ListTaskParam) ([]*mo } // 处理任务过滤条件 -func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *filter.TaskFilterFields) (field.Expr, error) { +func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *entity.TaskFilterFields) (field.Expr, error) { if taskFilters == nil || len(taskFilters.FilterFields) == 0 { return nil, nil } @@ -171,28 +179,28 @@ func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *filter.Ta } // 构建单个过滤条件 -func (v *TaskDaoImpl) buildSingleFilterExpr(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildSingleFilterExpr(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if f.FieldName == nil || f.QueryType == nil { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("field name or query type is nil")) } switch *f.FieldName { - case filter.TaskFieldNameTaskName: + case entity.TaskFieldNameTaskName: return v.buildTaskNameFilter(q, f) - case filter.TaskFieldNameTaskType: + case entity.TaskFieldNameTaskType: return v.buildTaskTypeFilter(q, f) - case filter.TaskFieldNameTaskStatus: + case entity.TaskFieldNameTaskStatus: return v.buildTaskStatusFilter(q, f) - case filter.TaskFieldNameCreatedBy: + case entity.TaskFieldNameCreatedBy: return v.buildCreatedByFilter(q, f) - case filter.TaskFieldNameSampleRate: + case entity.TaskFieldNameSampleRate: return v.buildSampleRateFilter(q, f) case "task_id": return v.buildTaskIDFilter(q, f) case "updated_at": return v.buildUpdateAtFilter(q, f) default: - return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid filter field name: %s", *f.FieldName)) + return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid filter field name: %s", string(*f.FieldName))) } } @@ -202,7 +210,7 @@ func (v *TaskDaoImpl) combineExpressions(expressions []field.Expr, relation stri return expressions[0] } - if relation == filter.QueryRelationOr { + if relation == string(entity.QueryRelationOr) { return field.Or(expressions...) } // 默认使用 AND 关系 @@ -210,15 +218,15 @@ func (v *TaskDaoImpl) combineExpressions(expressions []field.Expr, relation stri } // 构建任务名称过滤条件 -func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for task name query")) } switch *f.QueryType { - case filter.QueryTypeEq: + case entity.QueryTypeEq: return q.ObservabilityTask.Name.Eq(f.Values[0]), nil - case filter.QueryTypeMatch: + case entity.QueryTypeMatch: return q.ObservabilityTask.Name.Like(fmt.Sprintf("%%%s%%", f.Values[0])), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task name")) @@ -226,15 +234,15 @@ func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *filter.TaskFilte } // 构建任务类型过滤条件 -func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for task type query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.TaskType.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.TaskType.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task type")) @@ -242,15 +250,15 @@ func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *filter.TaskFilte } // 构建任务状态过滤条件 -func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for task status query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.TaskStatus.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.TaskStatus.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task status")) @@ -258,15 +266,15 @@ func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *filter.TaskFil } // 构建创建者过滤条件 -func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for created_by query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.CreatedBy.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.CreatedBy.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for created_by")) @@ -274,7 +282,7 @@ func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *filter.TaskFilt } // 构建采样率过滤条件 -func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for sample rate")) } @@ -287,13 +295,13 @@ func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFil // 构建 JSON_EXTRACT 表达式 switch *f.QueryType { - case filter.QueryTypeGte: + case entity.QueryTypeGte: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) >= ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeLte: + case entity.QueryTypeLte: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) <= ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeEq: + case entity.QueryTypeEq: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) = ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeNotEq: + case entity.QueryTypeNotEq: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) != ?", q.ObservabilityTask.Sampler, sampleRate), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for sample rate")) @@ -301,7 +309,7 @@ func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFil } // 构建任务ID过滤条件 -func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for task id")) } @@ -318,7 +326,7 @@ func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *filter.TaskFilterF return q.ObservabilityTask.ID.In(taskIDs...), nil } -func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for update at")) } @@ -328,9 +336,9 @@ func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *filter.TaskFilte return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid update at: %v", err.Error())) } switch *f.QueryType { - case filter.QueryTypeGt: + case entity.QueryTypeGt: return q.ObservabilityTask.UpdatedAt.Gt(time.UnixMilli(updateAtLatest)), nil - case filter.QueryTypeLt: + case entity.QueryTypeLt: return q.ObservabilityTask.UpdatedAt.Lt(time.UnixMilli(updateAtLatest)), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for update at")) @@ -367,7 +375,6 @@ func (d *TaskDaoImpl) order(q *genquery.Query, orderBy string, asc bool) field.E } func (v *TaskDaoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error { - // todo[xun]: 乐观锁 logs.CtxInfo(ctx, "UpdateTaskWithOCC, id:%d, workspaceID:%d, updateMap:%+v", id, workspaceID, updateMap) q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTask qd := q.WithContext(ctx) @@ -392,46 +399,16 @@ func (v *TaskDaoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspace return errorx.NewByCode(obErrorx.CommonMySqlErrorCode, errorx.WithExtraMsg("TaskRun update failed with OCC")) } -func (v *TaskDaoImpl) GetObjListWithTask(ctx context.Context) ([]string, []string, []*model.ObservabilityTask, error) { +func (v *TaskDaoImpl) ListNonFinalTasks(ctx context.Context) ([]*model.ObservabilityTask, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)) qd := q.WithContext(ctx).ObservabilityTask - // 查询非终态任务的workspace_id,使用DISTINCT去重 - qd = qd.Where(q.ObservabilityTask.TaskStatus.NotIn("success", "disabled")) - // qd = qd.Select(q.ObservabilityTask.WorkspaceID).Distinct() + // 查询非终态任务 + qd = qd.Where(q.ObservabilityTask.TaskStatus.NotIn(string(entity.TaskStatusSuccess), string(entity.TaskStatusDisabled))) results, err := qd.Find() if err != nil { - return nil, nil, nil, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) - } - - // 转换为字符串数组 - var spaceList []string - var botList []string - for _, task := range results { - spaceList = append(spaceList, strconv.FormatInt(task.WorkspaceID, 10)) - spanFilter := tconv.SpanFilterPO2DO(ctx, task.SpanFilter) - if spanFilter != nil && spanFilter.Filters.FilterFields != nil { - extractBotIDFromFilters(spanFilter.Filters.FilterFields, &botList) - } - } - - return spaceList, botList, nil, nil -} - -// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter -func extractBotIDFromFilters(filterFields []*filter.FilterField, botList *[]string) { - for _, filterField := range filterFields { - if filterField == nil { - continue - } - // 检查当前 FilterField 的 FieldName - if filterField.FieldName != nil && *filterField.FieldName == "bot_id" { - *botList = append(*botList, filterField.Values...) - } - // 递归处理 SubFilter - if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { - extractBotIDFromFilters(filterField.SubFilter.FilterFields, botList) - } + return nil, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) } + return results, nil } diff --git a/backend/modules/observability/infra/repo/mysql/task_run.go b/backend/modules/observability/infra/repo/mysql/task_run.go index 3f8ab0dd4..9cce07d99 100755 --- a/backend/modules/observability/infra/repo/mysql/task_run.go +++ b/backend/modules/observability/infra/repo/mysql/task_run.go @@ -9,8 +9,8 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/db" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + tracecommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" genquery "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/query" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" @@ -30,10 +30,10 @@ const ( type ListTaskRunParam struct { WorkspaceID *int64 TaskID *int64 - TaskRunStatus *task.RunStatus + TaskRunStatus *entity.TaskRunStatus ReqLimit int32 ReqOffset int32 - OrderBy *common.OrderBy + OrderBy *tracecommon.OrderBy } //go:generate mockgen -destination=mocks/task_run.go -package=mocks . ITaskRunDao @@ -57,20 +57,6 @@ type TaskRunDaoImpl struct { dbMgr db.Provider } -// TaskRun非终态状态定义 -var NonFinalTaskRunStatuses = []string{ - "pending", // 等待执行 - "running", // 执行中 - "paused", // 暂停 - "retrying", // 重试中 -} - -// 活跃状态定义(非终态状态的子集) -var ActiveTaskRunStatuses = []string{ - "running", // 执行中 - "retrying", // 重试中 -} - // 计算分页参数 func calculateTaskRunPagination(reqLimit, reqOffset int32) (int, int) { limit := DefaultTaskRunLimit @@ -88,7 +74,7 @@ func calculateTaskRunPagination(reqLimit, reqOffset int32) (int, int) { func (v *TaskRunDaoImpl) GetBackfillTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*model.ObservabilityTaskRun, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun - qd := q.WithContext(ctx).Where(q.TaskType.Eq(task.TaskRunTypeBackFill)).Where(q.TaskID.Eq(taskID)) + qd := q.WithContext(ctx).Where(q.TaskType.Eq(string(entity.TaskRunTypeBackFill))).Where(q.TaskID.Eq(taskID)) if workspaceID != nil { qd = qd.Where(q.WorkspaceID.Eq(*workspaceID)) @@ -106,7 +92,7 @@ func (v *TaskRunDaoImpl) GetBackfillTaskRun(ctx context.Context, workspaceID *in func (v *TaskRunDaoImpl) GetLatestNewDataTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*model.ObservabilityTaskRun, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun - qd := q.WithContext(ctx).Where(q.TaskType.Eq(task.TaskRunTypeNewData)).Where(q.TaskID.Eq(taskID)) + qd := q.WithContext(ctx).Where(q.TaskType.Eq(string(entity.TaskRunTypeNewData))).Where(q.TaskID.Eq(taskID)) if workspaceID != nil { qd = qd.Where(q.WorkspaceID.Eq(*workspaceID)) @@ -150,12 +136,15 @@ func (v *TaskRunDaoImpl) ListTaskRuns(ctx context.Context, param ListTaskRunPara var total int64 // TaskID过滤 - if param.TaskID != nil { - qd = qd.Where(q.ObservabilityTaskRun.TaskID.Eq(*param.TaskID)) + if param.TaskID == nil { + logs.CtxError(ctx, "TaskID is nil") + return nil, 0, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("TaskID is nil")) } + qd = qd.Where(q.ObservabilityTaskRun.TaskID.Eq(*param.TaskID)) + // TaskRunStatus过滤 if param.TaskRunStatus != nil { - qd = qd.Where(q.ObservabilityTaskRun.RunStatus.Eq(*param.TaskRunStatus)) + qd = qd.Where(q.ObservabilityTaskRun.RunStatus.Eq(string(*param.TaskRunStatus))) } // workspaceID过滤 if param.WorkspaceID != nil { @@ -163,7 +152,13 @@ func (v *TaskRunDaoImpl) ListTaskRuns(ctx context.Context, param ListTaskRunPara } // 排序 - qd = qd.Order(v.order(q, param.OrderBy.GetField(), param.OrderBy.GetIsAsc())) + orderField := "" + orderAsc := false + if param.OrderBy != nil { + orderField = param.OrderBy.Field + orderAsc = param.OrderBy.IsAsc + } + qd = qd.Order(v.order(q, orderField, orderAsc)) // 计算总数 total, err := qd.Count() @@ -200,11 +195,6 @@ func (d *TaskRunDaoImpl) order(q *genquery.Query, orderBy string, asc bool) fiel return orderExpr.Desc() } -const ( - MaxRetries = 3 - RetryDelay = 100 * time.Millisecond -) - // UpdateTaskRunWithOCC 乐观并发控制更新 func (v *TaskRunDaoImpl) UpdateTaskRunWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun diff --git a/backend/modules/observability/infra/repo/redis/dao/task.go b/backend/modules/observability/infra/repo/redis/task.go similarity index 99% rename from backend/modules/observability/infra/repo/redis/dao/task.go rename to backend/modules/observability/infra/repo/redis/task.go index 99fc27f1e..a96ce325a 100755 --- a/backend/modules/observability/infra/repo/redis/dao/task.go +++ b/backend/modules/observability/infra/repo/redis/task.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package dao +package redis import ( "context" @@ -61,6 +61,7 @@ func (q *TaskDAOImpl) makeTaskCacheKey(taskID int64) string { return fmt.Sprintf(taskDetailCacheKeyPattern, taskID) } +// 为了兼容旧版,redis key必须保持一致,无法增加前缀 func (q *TaskDAOImpl) makeTaskCountCacheKey(taskID int64) string { return fmt.Sprintf("count_%d", taskID) } diff --git a/backend/modules/observability/infra/repo/redis/dao/task_run.go b/backend/modules/observability/infra/repo/redis/task_run.go similarity index 99% rename from backend/modules/observability/infra/repo/redis/dao/task_run.go rename to backend/modules/observability/infra/repo/redis/task_run.go index 2b7c41154..04263fd7e 100755 --- a/backend/modules/observability/infra/repo/redis/dao/task_run.go +++ b/backend/modules/observability/infra/repo/redis/task_run.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package dao +package redis import ( "context" diff --git a/backend/modules/observability/infra/repo/task.go b/backend/modules/observability/infra/repo/task.go index 3791716c7..8e18f7626 100644 --- a/backend/modules/observability/infra/repo/task.go +++ b/backend/modules/observability/infra/repo/task.go @@ -13,12 +13,12 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/convertor" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, taskRedisDao dao.ITaskDAO, taskRunDao mysql.ITaskRunDao, taskRunRedisDao dao.ITaskRunDAO) repo.ITaskRepo { +func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, taskRedisDao redis.ITaskDAO, taskRunDao mysql.ITaskRunDao, taskRunRedisDao redis.ITaskRunDAO) repo.ITaskRepo { return &TaskRepoImpl{ TaskDao: TaskDao, idGenerator: idGenerator, @@ -31,23 +31,11 @@ func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, tas type TaskRepoImpl struct { TaskDao mysql.ITaskDao TaskRunDao mysql.ITaskRunDao - TaskRedisDao dao.ITaskDAO - TaskRunRedisDao dao.ITaskRunDAO + TaskRedisDao redis.ITaskDAO + TaskRunRedisDao redis.ITaskRunDAO idGenerator idgen.IIDGenerator } -// 缓存 TTL 常量 -const ( - TaskDetailTTL = 30 * time.Minute // 单个任务缓存30分钟 - NonFinalTaskListTTL = 1 * time.Minute // 非最终状态任务缓存1分钟 - TaskCountTTL = 10 * time.Minute // 任务计数缓存10分钟 -) - -// 任务运行计数TTL常量 -const ( - TaskRunCountTTL = 10 * time.Minute // 任务运行计数缓存10分钟 -) - func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) { TaskPO, err := v.TaskDao.GetTask(ctx, id, workspaceID, userID) if err != nil { @@ -59,7 +47,7 @@ func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64 TaskRunPO, _, err := v.TaskRunDao.ListTaskRuns(ctx, mysql.ListTaskRunParam{ WorkspaceID: ptr.Of(taskDO.WorkspaceID), TaskID: ptr.Of(taskDO.ID), - ReqLimit: 1000, + ReqLimit: 500, ReqOffset: 0, }) @@ -71,8 +59,14 @@ func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64 return taskDO, nil } -func (v *TaskRepoImpl) ListTasks(ctx context.Context, param mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { - results, total, err := v.TaskDao.ListTasks(ctx, param) +func (v *TaskRepoImpl) ListTasks(ctx context.Context, param repo.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { + results, total, err := v.TaskDao.ListTasks(ctx, mysql.ListTaskParam{ + WorkspaceIDs: param.WorkspaceIDs, + TaskFilters: param.TaskFilters, + ReqLimit: param.ReqLimit, + ReqOffset: param.ReqOffset, + OrderBy: param.OrderBy, + }) if err != nil { return nil, 0, err } @@ -80,12 +74,13 @@ func (v *TaskRepoImpl) ListTasks(ctx context.Context, param mysql.ListTaskParam) for i, result := range results { resp[i] = convertor.TaskPO2DO(result) } + // todo 待优化 for _, t := range resp { taskRuns, _, err := v.TaskRunDao.ListTaskRuns(ctx, mysql.ListTaskRunParam{ WorkspaceID: ptr.Of(t.WorkspaceID), TaskID: ptr.Of(t.ID), - ReqLimit: param.ReqLimit, - ReqOffset: param.ReqOffset, + ReqLimit: 500, + ReqOffset: 0, }) if err != nil { logs.CtxError(ctx, "ListTaskRuns err, taskID:%d, err:%v", t.ID, err) @@ -154,21 +149,6 @@ func (v *TaskRepoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspac return nil } -func (v *TaskRepoImpl) GetObjListWithTask(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) { - var tasks []*entity.ObservabilityTask - spaceList, botList, results, err := v.TaskDao.GetObjListWithTask(ctx) - if err != nil { - logs.CtxWarn(ctx, "failed to get obj list with task from mysql", "err", err) - return nil, nil, nil - } - tasks = make([]*entity.ObservabilityTask, len(results)) - for i, result := range results { - tasks[i] = convertor.TaskPO2DO(result) - } - - return spaceList, botList, tasks -} - func (v *TaskRepoImpl) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error { // 先执行数据库删除操作 err := v.TaskDao.DeleteTask(ctx, do.ID, do.WorkspaceID, do.CreatedBy) @@ -183,16 +163,29 @@ func (v *TaskRepoImpl) DeleteTask(ctx context.Context, do *entity.ObservabilityT return nil } +func (v *TaskRepoImpl) ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) { + result, err := v.TaskDao.ListNonFinalTasks(ctx) + if err != nil { + return nil, err + } + + resp := make([]*entity.ObservabilityTask, len(result)) + for i, t := range result { + resp[i] = convertor.TaskPO2DO(t) + } + return resp, nil +} + func (v *TaskRepoImpl) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) { // 1. 生成ID id, err := v.idGenerator.GenID(ctx) if err != nil { return 0, err } + do.ID = id // 2. 转换并设置ID taskRunPo := convertor.TaskRunDO2PO(do) - taskRunPo.ID = id // 3. 数据库创建 createdID, err := v.TaskRunDao.CreateTaskRun(ctx, taskRunPo) @@ -200,8 +193,6 @@ func (v *TaskRepoImpl) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (i return 0, err } - // 4. 异步更新缓存 - do.ID = createdID return createdID, nil } @@ -324,7 +315,7 @@ func (v *TaskRepoImpl) IncrTaskRunFailCount(ctx context.Context, taskID, taskRun return v.TaskRunRedisDao.IncrTaskRunFailCount(ctx, taskID, taskRunID, time.Duration(ttl)*time.Second) } -func (v *TaskRepoImpl) ListNonFinalTask(ctx context.Context, spaceID string) ([]int64, error) { +func (v *TaskRepoImpl) ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) { return v.TaskRedisDao.ListNonFinalTask(ctx, spaceID) } @@ -336,7 +327,7 @@ func (v *TaskRepoImpl) RemoveNonFinalTask(ctx context.Context, spaceID string, t return v.TaskRedisDao.RemoveNonFinalTask(ctx, spaceID, taskID) } -func (v *TaskRepoImpl) GetTaskByRedis(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { +func (v *TaskRepoImpl) GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { taskDO, err := v.TaskRedisDao.GetTask(ctx, taskID) if err != nil { logs.CtxError(ctx, "Failed to get task", "err", err) @@ -360,7 +351,3 @@ func (v *TaskRepoImpl) GetTaskByRedis(ctx context.Context, taskID int64) (*entit } return taskDO, nil } - -func (v *TaskRepoImpl) SetTask(ctx context.Context, task *entity.ObservabilityTask) error { - return v.TaskRedisDao.SetTask(ctx, task) -} diff --git a/backend/modules/observability/infra/repo/task_test.go b/backend/modules/observability/infra/repo/task_test.go index e27aa5ea5..c8aae58ef 100755 --- a/backend/modules/observability/infra/repo/task_test.go +++ b/backend/modules/observability/infra/repo/task_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - mysql "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" mysqlconv "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/convertor" mysqlmodel "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" ) @@ -38,6 +38,7 @@ type stubTaskDao struct { getTaskFunc func(ctx context.Context, id int64, workspaceID *int64, userID *string) (*mysqlmodel.ObservabilityTask, error) listTasksFunc func(ctx context.Context, param mysql.ListTaskParam) ([]*mysqlmodel.ObservabilityTask, int64, error) getObjListWithTaskFunc func(ctx context.Context) ([]string, []string, []*mysqlmodel.ObservabilityTask, error) + listNonFinalTasksFunc func(ctx context.Context) ([]*mysqlmodel.ObservabilityTask, error) } func (s *stubTaskDao) CreateTask(ctx context.Context, po *mysqlmodel.ObservabilityTask) (int64, error) { @@ -82,11 +83,11 @@ func (s *stubTaskDao) ListTasks(ctx context.Context, param mysql.ListTaskParam) return nil, 0, nil } -func (s *stubTaskDao) GetObjListWithTask(ctx context.Context) ([]string, []string, []*mysqlmodel.ObservabilityTask, error) { - if s.getObjListWithTaskFunc != nil { - return s.getObjListWithTaskFunc(ctx) +func (s *stubTaskDao) ListNonFinalTasks(ctx context.Context) ([]*mysqlmodel.ObservabilityTask, error) { + if s.listNonFinalTasksFunc != nil { + return s.listNonFinalTasksFunc(ctx) } - return nil, nil, nil, nil + return nil, nil } type stubTaskRedisDao struct { @@ -443,7 +444,7 @@ func TestTaskRepoImpl_NonFinalTaskWrappers(t *testing.T) { TaskRunRedisDao: stubTaskRunRedisDao{}, } - list, err := repo.ListNonFinalTask(context.Background(), "space") + list, err := repo.ListNonFinalTaskBySpaceID(context.Background(), "space") assert.NoError(t, err) assert.Equal(t, expected, list) @@ -558,7 +559,7 @@ func TestTaskRepoImpl_GetTaskByRedis(t *testing.T) { } } - got, err := repo.GetTaskByRedis(context.Background(), 100) + got, err := repo.GetTaskByCache(context.Background(), 100) if tt.expectErr != nil { assert.EqualError(t, err, tt.expectErr.Error()) } else { @@ -570,23 +571,3 @@ func TestTaskRepoImpl_GetTaskByRedis(t *testing.T) { }) } } - -func TestTaskRepoImpl_SetTask(t *testing.T) { - t.Parallel() - - called := false - redisDao := &stubTaskRedisDao{setTaskFunc: func(ctx context.Context, task *entity.ObservabilityTask) error { - called = true - assert.Equal(t, int64(1), task.ID) - return nil - }} - repo := &TaskRepoImpl{ - TaskDao: &stubTaskDao{}, - TaskRunDao: stubTaskRunDao{}, - TaskRedisDao: redisDao, - TaskRunRedisDao: stubTaskRunRedisDao{}, - } - - assert.NoError(t, repo.SetTask(context.Background(), &entity.ObservabilityTask{ID: 1})) - assert.True(t, called) -}