diff --git a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go index 52b6bff4f93..0aee1336c63 100644 --- a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go +++ b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go @@ -160,6 +160,14 @@ func (c *CloudEventWrappedPublisher) TransformWorkflowExecutionEvent(ctx context logger.Warningf(ctx, "workflow id is nil for execution [%+v]", ex) return nil, fmt.Errorf("workflow id is nil for execution [%+v]", ex) } + + if ex.GetSpec().GetLaunchPlan().GetResourceType() == core.ResourceType_TASK { + logger.Debugf(ctx, "skipping single task execution workflow event [%+v]", rawEvent.ExecutionId) + return &event.CloudEventWorkflowExecution{ + RawEvent: rawEvent, + }, nil + } + workflowModel, err := c.db.WorkflowRepo().Get(ctx, repositoryInterfaces.Identifier{ Org: ex.Closure.WorkflowId.Org, Project: ex.Closure.WorkflowId.Project, @@ -171,6 +179,7 @@ func (c *CloudEventWrappedPublisher) TransformWorkflowExecutionEvent(ctx context logger.Warningf(ctx, "couldn't find workflow [%+v] for cloud event processing", ex.Closure.WorkflowId) return nil, err } + var workflowInterface core.TypedInterface if workflowModel.TypedInterface != nil && len(workflowModel.TypedInterface) > 0 { err = proto.Unmarshal(workflowModel.TypedInterface, &workflowInterface) diff --git a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher_test.go b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher_test.go index d9108aa3ff6..eb8a6a861d4 100644 --- a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher_test.go +++ b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher_test.go @@ -15,7 +15,13 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/flyteorg/flyte/flyteadmin/pkg/data/mocks" + repositoryInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/interfaces" + repoMocks "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/mocks" + "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/event" @@ -199,3 +205,100 @@ func TestCloudEventPublisher_PublishError(t *testing.T) { assert.Equal(t, errorPublish, currentEventPublisher.Publish(context.Background(), proto.MessageName(taskRequest), taskRequest)) } + +type DummyRepositories struct { + repositoryInterfaces.Repository + RepoExecution repositoryInterfaces.ExecutionRepoInterface +} + +func (r *DummyRepositories) ExecutionRepo() repositoryInterfaces.ExecutionRepoInterface { + return r.RepoExecution +} + +func getMockSingleTaskSpec() *admin.ExecutionSpec { + return &admin.ExecutionSpec{ + LaunchPlan: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: "default_raw_output"}, + } +} + +func getMockExecutionModel() models.Execution { + spec := getMockSingleTaskSpec() + specBytes, _ := proto.Marshal(spec) + startedAt := time.Date(2018, 8, 30, 0, 0, 0, 0, time.UTC) + createdAt := time.Date(2022, 01, 18, 0, 0, 0, 0, time.UTC) + startedAtProto := timestamppb.New(startedAt) + createdAtProto := timestamppb.New(createdAt) + + closure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + StartedAt: startedAtProto, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: createdAtProto, + }, + WorkflowId: &core.Identifier{ + ResourceType: core.ResourceType_WORKFLOW, + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + Org: "", + }, + ResolvedSpec: spec, + } + closureBytes, _ := proto.Marshal(&closure) + stateInt := int32(admin.ExecutionState_EXECUTION_ACTIVE) + executionModel := models.Execution{ + Spec: specBytes, + Phase: core.WorkflowExecution_SUCCEEDED.String(), + Closure: closureBytes, + LaunchPlanID: uint(1), + WorkflowID: uint(2), + StartedAt: &startedAt, + State: &stateInt, + } + return executionModel +} + +func TestCloudEventsPublisher_TransformWorkflow(t *testing.T) { + testScope := promutils.NewTestScope() + ctx := context.Background() + + mockURLData := mocks.NewMockRemoteURL() + dummyDataConfig := interfaces.RemoteDataConfig{} + cloudEventPublisher := NewCloudEventsWrappedPublisher(nil, mockPubSubSender, testScope, nil, mockURLData, dummyDataConfig) + + t.Run("single task should skip", func(t *testing.T) { + mockExecutionRepo := repoMocks.NewMockExecutionRepo() + mockDB := &DummyRepositories{RepoExecution: mockExecutionRepo} + + mockExecutionRepo.(*repoMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input repositoryInterfaces.Identifier) (models.Execution, error) { + assert.Equal(t, input.Org, executionID.Org) + assert.Equal(t, input.Project, executionID.Project) + assert.Equal(t, input.Domain, executionID.Domain) + assert.Equal(t, input.Name, executionID.Name) + dummyModel := getMockExecutionModel() + + return dummyModel, nil + }) + + rawEvent := &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_SUCCEEDED, + ExecutionId: &executionID, + } + + casted := cloudEventPublisher.(*CloudEventWrappedPublisher) + casted.db = mockDB + ceEvent, err := casted.TransformWorkflowExecutionEvent(ctx, rawEvent) + assert.Nil(t, err) + assert.NotNil(t, ceEvent) + assert.Nil(t, ceEvent.GetOutputInterface()) + }) +}