diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go index 29de745acf..fbe0a8c1a6 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -15,6 +15,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" evtErr "github.com/flyteorg/flyte/flytepropeller/events/errors" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flyte/flytestdlib/cache" stdErr "github.com/flyteorg/flyte/flytestdlib/errors" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -114,6 +115,15 @@ func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchCo }) } + // Make a copy of the labels with shard-key removed. This ensures that the shard-key is re-computed for each + // instead of being copied from the parent. + labels := make(map[string]string) + for key, value := range launchCtx.Labels { + if key != k8s.ShardKeyLabel { + labels[key] = value + } + } + req := &admin.ExecutionCreateRequest{ Project: executionID.Project, Domain: executionID.Domain, @@ -127,7 +137,7 @@ func (a *adminLaunchPlanExecutor) Launch(ctx context.Context, launchCtx LaunchCo Principal: launchCtx.Principal, ParentNodeExecution: launchCtx.ParentNodeExecution, }, - Labels: &admin.Labels{Values: launchCtx.Labels}, + Labels: &admin.Labels{Values: labels}, Annotations: &admin.Annotations{Values: launchCtx.Annotations}, SecurityContext: &launchCtx.SecurityContext, MaxParallelism: int32(launchCtx.MaxParallelism), diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go index 89bb0e2477..2a442e3262 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/admin_test.go @@ -2,6 +2,7 @@ package launchplan import ( "context" + "reflect" "testing" "time" @@ -162,10 +163,14 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { ctx, mock.MatchedBy(func(o *admin.ExecutionCreateRequest) bool { return o.Project == "p" && o.Domain == "d" && o.Name == "n" && o.Spec.Inputs == nil && - o.Spec.Metadata.Mode == admin.ExecutionMetadata_CHILD_WORKFLOW + o.Spec.Metadata.Mode == admin.ExecutionMetadata_CHILD_WORKFLOW && + reflect.DeepEqual(o.Spec.Labels.Values, map[string]string{"foo": "bar"}) // Ensure shard-key was removed. }), ).Return(nil, nil) assert.NoError(t, err) + + var labels = map[string]string{"foo": "bar", "shard-key": "1"} + err = exec.Launch(ctx, LaunchContext{ ParentNodeExecution: &core.NodeExecutionIdentifier{ @@ -176,12 +181,15 @@ func TestAdminLaunchPlanExecutor_Launch(t *testing.T) { Name: "w", }, }, + Labels: labels, }, id, &core.Identifier{}, nil, ) assert.NoError(t, err) + // Ensure we haven't mutated the state of the parent workflow. + assert.True(t, reflect.DeepEqual(labels, map[string]string{"foo": "bar", "shard-key": "1"})) }) t.Run("happy recover", func(t *testing.T) {