From 4244f20d2137322b56277b9d478174de35c00772 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Mon, 7 Feb 2022 23:02:52 -0600 Subject: [PATCH] Track parallelism in k8s array plugin (#239) * add parallelism tracking in k8s array map task Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer * added unit tests Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer --- .../go/tasks/plugins/array/core/state.go | 5 +- .../go/tasks/plugins/array/core/state_test.go | 4 +- .../go/tasks/plugins/array/k8s/monitor.go | 66 +++++++++++------ .../tasks/plugins/array/k8s/monitor_test.go | 70 +++++++++++++++++-- .../go/tasks/plugins/array/k8s/task_test.go | 2 +- 5 files changed, 115 insertions(+), 32 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 67fea65142..cd12b76b1c 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -142,9 +142,8 @@ const ( func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.ArrayJob, error) { if structObj == nil { if taskTypeVersion == 0 { - return &idlPlugins.ArrayJob{ - Parallelism: 1, + Parallelism: 0, Size: 1, SuccessCriteria: &idlPlugins.ArrayJob_MinSuccesses{ MinSuccesses: 1, @@ -152,7 +151,7 @@ func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins. }, nil } return &idlPlugins.ArrayJob{ - Parallelism: 1, + Parallelism: 0, Size: 1, SuccessCriteria: &idlPlugins.ArrayJob_MinSuccessRatio{ MinSuccessRatio: 1.0, diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index 879e02d458..8f7af42d26 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -295,7 +295,7 @@ func TestToArrayJob(t *testing.T) { arrayJob, err := ToArrayJob(nil, 0) assert.NoError(t, err) assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ - Parallelism: 1, + Parallelism: 0, Size: 1, SuccessCriteria: &plugins.ArrayJob_MinSuccesses{ MinSuccesses: 1, @@ -307,7 +307,7 @@ func TestToArrayJob(t *testing.T) { arrayJob, err := ToArrayJob(nil, 1) assert.NoError(t, err) assert.True(t, proto.Equal(arrayJob, &plugins.ArrayJob{ - Parallelism: 1, + Parallelism: 0, Size: 1, SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ MinSuccessRatio: 1.0, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index 6e92d78291..55f42324e7 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -5,30 +5,27 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" - - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" "github.com/flyteorg/flytestdlib/bitarray" - - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" k8sTypes "k8s.io/apimachinery/pkg/types" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" errors2 "github.com/flyteorg/flytestdlib/errors" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - - "github.com/flyteorg/flyteplugins/go/tasks/logs" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" ) const ( @@ -88,6 +85,22 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return currentState, logLinks, subTaskIDs, err } + // identify max parallelism + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } else if taskTemplate == nil { + return currentState, logLinks, subTaskIDs, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + + arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } + + currentParallelism := 0 + maxParallelism := int(arrayJob.Parallelism) + for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) @@ -191,18 +204,27 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } return currentState, logLinks, subTaskIDs, err } - } - newState = newState.SetArrayStatus(*newArrayStatus) + // validate map task parallelism + newSubtaskPhase := core.Phases[newArrayStatus.Detailed.GetItem(childIdx)] + if !newSubtaskPhase.IsTerminal() || newSubtaskPhase == core.PhaseRetryableFailure { + currentParallelism++ + } - // Check that the taskTemplate is valid - taskTemplate, err := tCtx.TaskReader().Read(ctx) - if err != nil { - return currentState, logLinks, subTaskIDs, err - } else if taskTemplate == nil { - return currentState, logLinks, subTaskIDs, fmt.Errorf("required value not set, taskTemplate is nil") + if maxParallelism != 0 && currentParallelism >= maxParallelism { + // If max parallelism has been achieved we need to fill the subtask phase summary with + // the remaining subtasks so the overall map task phase can be accurately identified. + for i := childIdx + 1; i < len(currentState.GetArrayStatus().Detailed.GetItems()); i++ { + childSubtaskPhase := core.Phases[newArrayStatus.Detailed.GetItem(i)] + newArrayStatus.Summary.Inc(childSubtaskPhase) + } + + break + } } + newState = newState.SetArrayStatus(*newArrayStatus) + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) if phase == arrayCore.PhaseWriteToDiscoveryThenFail { errorMsg := msg.Summary(GetConfig().MaxErrorStringLength) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go index 87347a260a..282e0036fb 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go @@ -23,6 +23,8 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/stretchr/testify/assert" "golang.org/x/net/context" + + structpb "google.golang.org/protobuf/types/known/structpb" ) func createSampleContainerTask() *core2.Container { @@ -33,9 +35,14 @@ func createSampleContainerTask() *core2.Container { } } -func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContext { +func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.TaskExecutionContext { + customStruct, _ := structpb.NewStruct(map[string]interface{}{ + "parallelism": fmt.Sprintf("%d", parallelism), + }) + tr := &mocks.TaskReader{} tr.OnRead(ctx).Return(&core2.TaskTemplate{ + Custom: customStruct, Target: &core2.TaskTemplate_Container{ Container: createSampleContainerTask(), }, @@ -103,7 +110,7 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex func TestGetNamespaceForExecution(t *testing.T) { ctx := context.Background() - tCtx := getMockTaskExecutionContext(ctx) + tCtx := getMockTaskExecutionContext(ctx, 0) assert.Equal(t, GetNamespaceForExecution(tCtx, ""), tCtx.TaskExecutionMetadata().GetNamespace()) assert.Equal(t, GetNamespaceForExecution(tCtx, "abcd"), "abcd") @@ -122,7 +129,7 @@ func testSubTaskIDs(t *testing.T, actual []*string) { func TestCheckSubTasksState(t *testing.T) { ctx := context.Background() - tCtx := getMockTaskExecutionContext(ctx) + tCtx := getMockTaskExecutionContext(ctx, 0) kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) @@ -289,10 +296,65 @@ func TestCheckSubTasksState(t *testing.T) { }) } +func TestCheckSubTasksStateParallelism(t *testing.T) { + subtaskCount := 5 + + for i := 1; i <= subtaskCount; i++ { + t.Run(fmt.Sprintf("Parallelism-%d", i), func(t *testing.T) { + // construct task context + ctx := context.Background() + + tCtx := getMockTaskExecutionContext(ctx, i) + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) + tCtx.OnResourceManager().Return(&resourceManager) + + // evaluate one round of subtask launch and monitor + config := Config{ + MaxArrayJobSize: 100, + } + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subtaskCount), bitarray.Item(0)) + assert.NoError(t, err) + + cacheIndexes := bitarray.NewBitSet(uint(subtaskCount)) + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount * 2), + OriginalMinSuccesses: int64(subtaskCount * 2), + IndexesToCache: cacheIndexes, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(subtaskCount)), + }, + RetryAttempts: retryAttemptsArray, + }) + + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + + // validate only i subtasks were processed + executed := 0 + for _, existingPhaseIdx := range newState.GetArrayStatus().Detailed.GetItems() { + if core.Phases[existingPhaseIdx] != core.PhaseUndefined { + executed++ + } + } + + assert.Equal(t, i, executed) + }) + } +} + func TestCheckSubTasksStateResourceGranted(t *testing.T) { ctx := context.Background() - tCtx := getMockTaskExecutionContext(ctx) + tCtx := getMockTaskExecutionContext(ctx, 0) kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/task_test.go b/flyteplugins/go/tasks/plugins/array/k8s/task_test.go index 8a37be5b62..874caedacb 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/task_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/task_test.go @@ -20,7 +20,7 @@ import ( func TestFinalize(t *testing.T) { ctx := context.Background() - tCtx := getMockTaskExecutionContext(ctx) + tCtx := getMockTaskExecutionContext(ctx, 0) kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient())