diff --git a/go/tasks/logs/logging_utils.go b/go/tasks/logs/logging_utils.go index f6b3e06db..805c6273c 100755 --- a/go/tasks/logs/logging_utils.go +++ b/go/tasks/logs/logging_utils.go @@ -18,12 +18,7 @@ type logPlugin struct { } // Internal -func GetLogsForContainerInPod(ctx context.Context, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) { - logPlugin, err := InitializeLogPlugins(GetLogConfig()) - if err != nil { - return nil, err - } - +func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) { if logPlugin == nil { return nil, nil } diff --git a/go/tasks/logs/logging_utils_test.go b/go/tasks/logs/logging_utils_test.go index 3bc37a8e3..02f5d778b 100755 --- a/go/tasks/logs/logging_utils_test.go +++ b/go/tasks/logs/logging_utils_test.go @@ -15,29 +15,32 @@ import ( const podName = "PodName" func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{})) - l, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix") + logPlugin, err := InitializeLogPlugins(&LogConfig{}) + assert.NoError(t, err) + l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix") assert.NoError(t, err) assert.Nil(t, l) } func TestGetLogsForContainerInPod_NoLogs(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsCloudwatchEnabled: true, CloudwatchRegion: "us-east-1", CloudwatchLogGroup: "/kubernetes/flyte-production", - })) - p, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix") + }) + assert.NoError(t, err) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix") assert.NoError(t, err) assert.Nil(t, p) } func TestGetLogsForContainerInPod_BadIndex(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsCloudwatchEnabled: true, CloudwatchRegion: "us-east-1", CloudwatchLogGroup: "/kubernetes/flyte-production", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -57,17 +60,18 @@ func TestGetLogsForContainerInPod_BadIndex(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix") + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix") assert.NoError(t, err) assert.Nil(t, p) } func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsCloudwatchEnabled: true, CloudwatchRegion: "us-east-1", CloudwatchLogGroup: "/kubernetes/flyte-production", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -81,16 +85,17 @@ func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix") + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix") assert.NoError(t, err) assert.Nil(t, p) } func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{IsCloudwatchEnabled: true, + logPlugin, err := InitializeLogPlugins(&LogConfig{IsCloudwatchEnabled: true, CloudwatchRegion: "us-east-1", CloudwatchLogGroup: "/kubernetes/flyte-production", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -110,16 +115,17 @@ func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix") assert.Nil(t, err) assert.Len(t, logs, 1) } func TestGetLogsForContainerInPod_K8s(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsKubernetesEnabled: true, KubernetesURL: "k8s.com", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -139,19 +145,20 @@ func TestGetLogsForContainerInPod_K8s(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix") assert.Nil(t, err) assert.Len(t, logs, 1) } func TestGetLogsForContainerInPod_All(t *testing.T) { - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsKubernetesEnabled: true, KubernetesURL: "k8s.com", IsCloudwatchEnabled: true, CloudwatchRegion: "us-east-1", CloudwatchLogGroup: "/kubernetes/flyte-production", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -171,18 +178,18 @@ func TestGetLogsForContainerInPod_All(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix") assert.Nil(t, err) assert.Len(t, logs, 2) } func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) { - - assert.NoError(t, SetLogConfig(&LogConfig{ + logPlugin, err := InitializeLogPlugins(&LogConfig{ IsStackDriverEnabled: true, GCPProjectName: "myGCPProject", StackdriverLogResourceName: "aws_ec2_instance", - })) + }) + assert.NoError(t, err) pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -202,7 +209,7 @@ func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix") + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix") assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -252,7 +259,8 @@ func TestGetLogsForContainerInPod_LegacyTemplate(t *testing.T) { } func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*core.TaskLog) { - assert.NoError(tb, SetLogConfig(config)) + logPlugin, err := InitializeLogPlugins(config) + assert.NoError(tb, err) pod := &v1.Pod{ ObjectMeta: v12.ObjectMeta{ @@ -275,7 +283,7 @@ func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*c }, } - logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " my-Suffix") + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " my-Suffix") assert.Nil(tb, err) assert.Len(tb, logs, len(expectedTaskLogs)) if diff := deep.Equal(logs, expectedTaskLogs); len(diff) > 0 { diff --git a/go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go b/go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go index 36ada0c9d..120b9c6d7 100644 --- a/go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go +++ b/go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go @@ -6,9 +6,9 @@ import ( "reflect" "sync" - "k8s.io/apimachinery/pkg/api/meta" - + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -89,6 +89,20 @@ func (m *FakeKubeClient) Create(ctx context.Context, obj client.Object, opts ... m.syncObj.Lock() defer m.syncObj.Unlock() + // if obj is a *v1.Pod then append a ContainerStatus for each Container + pod, ok := obj.(*v1.Pod) + if ok { + for i := range pod.Spec.Containers { + if len(pod.Status.ContainerStatuses) > i { + continue + } + + pod.Status.ContainerStatuses = append(pod.Status.ContainerStatuses, v1.ContainerStatus{ + ContainerID: "docker://container-name", + }) + } + } + accessor, err := meta.Accessor(obj) if err != nil { return err diff --git a/go/tasks/plugins/array/k8s/executor.go b/go/tasks/plugins/array/k8s/executor.go index ceb760ca1..c4f2444d1 100644 --- a/go/tasks/plugins/array/k8s/executor.go +++ b/go/tasks/plugins/array/k8s/executor.go @@ -3,20 +3,19 @@ package k8s import ( "context" - "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/client" - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" ) const executorName = "k8s-array" @@ -145,18 +144,21 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c } func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { - return nil + pluginState := &arrayCore.State{} + if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil { + return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") + } + + return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState) } func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { - pluginConfig := GetConfig() - pluginState := &arrayCore.State{} if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil { return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } - return TerminateSubTasks(ctx, tCtx, e.kubeClient, pluginConfig, pluginState) + return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState) } func (e Executor) Start(ctx context.Context) error { diff --git a/go/tasks/plugins/array/k8s/integration_test.go b/go/tasks/plugins/array/k8s/integration_test.go index 9b76265c5..17edf2077 100644 --- a/go/tasks/plugins/array/k8s/integration_test.go +++ b/go/tasks/plugins/array/k8s/integration_test.go @@ -1,33 +1,29 @@ package k8s import ( + "context" "strconv" "testing" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - - "github.com/flyteorg/flytestdlib/storage" + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - - "context" - - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "sigs.k8s.io/controller-runtime/pkg/client" - + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" - "github.com/flyteorg/flyteidl/clients/go/coreutils" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "sigs.k8s.io/controller-runtime/pkg/client" ) func init() { diff --git a/go/tasks/plugins/array/k8s/launcher.go b/go/tasks/plugins/array/k8s/launcher.go deleted file mode 100644 index 3b2d9510a..000000000 --- a/go/tasks/plugins/array/k8s/launcher.go +++ /dev/null @@ -1,101 +0,0 @@ -package k8s - -import ( - "context" - "fmt" - "strconv" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" - - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - - errors2 "github.com/flyteorg/flytestdlib/errors" - - corev1 "k8s.io/api/core/v1" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" -) - -const ( - ErrBuildPodTemplate errors2.ErrorCode = "POD_TEMPLATE_FAILED" - ErrReplaceCmdTemplate errors2.ErrorCode = "CMD_TEMPLATE_FAILED" - ErrSubmitJob errors2.ErrorCode = "SUBMIT_JOB_FAILED" - ErrGetTaskTypeVersion errors2.ErrorCode = "GET_TASK_TYPE_VERSION_FAILED" - JobIndexVarName string = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" - FlyteK8sArrayIndexVarName string = "FLYTE_K8S_ARRAY_INDEX" -) - -var arrayJobEnvVars = []corev1.EnvVar{ - { - Name: JobIndexVarName, - Value: FlyteK8sArrayIndexVarName, - }, -} - -func formatSubTaskName(_ context.Context, parentName string, index int, retryAttempt uint64) (subTaskName string) { - indexStr := strconv.Itoa(index) - - // If the retryAttempt is 0 we do not include it in the pod name. The gives us backwards - // compatibility in the ability to dynamically transition running map tasks to use subtask retries. - if retryAttempt == 0 { - return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v", parentName, indexStr)) - } - - retryAttemptStr := strconv.FormatUint(retryAttempt, 10) - return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v-%v", parentName, indexStr, retryAttemptStr)) -} - -func ApplyPodPolicies(_ context.Context, cfg *Config, pod *corev1.Pod) *corev1.Pod { - if len(cfg.DefaultScheduler) > 0 { - pod.Spec.SchedulerName = cfg.DefaultScheduler - } - - return pod -} - -func applyNodeSelectorLabels(_ context.Context, cfg *Config, pod *corev1.Pod) *corev1.Pod { - if len(cfg.NodeSelector) != 0 { - pod.Spec.NodeSelector = cfg.NodeSelector - } - - return pod -} - -func applyPodTolerations(_ context.Context, cfg *Config, pod *corev1.Pod) *corev1.Pod { - if len(cfg.Tolerations) != 0 { - pod.Spec.Tolerations = cfg.Tolerations - } - - return pod -} - -func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, - currentState *arrayCore.State) error { - - size := currentState.GetExecutionArraySize() - errs := errorcollector.NewErrorMessageCollector() - for childIdx := 0; childIdx < size; childIdx++ { - task := Task{ - ChildIdx: childIdx, - Config: config, - State: currentState, - } - - err := task.Abort(ctx, tCtx, kubeClient) - if err != nil { - errs.Collect(childIdx, err.Error()) - } - err = task.Finalize(ctx, tCtx, kubeClient) - if err != nil { - errs.Collect(childIdx, err.Error()) - } - } - - if errs.Length() > 0 { - return fmt.Errorf(errs.Summary(config.MaxErrorStringLength)) - } - - return nil -} diff --git a/go/tasks/plugins/array/k8s/launcher_test.go b/go/tasks/plugins/array/k8s/launcher_test.go deleted file mode 100644 index 729f4c0ea..000000000 --- a/go/tasks/plugins/array/k8s/launcher_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package k8s - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" -) - -func TestApplyNodeSelectorLabels(t *testing.T) { - ctx := context.Background() - cfg := &Config{ - NodeSelector: map[string]string{ - "disktype": "ssd", - }, - } - pod := &corev1.Pod{} - - pod = applyNodeSelectorLabels(ctx, cfg, pod) - - assert.Equal(t, pod.Spec.NodeSelector, cfg.NodeSelector) -} - -func TestApplyPodTolerations(t *testing.T) { - ctx := context.Background() - cfg := &Config{ - Tolerations: []v1.Toleration{{ - Key: "reserved", - Operator: "equal", - Value: "value", - Effect: "NoSchedule", - }}, - } - pod := &corev1.Pod{} - - pod = applyPodTolerations(ctx, cfg, pod) - - assert.Equal(t, pod.Spec.Tolerations, cfg.Tolerations) -} - -func TestFormatSubTaskName(t *testing.T) { - ctx := context.Background() - parentName := "foo" - - tests := []struct { - index int - retryAttempt uint64 - want string - }{ - {0, 0, fmt.Sprintf("%v-%v", parentName, 0)}, - {1, 0, fmt.Sprintf("%v-%v", parentName, 1)}, - {0, 1, fmt.Sprintf("%v-%v-%v", parentName, 0, 1)}, - {1, 1, fmt.Sprintf("%v-%v-%v", parentName, 1, 1)}, - } - - for i, tt := range tests { - t.Run(fmt.Sprintf("format-subtask-name-%v", i), func(t *testing.T) { - assert.Equal(t, tt.want, formatSubTaskName(ctx, parentName, tt.index, tt.retryAttempt)) - }) - } -} diff --git a/go/tasks/plugins/array/k8s/management.go b/go/tasks/plugins/array/k8s/management.go new file mode 100644 index 000000000..e5b7dd2cf --- /dev/null +++ b/go/tasks/plugins/array/k8s/management.go @@ -0,0 +1,301 @@ +package k8s + +import ( + "context" + "fmt" + "time" + + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" + + "github.com/flyteorg/flytestdlib/bitarray" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" +) + +// allocateResource attempts to allot resources for the specified parameter with the +// TaskExecutionContexts ResourceManager. +func allocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) (core.AllocationStatus, error) { + if !IsResourceConfigSet(config.ResourceConfig) { + return core.AllocationStatusGranted, nil + } + + resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) + resourceConstraintSpec := core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: nil, + NamespaceScopeResourceConstraint: nil, + } + + allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, resourceNamespace, podName, resourceConstraintSpec) + if err != nil { + return core.AllocationUndefined, err + } + + return allocationStatus, nil +} + +// deallocateResource attempts to release resources for the specified parameter with the +// TaskExecutionContexts ResourceManager. +func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) error { + if !IsResourceConfigSet(config.ResourceConfig) { + return nil + } + resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) + + err := tCtx.ResourceManager().ReleaseResource(ctx, resourceNamespace, podName) + if err != nil { + logger.Errorf(ctx, "Error releasing token [%s]. error %s", podName, err) + return err + } + + return nil +} + +// LaunchAndCheckSubTasksState iterates over each subtask performing operations to transition them +// to a terminal state. This may include creating new k8s resources, monitoring existing k8s +// resources, retrying failed attempts, or declaring a permanent failure among others. +func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, + config *Config, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) ( + newState *arrayCore.State, logLinks []*idlCore.TaskLog, subTaskIDs []*string, err error) { + if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize { + ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize) + logger.Info(ctx, ee) + currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error()) + return currentState, logLinks, subTaskIDs, nil + } + + logLinks = make([]*idlCore.TaskLog, 0, 4) + newState = currentState + messageCollector := errorcollector.NewErrorMessageCollector() + newArrayStatus := &arraystatus.ArrayStatus{ + Summary: arraystatus.ArraySummary{}, + Detailed: arrayCore.NewPhasesCompactArray(uint(currentState.GetExecutionArraySize())), + } + subTaskIDs = make([]*string, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) + + // If we have arrived at this state for the first time then currentState has not been + // initialized with number of sub tasks. + if len(currentState.GetArrayStatus().Detailed.GetItems()) == 0 { + currentState.ArrayStatus = *newArrayStatus + } + + // If the current State is newly minted then we must initialize RetryAttempts to track how many + // times each subtask is executed. + if len(currentState.RetryAttempts.GetItems()) == 0 { + count := uint(currentState.GetExecutionArraySize()) + maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetMaxAttempts()) + + retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue) + if err != nil { + logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) + return currentState, logLinks, subTaskIDs, nil + } + + // Initialize subtask retryAttempts to 0 so that, in tandem with the podName logic, we + // maintain backwards compatibility. + for i := 0; i < currentState.GetExecutionArraySize(); i++ { + retryAttemptsArray.SetItem(i, 0) + } + + currentState.RetryAttempts = retryAttemptsArray + } + + // initialize log plugin + logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } + + // identify max parallelism + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } else if taskTemplate == nil { + return currentState, logLinks, subTaskIDs, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + + arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } + + currentParallelism := 0 + maxParallelism := int(arrayJob.Parallelism) + + for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { + existingPhase := core.Phases[existingPhaseIdx] + retryAttempt := currentState.RetryAttempts.GetItem(childIdx) + + if existingPhase == core.PhaseRetryableFailure { + retryAttempt++ + newState.RetryAttempts.SetItem(childIdx, retryAttempt) + } else if existingPhase.IsTerminal() { + newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase)) + continue + } + + originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) + stCtx := newSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt) + podName := stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + // depending on the existing subtask phase we either a launch new k8s resource or monitor + // an existing instance + var phaseInfo core.PhaseInfo + var perr error + if existingPhase == core.PhaseUndefined || existingPhase == core.PhaseWaitingForResources || existingPhase == core.PhaseRetryableFailure { + // attempt to allocateResource + allocationStatus, err := allocateResource(ctx, stCtx, config, podName) + if err != nil { + logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", + stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), podName, err) + return currentState, logLinks, subTaskIDs, err + } + + logger.Infof(ctx, "Allocation result for [%s] is [%s]", podName, allocationStatus) + if allocationStatus != core.AllocationStatusGranted { + phaseInfo = core.PhaseInfoWaitingForResourcesInfo(time.Now(), core.DefaultPhaseVersion, "Exceeded ResourceManager quota", nil) + } else { + phaseInfo, perr = launchSubtask(ctx, stCtx, config, kubeClient) + + // if launchSubtask fails we attempt to deallocate the (previously allocated) + // resource to mitigate leaks + if perr != nil { + perr = deallocateResource(ctx, stCtx, config, podName) + if perr != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err) + } + } + } + } else { + phaseInfo, perr = getSubtaskPhaseInfo(ctx, stCtx, config, kubeClient, logPlugin) + } + + // validate and process phaseInfo and perr + if perr != nil { + return currentState, logLinks, subTaskIDs, perr + } + + if phaseInfo.Err() != nil { + messageCollector.Collect(childIdx, phaseInfo.Err().String()) + } + + subTaskIDs = append(subTaskIDs, &podName) + if phaseInfo.Info() != nil { + logLinks = append(logLinks, phaseInfo.Info().Logs...) + } + + // process subtask phase + actualPhase := phaseInfo.Phase() + if actualPhase.IsSuccess() { + actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, childIdx, originalIdx) + if err != nil { + return currentState, logLinks, subTaskIDs, err + } + } + + if actualPhase == core.PhaseRetryableFailure && uint32(retryAttempt+1) >= stCtx.TaskExecutionMetadata().GetMaxAttempts() { + // If we see a retryable failure we must check if the number of retries exceeds the maximum + // attempts. If so, transition to a permanent failure so that is not attempted again. + newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhasePermanentFailure)) + } else { + newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(actualPhase)) + } + + if actualPhase.IsTerminal() { + err = deallocateResource(ctx, stCtx, config, podName) + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err) + return currentState, logLinks, subTaskIDs, err + } + + err = finalizeSubtask(ctx, stCtx, config, kubeClient) + if err != nil { + logger.Errorf(ctx, "Error finalizing resource [%s] in Finalize [%s]", podName, err) + return currentState, logLinks, subTaskIDs, err + } + } + + // validate parallelism + if !actualPhase.IsTerminal() || actualPhase == core.PhaseRetryableFailure { + currentParallelism++ + } + + if maxParallelism != 0 && currentParallelism >= maxParallelism { + break + } + } + + // compute task phase from array status summary + for _, phaseIdx := range newArrayStatus.Detailed.GetItems() { + newArrayStatus.Summary.Inc(core.Phases[phaseIdx]) + } + + phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) + + // process new state + newState = newState.SetArrayStatus(*newArrayStatus) + if phase == arrayCore.PhaseWriteToDiscoveryThenFail { + errorMsg := messageCollector.Summary(GetConfig().MaxErrorStringLength) + newState = newState.SetReason(errorMsg) + } + + if phase == arrayCore.PhaseCheckingSubTaskExecutions { + newPhaseVersion := uint32(0) + + // For now, the only changes to PhaseVersion and PreviousSummary occur for running array jobs. + for phase, count := range newState.GetArrayStatus().Summary { + newPhaseVersion += uint32(phase) * uint32(count) + } + + newState = newState.SetPhase(phase, newPhaseVersion).SetReason("Task is still running.") + } else { + newState = newState.SetPhase(phase, core.DefaultPhaseVersion) + } + + return newState, logLinks, subTaskIDs, nil +} + +// TerminateSubTasks performs operations to gracefully terminate all subtasks. This may include +// aborting and finalizing active k8s resources. +func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, + terminateFunction func(context.Context, SubTaskExecutionContext, *Config, core.KubeClient) error, currentState *arrayCore.State) error { + + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } else if taskTemplate == nil { + return errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") + } + + messageCollector := errorcollector.NewErrorMessageCollector() + for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { + existingPhase := core.Phases[existingPhaseIdx] + retryAttempt := currentState.RetryAttempts.GetItem(childIdx) + + // return immediately if subtask has completed or not yet started + if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { + continue + } + + originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) + stCtx := newSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt) + + err := terminateFunction(ctx, stCtx, config, kubeClient) + if err != nil { + messageCollector.Collect(childIdx, err.Error()) + } + } + + if messageCollector.Length() > 0 { + return fmt.Errorf(messageCollector.Summary(config.MaxErrorStringLength)) + } + + return nil +} diff --git a/go/tasks/plugins/array/k8s/management_test.go b/go/tasks/plugins/array/k8s/management_test.go new file mode 100644 index 000000000..8e5cd6b1e --- /dev/null +++ b/go/tasks/plugins/array/k8s/management_test.go @@ -0,0 +1,501 @@ +package k8s + +import ( + "fmt" + "testing" + + core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + + "github.com/flyteorg/flytestdlib/bitarray" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "golang.org/x/net/context" + + structpb "google.golang.org/protobuf/types/known/structpb" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func createSampleContainerTask() *core2.Container { + return &core2.Container{ + Command: []string{"cmd"}, + Args: []string{"{{$inputPrefix}}"}, + Image: "img1", + } +} + +func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.TaskExecutionContext { + customStruct, _ := structpb.NewStruct(map[string]interface{}{ + "parallelism": fmt.Sprintf("%d", parallelism), + }) + + tr := &mocks.TaskReader{} + tr.OnRead(ctx).Return(&core2.TaskTemplate{ + Custom: customStruct, + Target: &core2.TaskTemplate_Container{ + Container: createSampleContainerTask(), + }, + }, nil) + + tID := &mocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("notfound") + tID.OnGetID().Return(core2.TaskExecutionIdentifier{ + TaskId: &core2.Identifier{ + ResourceType: core2.ResourceType_TASK, + Project: "a", + Domain: "d", + Name: "n", + Version: "abc", + }, + NodeExecutionId: &core2.NodeExecutionIdentifier{ + NodeId: "node1", + ExecutionId: &core2.WorkflowExecutionIdentifier{ + Project: "a", + Domain: "d", + Name: "exec", + }, + }, + RetryAttempt: 0, + }) + + overrides := &mocks.TaskOverrides{} + overrides.OnGetResources().Return(&v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("10"), + }, + }) + + tMeta := &mocks.TaskExecutionMetadata{} + tMeta.OnGetTaskExecutionID().Return(tID) + tMeta.OnGetOverrides().Return(overrides) + tMeta.OnIsInterruptible().Return(false) + tMeta.OnGetK8sServiceAccount().Return("s") + tMeta.OnGetSecurityContext().Return(core2.SecurityContext{}) + + tMeta.OnGetMaxAttempts().Return(2) + tMeta.OnGetNamespace().Return("n") + tMeta.OnGetLabels().Return(nil) + tMeta.OnGetAnnotations().Return(nil) + tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{}) + tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + + ow := &mocks2.OutputWriter{} + ow.OnGetOutputPrefixPath().Return("/prefix/") + ow.OnGetRawOutputPrefix().Return("/raw_prefix/") + ow.OnGetCheckpointPrefix().Return("/checkpoint") + ow.OnGetPreviousCheckpointsPrefix().Return("/prev") + + ir := &mocks2.InputReader{} + ir.OnGetInputPrefixPath().Return("/prefix/") + ir.OnGetInputPath().Return("/prefix/inputs.pb") + ir.OnGetMatch(mock.Anything).Return(&core2.LiteralMap{}, nil) + + tCtx := &mocks.TaskExecutionContext{} + tCtx.OnTaskReader().Return(tr) + tCtx.OnTaskExecutionMetadata().Return(tMeta) + tCtx.OnOutputWriter().Return(ow) + tCtx.OnInputReader().Return(ir) + return tCtx +} + +func TestCheckSubTasksState(t *testing.T) { + ctx := context.Background() + subtaskCount := 5 + + config := Config{ + MaxArrayJobSize: int64(subtaskCount * 10), + ResourceConfig: ResourceConfig{ + PrimaryLabel: "p", + Limit: subtaskCount, + }, + } + + fakeKubeClient := mocks.NewFakeKubeClient() + fakeKubeCache := mocks.NewFakeKubeCache() + + for i := 0; i < subtaskCount; i++ { + pod := flytek8s.BuildIdentityPod() + pod.SetName(fmt.Sprintf("notfound-%d", i)) + pod.SetNamespace("a-n-b") + pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{Name: "foo"}) + + pod.Status.Phase = v1.PodRunning + _ = fakeKubeClient.Create(ctx, pod) + _ = fakeKubeCache.Create(ctx, pod) + } + + failureFakeKubeClient := mocks.NewFakeKubeClient() + failureFakeKubeCache := mocks.NewFakeKubeCache() + + for i := 0; i < subtaskCount; i++ { + pod := flytek8s.BuildIdentityPod() + pod.SetName(fmt.Sprintf("notfound-%d", i)) + pod.SetNamespace("a-n-b") + pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{Name: "foo"}) + + pod.Status.Phase = v1.PodFailed + _ = failureFakeKubeClient.Create(ctx, pod) + _ = failureFakeKubeCache.Create(ctx, pod) + } + + t.Run("Launch", func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(subtaskCount)), // set all tasks to core.PhaseUndefined + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", subtaskCount) + for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhaseQueued, core.Phases[subtaskPhaseIndex]) + } + }) + + for i := 1; i <= subtaskCount; i++ { + t.Run(fmt.Sprintf("LaunchParallelism%d", i), func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + + tCtx := getMockTaskExecutionContext(ctx, i) + tCtx.OnResourceManager().Return(&resourceManager) + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(subtaskCount)), // set all tasks to core.PhaseUndefined + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + + executed := 0 + for _, existingPhaseIdx := range newState.GetArrayStatus().Detailed.GetItems() { + if core.Phases[existingPhaseIdx] != core.PhaseUndefined { + executed++ + } + } + + assert.Equal(t, i, executed) + }) + } + + t.Run("LaunchResourcesExhausted", func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(subtaskCount)), // set all tasks to core.PhaseUndefined + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseWaitingForResources.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", subtaskCount) + for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhaseWaitingForResources, core.Phases[subtaskPhaseIndex]) + } + + // execute again - with resources available and validate results + nresourceManager := mocks.ResourceManager{} + nresourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + + ntCtx := getMockTaskExecutionContext(ctx, 0) + ntCtx.OnResourceManager().Return(&nresourceManager) + + lastState, _, _, err := LaunchAndCheckSubTasksState(ctx, ntCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", newState) + assert.Nil(t, err) + np, _ := lastState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), np.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", subtaskCount) + for _, subtaskPhaseIndex := range lastState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhaseQueued, core.Phases[subtaskPhaseIndex]) + } + }) + + t.Run("LaunchRetryableFailures", func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(fakeKubeClient) + kubeClient.OnGetCache().Return(fakeKubeCache) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + detailed := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + detailed.SetItem(i, bitarray.Item(core.PhaseRetryableFailure)) // set all tasks to core.PhaseRetryableFailure + } + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subtaskCount), bitarray.Item(1)) + assert.NoError(t, err) + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailed, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + RetryAttempts: retryAttemptsArray, + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "AllocateResource", subtaskCount) + for i, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhaseQueued, core.Phases[subtaskPhaseIndex]) + assert.Equal(t, bitarray.Item(1), newState.RetryAttempts.GetItem(i)) + } + }) + + t.Run("RunningLogLinksAndSubtaskIDs", func(t *testing.T) { + // initialize metadata + config := Config{ + MaxArrayJobSize: 100, + MaxErrorStringLength: 200, + NamespaceTemplate: "a-{{.namespace}}-b", + OutputAssembler: workqueue.Config{ + Workers: 2, + MaxRetries: 0, + IndexCacheMaxItems: 100, + }, + ErrorAssembler: workqueue.Config{ + Workers: 2, + MaxRetries: 0, + IndexCacheMaxItems: 100, + }, + LogConfig: LogConfig{ + Config: logs.LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchTemplateURI: "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.{{ .podName }};streamFilter=typeLogStreamPrefix", + IsKubernetesEnabled: true, + KubernetesTemplateURI: "k8s/log/{{.namespace}}/{{.podName}}/pod?namespace={{.namespace}}", + }}, + } + + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(fakeKubeClient) + kubeClient.OnGetCache().Return(fakeKubeCache) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + detailed := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + detailed.SetItem(i, bitarray.Item(core.PhaseRunning)) // set all tasks to core.PhaseRunning + } + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailed, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + } + + // execute + newState, logLinks, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + + resourceManager.AssertNumberOfCalls(t, "AllocateResource", 0) + resourceManager.AssertNumberOfCalls(t, "ReleaseResource", 0) + + assert.NotEmpty(t, logLinks) + assert.Equal(t, subtaskCount*2, len(logLinks)) + for i := 0; i < subtaskCount*2; i = i + 2 { + assert.Equal(t, fmt.Sprintf("Kubernetes Logs #0-%d (PhaseRunning)", i/2), logLinks[i].Name) + assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d/pod?namespace=a-n-b", i/2), logLinks[i].Uri) + + assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #0-%d (PhaseRunning)", i/2), logLinks[i+1].Name) + assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri) + } + + for i := 0; i < subtaskCount; i++ { + assert.Equal(t, fmt.Sprintf("notfound-%d", i), *subTaskIDs[i]) + } + }) + + t.Run("RunningRetryableFailures", func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(failureFakeKubeClient) + kubeClient.OnGetCache().Return(failureFakeKubeCache) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + detailed := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + detailed.SetItem(i, bitarray.Item(core.PhaseRunning)) // set all tasks to core.PhaseRunning + } + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subtaskCount), bitarray.Item(1)) + assert.NoError(t, err) + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailed, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + RetryAttempts: retryAttemptsArray, + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) + for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhaseRetryableFailure, core.Phases[subtaskPhaseIndex]) + } + }) + + t.Run("RunningPermanentFailures", func(t *testing.T) { + // initialize metadata + kubeClient := mocks.KubeClient{} + kubeClient.OnGetClient().Return(failureFakeKubeClient) + kubeClient.OnGetCache().Return(failureFakeKubeCache) + + resourceManager := mocks.ResourceManager{} + resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + tCtx := getMockTaskExecutionContext(ctx, 0) + tCtx.OnResourceManager().Return(&resourceManager) + + detailed := arrayCore.NewPhasesCompactArray(uint(subtaskCount)) + for i := 0; i < subtaskCount; i++ { + detailed.SetItem(i, bitarray.Item(core.PhaseRunning)) // set all tasks to core.PhaseRunning + } + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subtaskCount), bitarray.Item(1)) + assert.NoError(t, err) + + for i := 0; i < subtaskCount; i++ { + retryAttemptsArray.SetItem(i, bitarray.Item(1)) + } + + currentState := &arrayCore.State{ + CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, + ExecutionArraySize: subtaskCount, + OriginalArraySize: int64(subtaskCount), + OriginalMinSuccesses: int64(subtaskCount), + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailed, + }, + IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached + RetryAttempts: retryAttemptsArray, + } + + // execute + newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", currentState) + + // validate results + assert.Nil(t, err) + p, _ := newState.GetPhase() + assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail.String(), p.String()) + resourceManager.AssertNumberOfCalls(t, "ReleaseResource", subtaskCount) + for _, subtaskPhaseIndex := range newState.GetArrayStatus().Detailed.GetItems() { + assert.Equal(t, core.PhasePermanentFailure, core.Phases[subtaskPhaseIndex]) + } + }) +} diff --git a/go/tasks/plugins/array/k8s/monitor.go b/go/tasks/plugins/array/k8s/monitor.go deleted file mode 100644 index 55f42324e..000000000 --- a/go/tasks/plugins/array/k8s/monitor.go +++ /dev/null @@ -1,348 +0,0 @@ -package k8s - -import ( - "context" - "fmt" - "time" - - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyteplugins/go/tasks/logs" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" - - "github.com/flyteorg/flytestdlib/bitarray" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" - - v1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" - k8sTypes "k8s.io/apimachinery/pkg/types" - - errors2 "github.com/flyteorg/flytestdlib/errors" -) - -const ( - ErrCheckPodStatus errors2.ErrorCode = "CHECK_POD_FAILED" -) - -func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, - config *Config, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) ( - newState *arrayCore.State, logLinks []*idlCore.TaskLog, subTaskIDs []*string, err error) { - if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize { - ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize) - logger.Info(ctx, ee) - currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error()) - return currentState, logLinks, subTaskIDs, nil - } - - logLinks = make([]*idlCore.TaskLog, 0, 4) - newState = currentState - msg := errorcollector.NewErrorMessageCollector() - newArrayStatus := &arraystatus.ArrayStatus{ - Summary: arraystatus.ArraySummary{}, - Detailed: arrayCore.NewPhasesCompactArray(uint(currentState.GetExecutionArraySize())), - } - subTaskIDs = make([]*string, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) - - // If we have arrived at this state for the first time then currentState has not been - // initialized with number of sub tasks. - if len(currentState.GetArrayStatus().Detailed.GetItems()) == 0 { - currentState.ArrayStatus = *newArrayStatus - } - - // If the current State is newly minted then we must initialize RetryAttempts to track how many - // times each subtask is executed. - if len(currentState.RetryAttempts.GetItems()) == 0 { - count := uint(currentState.GetExecutionArraySize()) - maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetMaxAttempts()) - - retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue) - if err != nil { - logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) - return currentState, logLinks, subTaskIDs, nil - } - - // Set subtask retryAttempts using the existing task context retry attempt. For new tasks - // this will initialize to 0, but running tasks will use the existing retry attempt. - retryAttempt := bitarray.Item(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt) - for i := 0; i < currentState.GetExecutionArraySize(); i++ { - retryAttemptsArray.SetItem(i, retryAttempt) - } - - currentState.RetryAttempts = retryAttemptsArray - } - - logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) - if err != nil { - logger.Errorf(ctx, "Error initializing LogPlugins: [%s]", err) - return currentState, logLinks, subTaskIDs, err - } - - // identify max parallelism - taskTemplate, err := tCtx.TaskReader().Read(ctx) - if err != nil { - return currentState, logLinks, subTaskIDs, err - } else if taskTemplate == nil { - return currentState, logLinks, subTaskIDs, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") - } - - arrayJob, err := arrayCore.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) - if err != nil { - return currentState, logLinks, subTaskIDs, err - } - - currentParallelism := 0 - maxParallelism := int(arrayJob.Parallelism) - - for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { - existingPhase := core.Phases[existingPhaseIdx] - originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) - - retryAttempt := currentState.RetryAttempts.GetItem(childIdx) - podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), childIdx, retryAttempt) - - if existingPhase.IsTerminal() { - // If we get here it means we have already "processed" this terminal phase since we will only persist - // the phase after all processing is done (e.g. check outputs/errors file, record events... etc.). - - // Since we know we have already "processed" this terminal phase we can safely deallocate resource - err = deallocateResource(ctx, tCtx, config, podName) - if err != nil { - logger.Errorf(ctx, "Error releasing allocation token [%s] in LaunchAndCheckSubTasks [%s]", podName, err) - return currentState, logLinks, subTaskIDs, errors2.Wrapf(ErrCheckPodStatus, err, "Error releasing allocation token.") - } - - // If a subtask is marked as a retryable failure we check if the number of retries - // exceeds the maximum attempts. If so, transition the task to a permanent failure - // so that is not attempted again. If it can be retried, increment the retry attempts - // value and transition the task to "Undefined" so that it is reevaluated. - if existingPhase == core.PhaseRetryableFailure { - if uint32(retryAttempt+1) < tCtx.TaskExecutionMetadata().GetMaxAttempts() { - newState.RetryAttempts.SetItem(childIdx, retryAttempt+1) - - newArrayStatus.Summary.Inc(core.PhaseUndefined) - newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseUndefined)) - continue - } else { - existingPhase = core.PhasePermanentFailure - } - } - - newArrayStatus.Summary.Inc(existingPhase) - newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase)) - - phaseInfo, err := FetchPodStatusAndLogs(ctx, kubeClient, - k8sTypes.NamespacedName{ - Name: podName, - Namespace: GetNamespaceForExecution(tCtx, config.NamespaceTemplate), - }, - originalIdx, - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, - retryAttempt, - logPlugin) - - if err != nil { - return currentState, logLinks, subTaskIDs, err - } - - if phaseInfo.Info() != nil { - logLinks = append(logLinks, phaseInfo.Info().Logs...) - } - - continue - } - - task := &Task{ - State: newState, - NewArrayStatus: newArrayStatus, - Config: config, - ChildIdx: childIdx, - OriginalIndex: originalIdx, - MessageCollector: &msg, - SubTaskIDs: subTaskIDs, - } - - // The first time we enter this state we will launch every subtask. On subsequent rounds, the pod - // has already been created so we return a Success value and continue with the Monitor step. - var launchResult LaunchResult - launchResult, err = task.Launch(ctx, tCtx, kubeClient) - if err != nil { - logger.Errorf(ctx, "K8s array - Launch error %v", err) - return currentState, logLinks, subTaskIDs, err - } - - switch launchResult { - case LaunchSuccess: - // Continue with execution if successful - case LaunchError: - return currentState, logLinks, subTaskIDs, err - // If Resource manager is enabled and there are currently not enough resources we can skip this round - // for a subtask and wait until there are enough resources. - case LaunchWaiting: - continue - case LaunchReturnState: - return currentState, logLinks, subTaskIDs, nil - } - - monitorResult, taskLogs, err := task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox, logPlugin) - - if len(taskLogs) > 0 { - logLinks = append(logLinks, taskLogs...) - } - subTaskIDs = task.SubTaskIDs - - if monitorResult != MonitorSuccess { - if err != nil { - logger.Errorf(ctx, "K8s array - Monitor error %v", err) - } - return currentState, logLinks, subTaskIDs, err - } - - // validate map task parallelism - newSubtaskPhase := core.Phases[newArrayStatus.Detailed.GetItem(childIdx)] - if !newSubtaskPhase.IsTerminal() || newSubtaskPhase == core.PhaseRetryableFailure { - currentParallelism++ - } - - if maxParallelism != 0 && currentParallelism >= maxParallelism { - // If max parallelism has been achieved we need to fill the subtask phase summary with - // the remaining subtasks so the overall map task phase can be accurately identified. - for i := childIdx + 1; i < len(currentState.GetArrayStatus().Detailed.GetItems()); i++ { - childSubtaskPhase := core.Phases[newArrayStatus.Detailed.GetItem(i)] - newArrayStatus.Summary.Inc(childSubtaskPhase) - } - - break - } - } - - newState = newState.SetArrayStatus(*newArrayStatus) - - phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) - if phase == arrayCore.PhaseWriteToDiscoveryThenFail { - errorMsg := msg.Summary(GetConfig().MaxErrorStringLength) - newState = newState.SetReason(errorMsg) - } - - if phase == arrayCore.PhaseCheckingSubTaskExecutions { - newPhaseVersion := uint32(0) - - // For now, the only changes to PhaseVersion and PreviousSummary occur for running array jobs. - for phase, count := range newState.GetArrayStatus().Summary { - newPhaseVersion += uint32(phase) * uint32(count) - } - - newState = newState.SetPhase(phase, newPhaseVersion).SetReason("Task is still running.") - } else { - newState = newState.SetPhase(phase, core.DefaultPhaseVersion) - } - - return newState, logLinks, subTaskIDs, nil -} - -func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName, index int, retryAttempt uint32, subtaskRetryAttempt uint64, logPlugin tasklog.Plugin) ( - info core.PhaseInfo, err error) { - - pod := &v1.Pod{ - TypeMeta: metaV1.TypeMeta{ - Kind: PodKind, - APIVersion: v1.SchemeGroupVersion.String(), - }, - } - - err = client.GetClient().Get(ctx, name, pod) - now := time.Now() - - if err != nil { - if k8serrors.IsNotFound(err) { - // If the object disappeared at this point, it means it was manually removed or garbage collected. - // Mark it as a failure. - return core.PhaseInfoFailed(core.PhaseRetryableFailure, &idlCore.ExecutionError{ - Code: string(k8serrors.ReasonForError(err)), - Message: err.Error(), - Kind: idlCore.ExecutionError_SYSTEM, - }, &core.TaskInfo{ - OccurredAt: &now, - }), nil - } - - return info, err - } - - t := flytek8s.GetLastTransitionOccurredAt(pod).Time - taskInfo := core.TaskInfo{ - OccurredAt: &t, - } - - if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - // We append the subtaskRetryAttempt to the log name only when it is > 0 to ensure backwards - // compatibility when dynamically transitioning running map tasks to use subtask retry attempts. - var logName string - if subtaskRetryAttempt == 0 { - logName = fmt.Sprintf(" #%d-%d", retryAttempt, index) - } else { - logName = fmt.Sprintf(" #%d-%d-%d", retryAttempt, index, subtaskRetryAttempt) - } - - if logPlugin != nil { - o, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: pod.Name, - Namespace: pod.Namespace, - LogName: logName, - PodUnixStartTime: pod.CreationTimestamp.Unix(), - }) - - if err != nil { - return core.PhaseInfoUndefined, err - } - taskInfo.Logs = o.TaskLogs - } - } - - var phaseInfo core.PhaseInfo - var err2 error - - switch pod.Status.Phase { - case v1.PodSucceeded: - phaseInfo, err2 = flytek8s.DemystifySuccess(pod.Status, taskInfo) - case v1.PodFailed: - phaseInfo, err2 = flytek8s.DemystifyFailure(pod.Status, taskInfo) - case v1.PodPending: - phaseInfo, err2 = flytek8s.DemystifyPending(pod.Status) - case v1.PodUnknown: - phaseInfo = core.PhaseInfoUndefined - default: - primaryContainerName, ok := pod.GetAnnotations()[primaryContainerKey] - if ok { - // Special handling for determining the phase of an array job for a Pod task. - phaseInfo = flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &taskInfo) - if phaseInfo.Phase() == core.PhaseRunning && len(taskInfo.Logs) > 0 { - return core.PhaseInfoRunning(core.DefaultPhaseVersion+1, phaseInfo.Info()), nil - } - return phaseInfo, nil - } - - if len(taskInfo.Logs) > 0 { - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion+1, &taskInfo) - } else { - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, &taskInfo) - } - } - - if err2 == nil && phaseInfo.Info() != nil { - // Append sub-job status in Log Name for viz. - for _, log := range phaseInfo.Info().Logs { - log.Name += fmt.Sprintf(" (%s)", phaseInfo.Phase().String()) - } - } - - return phaseInfo, err2 - -} diff --git a/go/tasks/plugins/array/k8s/monitor_test.go b/go/tasks/plugins/array/k8s/monitor_test.go deleted file mode 100644 index 282e0036f..000000000 --- a/go/tasks/plugins/array/k8s/monitor_test.go +++ /dev/null @@ -1,437 +0,0 @@ -package k8s - -import ( - "fmt" - "testing" - - "github.com/flyteorg/flyteplugins/go/tasks/logs" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue" - - core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" - "github.com/flyteorg/flytestdlib/bitarray" - "github.com/stretchr/testify/mock" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - v12 "k8s.io/apimachinery/pkg/apis/meta/v1" - - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/stretchr/testify/assert" - "golang.org/x/net/context" - - structpb "google.golang.org/protobuf/types/known/structpb" -) - -func createSampleContainerTask() *core2.Container { - return &core2.Container{ - Command: []string{"cmd"}, - Args: []string{"{{$inputPrefix}}"}, - Image: "img1", - } -} - -func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.TaskExecutionContext { - customStruct, _ := structpb.NewStruct(map[string]interface{}{ - "parallelism": fmt.Sprintf("%d", parallelism), - }) - - tr := &mocks.TaskReader{} - tr.OnRead(ctx).Return(&core2.TaskTemplate{ - Custom: customStruct, - Target: &core2.TaskTemplate_Container{ - Container: createSampleContainerTask(), - }, - }, nil) - - tID := &mocks.TaskExecutionID{} - tID.OnGetGeneratedName().Return("notfound") - tID.OnGetID().Return(core2.TaskExecutionIdentifier{ - TaskId: &core2.Identifier{ - ResourceType: core2.ResourceType_TASK, - Project: "a", - Domain: "d", - Name: "n", - Version: "abc", - }, - NodeExecutionId: &core2.NodeExecutionIdentifier{ - NodeId: "node1", - ExecutionId: &core2.WorkflowExecutionIdentifier{ - Project: "a", - Domain: "d", - Name: "exec", - }, - }, - RetryAttempt: 0, - }) - - overrides := &mocks.TaskOverrides{} - overrides.OnGetResources().Return(&v1.ResourceRequirements{ - Requests: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("10"), - }, - }) - - tMeta := &mocks.TaskExecutionMetadata{} - tMeta.OnGetTaskExecutionID().Return(tID) - tMeta.OnGetOverrides().Return(overrides) - tMeta.OnIsInterruptible().Return(false) - tMeta.OnGetK8sServiceAccount().Return("s") - - tMeta.OnGetMaxAttempts().Return(2) - tMeta.OnGetNamespace().Return("n") - tMeta.OnGetLabels().Return(nil) - tMeta.OnGetAnnotations().Return(nil) - tMeta.OnGetOwnerReference().Return(v12.OwnerReference{}) - tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) - - ow := &mocks2.OutputWriter{} - ow.OnGetOutputPrefixPath().Return("/prefix/") - ow.OnGetRawOutputPrefix().Return("/raw_prefix/") - ow.OnGetCheckpointPrefix().Return("/checkpoint") - ow.OnGetPreviousCheckpointsPrefix().Return("/prev") - - ir := &mocks2.InputReader{} - ir.OnGetInputPrefixPath().Return("/prefix/") - ir.OnGetInputPath().Return("/prefix/inputs.pb") - ir.OnGetMatch(mock.Anything).Return(&core2.LiteralMap{}, nil) - - tCtx := &mocks.TaskExecutionContext{} - tCtx.OnTaskReader().Return(tr) - tCtx.OnTaskExecutionMetadata().Return(tMeta) - tCtx.OnOutputWriter().Return(ow) - tCtx.OnInputReader().Return(ir) - return tCtx -} - -func TestGetNamespaceForExecution(t *testing.T) { - ctx := context.Background() - tCtx := getMockTaskExecutionContext(ctx, 0) - - assert.Equal(t, GetNamespaceForExecution(tCtx, ""), tCtx.TaskExecutionMetadata().GetNamespace()) - assert.Equal(t, GetNamespaceForExecution(tCtx, "abcd"), "abcd") - assert.Equal(t, GetNamespaceForExecution(tCtx, "a-{{.namespace}}-b"), fmt.Sprintf("a-%s-b", tCtx.TaskExecutionMetadata().GetNamespace())) -} - -func testSubTaskIDs(t *testing.T, actual []*string) { - var expected = make([]*string, 5) - for i := 0; i < len(expected); i++ { - subTaskID := fmt.Sprintf("notfound-%d", i) - expected[i] = &subTaskID - } - assert.EqualValues(t, expected, actual) -} - -func TestCheckSubTasksState(t *testing.T) { - ctx := context.Background() - - tCtx := getMockTaskExecutionContext(ctx, 0) - kubeClient := mocks.KubeClient{} - kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) - kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) - - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) - tCtx.OnResourceManager().Return(&resourceManager) - - t.Run("Happy case", func(t *testing.T) { - config := Config{ - MaxArrayJobSize: 100, - MaxErrorStringLength: 200, - NamespaceTemplate: "a-{{.namespace}}-b", - OutputAssembler: workqueue.Config{ - Workers: 2, - MaxRetries: 0, - IndexCacheMaxItems: 100, - }, - ErrorAssembler: workqueue.Config{ - Workers: 2, - MaxRetries: 0, - IndexCacheMaxItems: 100, - }, - LogConfig: LogConfig{ - Config: logs.LogConfig{ - IsCloudwatchEnabled: true, - CloudwatchTemplateURI: "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.{{ .podName }};streamFilter=typeLogStreamPrefix", - IsKubernetesEnabled: true, - KubernetesTemplateURI: "k8s/log/{{.namespace}}/{{.podName}}/pod?namespace={{.namespace}}", - }}, - } - cacheIndexes := bitarray.NewBitSet(5) - cacheIndexes.Set(0) - cacheIndexes.Set(1) - cacheIndexes.Set(2) - cacheIndexes.Set(3) - cacheIndexes.Set(4) - - newState, logLinks, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - IndexesToCache: cacheIndexes, - }) - - assert.Nil(t, err) - assert.NotEmpty(t, logLinks) - assert.Equal(t, 10, len(logLinks)) - for i := 0; i < 10; i = i + 2 { - assert.Equal(t, fmt.Sprintf("Kubernetes Logs #0-%d (PhaseRunning)", i/2), logLinks[i].Name) - assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d/pod?namespace=a-n-b", i/2), logLinks[i].Uri) - - assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #0-%d (PhaseRunning)", i/2), logLinks[i+1].Name) - assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri) - } - - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) - resourceManager.AssertNumberOfCalls(t, "AllocateResource", 0) - testSubTaskIDs(t, subTaskIDs) - }) - - t.Run("Resource exhausted", func(t *testing.T) { - config := Config{ - MaxArrayJobSize: 100, - ResourceConfig: ResourceConfig{ - PrimaryLabel: "p", - Limit: 10, - }, - } - - retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) - assert.NoError(t, err) - - newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - ArrayStatus: arraystatus.ArrayStatus{ - Detailed: arrayCore.NewPhasesCompactArray(uint(5)), - }, - IndexesToCache: bitarray.NewBitSet(5), - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWaitingForResources.String(), p.String()) - resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) - assert.Empty(t, subTaskIDs, "subtask ids are only populated when monitor is called for a successfully launched task") - }) - - t.Run("RetryableSubtaskFailure", func(t *testing.T) { - failureIndex := 2 - - config := Config{ - MaxArrayJobSize: 100, - MaxErrorStringLength: 200, - } - - detailed := arrayCore.NewPhasesCompactArray(uint(5)) - detailed.SetItem(failureIndex, bitarray.Item(core.PhaseRetryableFailure)) - - retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(1)) - assert.NoError(t, err) - - cacheIndexes := bitarray.NewBitSet(5) - newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - IndexesToCache: cacheIndexes, - ArrayStatus: arraystatus.ArrayStatus{ - Detailed: detailed, - }, - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) - assert.Equal(t, core.PhaseUndefined, core.Phases[newState.ArrayStatus.Detailed.GetItem(failureIndex)]) - assert.Equal(t, uint64(1), newState.RetryAttempts.GetItem(failureIndex)) - }) - - t.Run("PermanentSubtaskFailure", func(t *testing.T) { - failureIndex := 2 - - config := Config{ - MaxArrayJobSize: 100, - MaxErrorStringLength: 200, - } - - detailed := arrayCore.NewPhasesCompactArray(uint(5)) - detailed.SetItem(failureIndex, bitarray.Item(core.PhaseRetryableFailure)) - - retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(1)) - assert.NoError(t, err) - retryAttemptsArray.SetItem(failureIndex, bitarray.Item(1)) - - cacheIndexes := bitarray.NewBitSet(5) - newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - IndexesToCache: cacheIndexes, - ArrayStatus: arraystatus.ArrayStatus{ - Detailed: detailed, - }, - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) - assert.Equal(t, core.PhasePermanentFailure, core.Phases[newState.ArrayStatus.Detailed.GetItem(failureIndex)]) - assert.Equal(t, uint64(1), newState.RetryAttempts.GetItem(failureIndex)) - }) -} - -func TestCheckSubTasksStateParallelism(t *testing.T) { - subtaskCount := 5 - - for i := 1; i <= subtaskCount; i++ { - t.Run(fmt.Sprintf("Parallelism-%d", i), func(t *testing.T) { - // construct task context - ctx := context.Background() - - tCtx := getMockTaskExecutionContext(ctx, i) - kubeClient := mocks.KubeClient{} - kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) - kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) - - resourceManager := mocks.ResourceManager{} - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) - tCtx.OnResourceManager().Return(&resourceManager) - - // evaluate one round of subtask launch and monitor - config := Config{ - MaxArrayJobSize: 100, - } - - retryAttemptsArray, err := bitarray.NewCompactArray(uint(subtaskCount), bitarray.Item(0)) - assert.NoError(t, err) - - cacheIndexes := bitarray.NewBitSet(uint(subtaskCount)) - newState, _, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: subtaskCount, - OriginalArraySize: int64(subtaskCount * 2), - OriginalMinSuccesses: int64(subtaskCount * 2), - IndexesToCache: cacheIndexes, - ArrayStatus: arraystatus.ArrayStatus{ - Detailed: arrayCore.NewPhasesCompactArray(uint(subtaskCount)), - }, - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) - - // validate only i subtasks were processed - executed := 0 - for _, existingPhaseIdx := range newState.GetArrayStatus().Detailed.GetItems() { - if core.Phases[existingPhaseIdx] != core.PhaseUndefined { - executed++ - } - } - - assert.Equal(t, i, executed) - }) - } -} - -func TestCheckSubTasksStateResourceGranted(t *testing.T) { - ctx := context.Background() - - tCtx := getMockTaskExecutionContext(ctx, 0) - kubeClient := mocks.KubeClient{} - kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) - kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) - - resourceManager := mocks.ResourceManager{} - - resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) - resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - tCtx.OnResourceManager().Return(&resourceManager) - - t.Run("Resource granted", func(t *testing.T) { - config := Config{ - MaxArrayJobSize: 100, - ResourceConfig: ResourceConfig{ - PrimaryLabel: "p", - Limit: 10, - }, - } - - retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) - assert.NoError(t, err) - - cacheIndexes := bitarray.NewBitSet(5) - newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - IndexesToCache: cacheIndexes, - ArrayStatus: arraystatus.ArrayStatus{ - Detailed: arrayCore.NewPhasesCompactArray(uint(5)), - }, - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) - resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) - testSubTaskIDs(t, subTaskIDs) - }) - - t.Run("All tasks success", func(t *testing.T) { - config := Config{ - MaxArrayJobSize: 100, - ResourceConfig: ResourceConfig{ - PrimaryLabel: "p", - Limit: 10, - }, - } - - arrayStatus := &arraystatus.ArrayStatus{ - Summary: arraystatus.ArraySummary{}, - Detailed: arrayCore.NewPhasesCompactArray(uint(5)), - } - for childIdx := range arrayStatus.Detailed.GetItems() { - arrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseSuccess)) - - } - cacheIndexes := bitarray.NewBitSet(5) - retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) - assert.NoError(t, err) - - newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ - CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, - ExecutionArraySize: 5, - OriginalArraySize: 10, - OriginalMinSuccesses: 5, - ArrayStatus: *arrayStatus, - IndexesToCache: cacheIndexes, - RetryAttempts: retryAttemptsArray, - }) - - assert.Nil(t, err) - p, _ := newState.GetPhase() - assert.Equal(t, arrayCore.PhaseWriteToDiscovery.String(), p.String()) - resourceManager.AssertNumberOfCalls(t, "ReleaseResource", 5) - assert.Empty(t, subTaskIDs, "terminal phases don't need to collect subtask IDs") - }) -} diff --git a/go/tasks/plugins/array/k8s/subtask.go b/go/tasks/plugins/array/k8s/subtask.go new file mode 100644 index 000000000..f4e5c12d8 --- /dev/null +++ b/go/tasks/plugins/array/k8s/subtask.go @@ -0,0 +1,346 @@ +package k8s + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod" + + stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8stypes "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + ErrBuildPodTemplate stdErrors.ErrorCode = "POD_TEMPLATE_FAILED" + ErrReplaceCmdTemplate stdErrors.ErrorCode = "CMD_TEMPLATE_FAILED" + FlyteK8sArrayIndexVarName string = "FLYTE_K8S_ARRAY_INDEX" + finalizer string = "flyte/array" + JobIndexVarName string = "BATCH_JOB_ARRAY_INDEX_VAR_NAME" +) + +var ( + arrayJobEnvVars = []v1.EnvVar{ + { + Name: JobIndexVarName, + Value: FlyteK8sArrayIndexVarName, + }, + } + namespaceRegex = regexp.MustCompile("(?i){{.namespace}}(?i)") +) + +// addMetadata sets k8s pod metadata that is either specifically required by the k8s array plugin +// or defined in the plugin configuration. +func addMetadata(stCtx SubTaskExecutionContext, cfg *Config, k8sPluginCfg *config.K8sPluginConfig, pod *v1.Pod) { + taskExecutionMetadata := stCtx.TaskExecutionMetadata() + + // Default to parent namespace + namespace := taskExecutionMetadata.GetNamespace() + if cfg.NamespaceTemplate != "" { + if namespaceRegex.MatchString(cfg.NamespaceTemplate) { + namespace = namespaceRegex.ReplaceAllString(cfg.NamespaceTemplate, namespace) + } else { + namespace = cfg.NamespaceTemplate + } + } + + pod.SetNamespace(namespace) + pod.SetAnnotations(utils.UnionMaps(k8sPluginCfg.DefaultAnnotations, pod.GetAnnotations(), utils.CopyMap(taskExecutionMetadata.GetAnnotations()))) + pod.SetLabels(utils.UnionMaps(pod.GetLabels(), utils.CopyMap(taskExecutionMetadata.GetLabels()), k8sPluginCfg.DefaultLabels)) + pod.SetName(taskExecutionMetadata.GetTaskExecutionID().GetGeneratedName()) + + if !cfg.RemoteClusterConfig.Enabled { + pod.OwnerReferences = []metav1.OwnerReference{taskExecutionMetadata.GetOwnerReference()} + } + + if k8sPluginCfg.InjectFinalizer { + f := append(pod.GetFinalizers(), finalizer) + pod.SetFinalizers(f) + } + + if len(cfg.DefaultScheduler) > 0 { + pod.Spec.SchedulerName = cfg.DefaultScheduler + } + + // The legacy map task implemented these as overrides so they were left as such. May want to + // revist whether they would serve better as appends. + if len(cfg.NodeSelector) != 0 { + pod.Spec.NodeSelector = cfg.NodeSelector + } + if len(cfg.Tolerations) != 0 { + pod.Spec.Tolerations = cfg.Tolerations + } +} + +// abortSubtask attempts to interrupt the k8s pod defined by the SubTaskExecutionContext and Config +func abortSubtask(ctx context.Context, stCtx SubTaskExecutionContext, cfg *Config, kubeClient pluginsCore.KubeClient) error { + logger.Infof(ctx, "KillTask invoked. We will attempt to delete object [%v].", + stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + + var plugin k8s.Plugin = podPlugin.DefaultPodPlugin + o, err := plugin.BuildIdentityResource(ctx, stCtx.TaskExecutionMetadata()) + if err != nil { + // This will recurrent, so we will skip further finalize + logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v, when finalizing.", stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return nil + } + + addMetadata(stCtx, cfg, config.GetK8sPluginConfig(), o.(*v1.Pod)) + + deleteResource := true + abortOverride, hasAbortOverride := plugin.(k8s.PluginAbortOverride) + + resourceToFinalize := o + var behavior k8s.AbortBehavior + + if hasAbortOverride { + behavior, err = abortOverride.OnAbort(ctx, stCtx, o) + deleteResource = err == nil && behavior.DeleteResource + if err == nil && behavior.Resource != nil { + resourceToFinalize = behavior.Resource + } + } + + if err != nil { + } else if deleteResource { + err = kubeClient.GetClient().Delete(ctx, resourceToFinalize) + } else { + if behavior.Patch != nil && behavior.Update == nil { + err = kubeClient.GetClient().Patch(ctx, resourceToFinalize, behavior.Patch.Patch, behavior.Patch.Options...) + } else if behavior.Patch == nil && behavior.Update != nil { + err = kubeClient.GetClient().Update(ctx, resourceToFinalize, behavior.Update.Options...) + } else { + err = errors.Errorf(errors.RuntimeFailure, "AbortBehavior for resource %v must specify either a Patch and an Update operation if Delete is set to false. Only one can be supplied.", resourceToFinalize.GetName()) + } + if behavior.DeleteOnErr && err != nil { + logger.Warningf(ctx, "Failed to apply AbortBehavior for resource %v with error %v. Will attempt to delete resource.", resourceToFinalize.GetName(), err) + err = kubeClient.GetClient().Delete(ctx, resourceToFinalize) + } + } + + if err != nil && !isK8sObjectNotExists(err) { + logger.Warningf(ctx, "Failed to clear finalizers for Resource with name: %v/%v. Error: %v", + resourceToFinalize.GetNamespace(), resourceToFinalize.GetName(), err) + return err + } + + return nil +} + +// clearFinalizers removes finalizers (if they exist) from the k8s resource +func clearFinalizers(ctx context.Context, o client.Object, kubeClient pluginsCore.KubeClient) error { + if len(o.GetFinalizers()) > 0 { + o.SetFinalizers([]string{}) + err := kubeClient.GetClient().Update(ctx, o) + if err != nil && !isK8sObjectNotExists(err) { + logger.Warningf(ctx, "Failed to clear finalizers for Resource with name: %v/%v. Error: %v", o.GetNamespace(), o.GetName(), err) + return err + } + } else { + logger.Debugf(ctx, "Finalizers are already empty for Resource with name: %v/%v", o.GetNamespace(), o.GetName()) + } + return nil +} + +// launchSubtask creates a k8s pod defined by the SubTaskExecutionContext and Config. +func launchSubtask(ctx context.Context, stCtx SubTaskExecutionContext, cfg *Config, kubeClient pluginsCore.KubeClient) (pluginsCore.PhaseInfo, error) { + o, err := podPlugin.DefaultPodPlugin.BuildResource(ctx, stCtx) + pod := o.(*v1.Pod) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + addMetadata(stCtx, cfg, config.GetK8sPluginConfig(), pod) + + // inject maptask specific container environment variables + if len(pod.Spec.Containers) == 0 { + return pluginsCore.PhaseInfoUndefined, stdErrors.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.") + } + + containerIndex, err := getTaskContainerIndex(pod) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, v1.EnvVar{ + Name: FlyteK8sArrayIndexVarName, + // Use the OriginalIndex which represents the position of the subtask in the original user's map task before + // compacting indexes caused by catalog-cache-check. + Value: strconv.Itoa(stCtx.originalIndex), + }) + + pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, arrayJobEnvVars...) + + logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", pod.GetObjectKind().GroupVersionKind(), pod.GetNamespace(), pod.GetName()) + err = kubeClient.GetClient().Create(ctx, pod) + if err != nil && !k8serrors.IsAlreadyExists(err) { + if k8serrors.IsForbidden(err) { + if strings.Contains(err.Error(), "exceeded quota") { + logger.Warnf(ctx, "Failed to launch job, resource quota exceeded and the operation is not guarded by back-off. err: %v", err) + return pluginsCore.PhaseInfoWaitingForResourcesInfo(time.Now(), pluginsCore.DefaultPhaseVersion, fmt.Sprintf("Exceeded resourcequota: %s", err.Error()), nil), nil + } + return pluginsCore.PhaseInfoRetryableFailure("RuntimeFailure", err.Error(), nil), nil + } else if k8serrors.IsBadRequest(err) || k8serrors.IsInvalid(err) { + logger.Errorf(ctx, "Badly formatted resource for plugin [%s], err %s", executorName, err) + // return pluginsCore.DoTransition(pluginsCore.PhaseInfoFailure("BadTaskFormat", err.Error(), nil)), nil + } else if k8serrors.IsRequestEntityTooLargeError(err) { + logger.Errorf(ctx, "Badly formatted resource for plugin [%s], err %s", executorName, err) + return pluginsCore.PhaseInfoFailure("EntityTooLarge", err.Error(), nil), nil + } + reason := k8serrors.ReasonForError(err) + logger.Errorf(ctx, "Failed to launch job, system error. err: %v", err) + return pluginsCore.PhaseInfoUndefined, errors.Wrapf(stdErrors.ErrorCode(reason), err, "failed to create resource") + } + + return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "task submitted to K8s"), nil +} + +// finalizeSubtask performs operations to complete the k8s pod defined by the SubTaskExecutionContext +// and Config. These may include removing finalizers and deleting the k8s resource. +func finalizeSubtask(ctx context.Context, stCtx SubTaskExecutionContext, cfg *Config, kubeClient pluginsCore.KubeClient) error { + errs := stdErrors.ErrorCollection{} + var pod *v1.Pod + var nsName k8stypes.NamespacedName + k8sPluginCfg := config.GetK8sPluginConfig() + if k8sPluginCfg.InjectFinalizer || k8sPluginCfg.DeleteResourceOnFinalize { + o, err := podPlugin.DefaultPodPlugin.BuildIdentityResource(ctx, stCtx.TaskExecutionMetadata()) + if err != nil { + // This will recurrent, so we will skip further finalize + logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v, when finalizing.", stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return nil + } + + pod = o.(*v1.Pod) + + addMetadata(stCtx, cfg, config.GetK8sPluginConfig(), pod) + nsName = k8stypes.NamespacedName{Namespace: pod.GetNamespace(), Name: pod.GetName()} + } + + // In InjectFinalizer is on, it means we may have added the finalizers when we launched this resource. Attempt to + // clear them to allow the object to be deleted/garbage collected. If InjectFinalizer was turned on (through config) + // after the resource was created, we will not find any finalizers to clear and the object may have already been + // deleted at this point. Therefore, account for these cases and do not consider them errors. + if k8sPluginCfg.InjectFinalizer { + // Attempt to get resource from informer cache, if not found, retrieve it from API server. + if err := kubeClient.GetClient().Get(ctx, nsName, pod); err != nil { + if isK8sObjectNotExists(err) { + return nil + } + // This happens sometimes because a node gets removed and K8s deletes the pod. This will result in a + // Pod does not exist error. This should be retried using the retry policy + logger.Warningf(ctx, "Failed in finalizing get Resource with name: %v. Error: %v", nsName, err) + return err + } + + // This must happen after sending admin event. It's safe against partial failures because if the event failed, we will + // simply retry in the next round. If the event succeeded but this failed, we will try again the next round to send + // the same event (idempotent) and then come here again... + err := clearFinalizers(ctx, pod, kubeClient) + if err != nil { + errs.Append(err) + } + } + + // If we should delete the resource when finalize is called, do a best effort delete. + if k8sPluginCfg.DeleteResourceOnFinalize { + // Attempt to delete resource, if not found, return success. + if err := kubeClient.GetClient().Delete(ctx, pod); err != nil { + if isK8sObjectNotExists(err) { + return errs.ErrorOrDefault() + } + + // This happens sometimes because a node gets removed and K8s deletes the pod. This will result in a + // Pod does not exist error. This should be retried using the retry policy + logger.Warningf(ctx, "Failed in finalizing. Failed to delete Resource with name: %v. Error: %v", nsName, err) + errs.Append(fmt.Errorf("finalize: failed to delete resource with name [%v]. Error: %w", nsName, err)) + } + } + + return errs.ErrorOrDefault() +} + +// getSubtaskPhaseInfo returns the PhaseInfo describing an existing k8s resource which is defined +// by the SubTaskExecutionContext and Config. +func getSubtaskPhaseInfo(ctx context.Context, stCtx SubTaskExecutionContext, cfg *Config, kubeClient pluginsCore.KubeClient, logPlugin tasklog.Plugin) (pluginsCore.PhaseInfo, error) { + o, err := podPlugin.DefaultPodPlugin.BuildIdentityResource(ctx, stCtx.TaskExecutionMetadata()) + if err != nil { + logger.Errorf(ctx, "Failed to build the Resource with name: %v. Error: %v", stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return pluginsCore.PhaseInfoFailure("BadTaskDefinition", fmt.Sprintf("Failed to build resource, caused by: %s", err.Error()), nil), nil + } + + pod := o.(*v1.Pod) + addMetadata(stCtx, cfg, config.GetK8sPluginConfig(), pod) + + // Attempt to get resource from informer cache, if not found, retrieve it from API server. + nsName := k8stypes.NamespacedName{Name: pod.GetName(), Namespace: pod.GetNamespace()} + if err := kubeClient.GetClient().Get(ctx, nsName, pod); err != nil { + if isK8sObjectNotExists(err) { + // This happens sometimes because a node gets removed and K8s deletes the pod. This will result in a + // Pod does not exist error. This should be retried using the retry policy + logger.Warnf(ctx, "Failed to find the Resource with name: %v. Error: %v", nsName, err) + failureReason := fmt.Sprintf("resource not found, name [%s]. reason: %s", nsName.String(), err.Error()) + return pluginsCore.PhaseInfoSystemRetryableFailure("ResourceDeletedExternally", failureReason, nil), nil + } + + logger.Warnf(ctx, "Failed to retrieve Resource Details with name: %v. Error: %v", nsName, err) + return pluginsCore.PhaseInfoUndefined, err + } + + stID, _ := stCtx.TaskExecutionMetadata().GetTaskExecutionID().(SubTaskExecutionID) + phaseInfo, err := podPlugin.DefaultPodPlugin.GetTaskPhaseWithLogs(ctx, stCtx, pod, logPlugin, stID.GetLogSuffix()) + if err != nil { + logger.Warnf(ctx, "failed to check status of resource in plugin [%s], with error: %s", executorName, err.Error()) + return pluginsCore.PhaseInfoUndefined, err + } + + if phaseInfo.Info() != nil { + // Append sub-job status in Log Name for viz. + for _, log := range phaseInfo.Info().Logs { + log.Name += fmt.Sprintf(" (%s)", phaseInfo.Phase().String()) + } + } + + return phaseInfo, err +} + +// getTaskContainerIndex returns the index of the primary container in a k8s pod. +func getTaskContainerIndex(pod *v1.Pod) (int, error) { + primaryContainerName, ok := pod.Annotations[podPlugin.PrimaryContainerKey] + // For tasks with a Container target, we only ever build one container as part of the pod + if !ok { + if len(pod.Spec.Containers) == 1 { + return 0, nil + } + // For tasks with a K8sPod task target, they may produce multiple containers but at least one must be the designated primary. + return -1, stdErrors.Errorf(ErrBuildPodTemplate, "Expected a specified primary container key when building an array job with a K8sPod spec target") + + } + + for idx, container := range pod.Spec.Containers { + if container.Name == primaryContainerName { + return idx, nil + } + } + return -1, stdErrors.Errorf(ErrBuildPodTemplate, "Couldn't find any container matching the primary container key when building an array job with a K8sPod spec target") +} + +// isK8sObjectNotExists returns true if the error is one which describes a non existent k8s object. +func isK8sObjectNotExists(err error) bool { + return k8serrors.IsNotFound(err) || k8serrors.IsGone(err) || k8serrors.IsResourceExpired(err) +} diff --git a/go/tasks/plugins/array/k8s/subtask_exec_context.go b/go/tasks/plugins/array/k8s/subtask_exec_context.go new file mode 100644 index 000000000..1d2ea1e85 --- /dev/null +++ b/go/tasks/plugins/array/k8s/subtask_exec_context.go @@ -0,0 +1,139 @@ +package k8s + +import ( + "context" + "fmt" + "strconv" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" + podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod" +) + +// SubTaskExecutionContext wraps the core TaskExecutionContext so that the k8s array task context +// can be used within the pod plugin +type SubTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + arrayInputReader io.InputReader + metadataOverride pluginsCore.TaskExecutionMetadata + originalIndex int + subtaskReader SubTaskReader +} + +// InputReader overrides the base TaskExecutionContext to return a custom InputReader +func (s SubTaskExecutionContext) InputReader() io.InputReader { + return s.arrayInputReader +} + +// TaskExecutionMetadata overrides the base TaskExecutionContext to return custom +// TaskExecutionMetadata +func (s SubTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return s.metadataOverride +} + +// TaskReader overrides the base TaskExecutionContext to return a custom TaskReader +func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader { + return s.subtaskReader +} + +// newSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters +func newSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate, + executionIndex, originalIndex int, retryAttempt uint64) SubTaskExecutionContext { + + arrayInputReader := array.GetInputReader(tCtx, taskTemplate) + taskExecutionMetadata := tCtx.TaskExecutionMetadata() + taskExecutionID := taskExecutionMetadata.GetTaskExecutionID() + metadataOverride := SubTaskExecutionMetadata{ + taskExecutionMetadata, + SubTaskExecutionID{ + taskExecutionID, + executionIndex, + taskExecutionID.GetGeneratedName(), + retryAttempt, + taskExecutionID.GetID().RetryAttempt, + }, + } + + subtaskTemplate := &core.TaskTemplate{} + *subtaskTemplate = *taskTemplate + + if subtaskTemplate != nil { + subtaskTemplate.TaskTypeVersion = 2 + if subtaskTemplate.GetContainer() != nil { + subtaskTemplate.Type = podPlugin.ContainerTaskType + } else if taskTemplate.GetK8SPod() != nil { + subtaskTemplate.Type = podPlugin.SidecarTaskType + } + } + + subtaskReader := SubTaskReader{tCtx.TaskReader(), subtaskTemplate} + + return SubTaskExecutionContext{ + TaskExecutionContext: tCtx, + arrayInputReader: arrayInputReader, + metadataOverride: metadataOverride, + originalIndex: originalIndex, + subtaskReader: subtaskReader, + } +} + +// SubTaskReader wraps the core TaskReader to customize the task template task type and version +type SubTaskReader struct { + pluginsCore.TaskReader + subtaskTemplate *core.TaskTemplate +} + +// Read overrides the base TaskReader to return a custom TaskTemplate +func (s SubTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { + return s.subtaskTemplate, nil +} + +// SubTaskExecutionID wraps the core TaskExecutionID to customize the generated pod name +type SubTaskExecutionID struct { + pluginsCore.TaskExecutionID + executionIndex int + parentName string + subtaskRetryAttempt uint64 + taskRetryAttempt uint32 +} + +// GetGeneratedName overrides the base TaskExecutionID to append the subtask index and retryAttempt +func (s SubTaskExecutionID) GetGeneratedName() string { + indexStr := strconv.Itoa(s.executionIndex) + + // If the retryAttempt is 0 we do not include it in the pod name. The gives us backwards + // compatibility in the ability to dynamically transition running map tasks to use subtask retries. + if s.subtaskRetryAttempt == 0 { + return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v", s.parentName, indexStr)) + } + + retryAttemptStr := strconv.FormatUint(s.subtaskRetryAttempt, 10) + return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v-%v", s.parentName, indexStr, retryAttemptStr)) +} + +// GetLogSuffix returns the suffix which should be appended to subtask log names +func (s SubTaskExecutionID) GetLogSuffix() string { + // Append the retry attempt and executionIndex so that log names coincide with pod names per + // https://github.com/flyteorg/flyteplugins/pull/186#discussion_r666569825. To maintain + // backwards compatibility we append the subtaskRetryAttempt if it is not 0. + if s.subtaskRetryAttempt == 0 { + return fmt.Sprintf(" #%d-%d", s.taskRetryAttempt, s.executionIndex) + } + + return fmt.Sprintf(" #%d-%d-%d", s.taskRetryAttempt, s.executionIndex, s.subtaskRetryAttempt) +} + +// SubTaskExecutionMetadata wraps the core TaskExecutionMetadata to customize the TaskExecutionID +type SubTaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata + subtaskExecutionID SubTaskExecutionID +} + +// GetTaskExecutionID overrides the base TaskExecutionMetadata to return a custom TaskExecutionID +func (s SubTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecutionID { + return s.subtaskExecutionID +} diff --git a/go/tasks/plugins/array/k8s/subtask_exec_context_test.go b/go/tasks/plugins/array/k8s/subtask_exec_context_test.go new file mode 100644 index 000000000..24ecbf425 --- /dev/null +++ b/go/tasks/plugins/array/k8s/subtask_exec_context_test.go @@ -0,0 +1,32 @@ +package k8s + +import ( + "context" + "fmt" + "testing" + + podPlugin "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/pod" + + "github.com/stretchr/testify/assert" +) + +func TestSubTaskExecutionContext(t *testing.T) { + ctx := context.Background() + + tCtx := getMockTaskExecutionContext(ctx, 0) + taskTemplate, err := tCtx.TaskReader().Read(ctx) + assert.Nil(t, err) + + executionIndex := 0 + originalIndex := 5 + retryAttempt := uint64(1) + + stCtx := newSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt) + + assert.Equal(t, stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt)) + + subtaskTemplate, err := stCtx.TaskReader().Read(ctx) + assert.Nil(t, err) + assert.Equal(t, int32(2), subtaskTemplate.TaskTypeVersion) + assert.Equal(t, podPlugin.ContainerTaskType, subtaskTemplate.Type) +} diff --git a/go/tasks/plugins/array/k8s/task.go b/go/tasks/plugins/array/k8s/task.go deleted file mode 100644 index f29baeae5..000000000 --- a/go/tasks/plugins/array/k8s/task.go +++ /dev/null @@ -1,331 +0,0 @@ -package k8s - -import ( - "context" - "strconv" - "strings" - - metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "sigs.k8s.io/controller-runtime/pkg/client" - - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/errorcollector" - "github.com/flyteorg/flytestdlib/bitarray" - errors2 "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/storage" - corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - k8sTypes "k8s.io/apimachinery/pkg/types" -) - -type Task struct { - State *arrayCore.State - NewArrayStatus *arraystatus.ArrayStatus - Config *Config - ChildIdx int - OriginalIndex int - MessageCollector *errorcollector.ErrorMessageCollector - SubTaskIDs []*string -} - -type LaunchResult int8 -type MonitorResult int8 - -const ( - LaunchSuccess LaunchResult = iota - LaunchError - LaunchWaiting - LaunchReturnState -) - -const finalizer = "flyte/array" - -func addPodFinalizer(pod *corev1.Pod) *corev1.Pod { - pod.Finalizers = append(pod.Finalizers, finalizer) - return pod -} - -func removeString(list []string, target string) []string { - ret := make([]string, 0) - for _, s := range list { - if s != target { - ret = append(ret, s) - } - } - - return ret -} - -func clearFinalizer(pod *corev1.Pod) *corev1.Pod { - pod.Finalizers = removeString(pod.Finalizers, finalizer) - return pod -} - -const ( - MonitorSuccess MonitorResult = iota - MonitorError -) - -func getTaskContainerIndex(pod *v1.Pod) (int, error) { - primaryContainerName, ok := pod.Annotations[primaryContainerKey] - // For tasks with a Container target, we only ever build one container as part of the pod - if !ok { - if len(pod.Spec.Containers) == 1 { - return 0, nil - } - // For tasks with a K8sPod task target, they may produce multiple containers but at least one must be the designated primary. - return -1, errors2.Errorf(ErrBuildPodTemplate, "Expected a specified primary container key when building an array job with a K8sPod spec target") - - } - - for idx, container := range pod.Spec.Containers { - if container.Name == primaryContainerName { - return idx, nil - } - } - return -1, errors2.Errorf(ErrBuildPodTemplate, "Couldn't find any container matching the primary container key when building an array job with a K8sPod spec target") -} - -func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (LaunchResult, error) { - podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, t.Config.NamespaceTemplate) - if err != nil { - return LaunchError, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for a task") - } - // Remove owner references for remote cluster execution - if t.Config.RemoteClusterConfig.Enabled { - podTemplate.OwnerReferences = nil - } - - if len(podTemplate.Spec.Containers) == 0 { - return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.") - } - - containerIndex, err := getTaskContainerIndex(&podTemplate) - if err != nil { - return LaunchError, err - } - - podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), t.ChildIdx, t.State.RetryAttempts.GetItem(t.ChildIdx)) - allocationStatus, err := allocateResource(ctx, tCtx, t.Config, podName) - if err != nil { - return LaunchError, err - } - - if allocationStatus != core.AllocationStatusGranted { - t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(core.PhaseWaitingForResources)) - t.NewArrayStatus.Summary.Inc(core.PhaseWaitingForResources) - return LaunchWaiting, nil - } - - pod := podTemplate.DeepCopy() - pod.Name = podName - pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, corev1.EnvVar{ - Name: FlyteK8sArrayIndexVarName, - // Use the OriginalIndex which represents the position of the subtask in the original user's map task before - // compacting indexes caused by catalog-cache-check. - Value: strconv.Itoa(t.OriginalIndex), - }) - - pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, arrayJobEnvVars...) - taskTemplate, err := tCtx.TaskReader().Read(ctx) - if err != nil { - return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Unable to read task template") - } else if taskTemplate == nil { - return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Missing task template") - } - - pod = ApplyPodPolicies(ctx, t.Config, pod) - pod = applyNodeSelectorLabels(ctx, t.Config, pod) - pod = applyPodTolerations(ctx, t.Config, pod) - pod = addPodFinalizer(pod) - - // Check for existing pods to prevent unnecessary Resource-Quota usage: https://github.com/kubernetes/kubernetes/issues/76787 - existingPod := &corev1.Pod{} - err = kubeClient.GetCache().Get(ctx, client.ObjectKey{ - Namespace: pod.GetNamespace(), - Name: pod.GetName(), - }, existingPod) - - if err != nil && k8serrors.IsNotFound(err) { - // Attempt creating non-existing pod. - err = kubeClient.GetClient().Create(ctx, pod) - if err != nil && !k8serrors.IsAlreadyExists(err) { - if k8serrors.IsForbidden(err) { - if strings.Contains(err.Error(), "exceeded quota") { - // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. - logger.Infof(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) - t.State = t.State.SetPhase(arrayCore.PhaseWaitingForResources, 0).SetReason("Not enough resources to launch job") - } else { - t.State = t.State.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to launch job.") - } - - t.State.SetReason(err.Error()) - return LaunchReturnState, nil - } - - return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") - } - } else if err != nil { - // Another error returned. - logger.Error(ctx, err) - return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") - } - - return LaunchSuccess, nil -} - -func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, - logPlugin tasklog.Plugin) (MonitorResult, []*idlCore.TaskLog, error) { - retryAttempt := t.State.RetryAttempts.GetItem(t.ChildIdx) - podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), t.ChildIdx, retryAttempt) - t.SubTaskIDs = append(t.SubTaskIDs, &podName) - var loglinks []*idlCore.TaskLog - - // Use original-index for log-name/links - originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) - phaseInfo, err := FetchPodStatusAndLogs(ctx, kubeClient, - k8sTypes.NamespacedName{ - Name: podName, - Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate), - }, - originalIdx, - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, - retryAttempt, - logPlugin) - if err != nil { - return MonitorError, loglinks, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") - } - - if phaseInfo.Info() != nil { - loglinks = phaseInfo.Info().Logs - } - - if phaseInfo.Err() != nil { - t.MessageCollector.Collect(t.ChildIdx, phaseInfo.Err().String()) - } - - actualPhase := phaseInfo.Phase() - if phaseInfo.Phase().IsSuccess() { - actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, t.ChildIdx, originalIdx) - if err != nil { - return MonitorError, loglinks, err - } - } - - t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(actualPhase)) - t.NewArrayStatus.Summary.Inc(actualPhase) - - return MonitorSuccess, loglinks, nil -} - -func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { - podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), t.ChildIdx, t.State.RetryAttempts.GetItem(t.ChildIdx)) - pod := &corev1.Pod{ - TypeMeta: metav1.TypeMeta{ - Kind: PodKind, - APIVersion: metav1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: podName, - Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate), - }, - } - - err := kubeClient.GetClient().Delete(ctx, pod) - if err != nil { - if k8serrors.IsNotFound(err) { - - return nil - } - return err - } - - return nil - -} - -func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { - podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), t.ChildIdx, t.State.RetryAttempts.GetItem(t.ChildIdx)) - - pod := &v1.Pod{ - TypeMeta: metaV1.TypeMeta{ - Kind: PodKind, - APIVersion: v1.SchemeGroupVersion.String(), - }, - } - - err := kubeClient.GetClient().Get(ctx, k8sTypes.NamespacedName{ - Name: podName, - Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate), - }, pod) - - if err != nil { - if !k8serrors.IsNotFound(err) { - logger.Errorf(ctx, "Error fetching pod [%s] in Finalize [%s]", podName, err) - return err - } - } else { - pod = clearFinalizer(pod) - err := kubeClient.GetClient().Update(ctx, pod) - if err != nil { - logger.Errorf(ctx, "Error updating pod finalizer [%s] in Finalize [%s]", podName, err) - return err - } - } - - // Deallocate Resource - err = deallocateResource(ctx, tCtx, t.Config, podName) - if err != nil { - logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err) - return err - } - - return nil - -} - -func allocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) (core.AllocationStatus, error) { - if !IsResourceConfigSet(config.ResourceConfig) { - return core.AllocationStatusGranted, nil - } - - resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) - resourceConstraintSpec := core.ResourceConstraintsSpec{ - ProjectScopeResourceConstraint: nil, - NamespaceScopeResourceConstraint: nil, - } - - allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, resourceNamespace, podName, resourceConstraintSpec) - if err != nil { - logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), podName, err) - return core.AllocationUndefined, err - } - - logger.Infof(ctx, "Allocation result for [%s] is [%s]", podName, allocationStatus) - return allocationStatus, nil -} - -func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) error { - if !IsResourceConfigSet(config.ResourceConfig) { - return nil - } - resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel) - - err := tCtx.ResourceManager().ReleaseResource(ctx, resourceNamespace, podName) - if err != nil { - logger.Errorf(ctx, "Error releasing token [%s]. error %s", podName, err) - return err - } - - return nil -} diff --git a/go/tasks/plugins/array/k8s/task_test.go b/go/tasks/plugins/array/k8s/task_test.go deleted file mode 100644 index 874caedac..000000000 --- a/go/tasks/plugins/array/k8s/task_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package k8s - -import ( - "context" - "testing" - - v1 "k8s.io/api/core/v1" - - "github.com/stretchr/testify/mock" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - - "github.com/flyteorg/flytestdlib/bitarray" - - "github.com/stretchr/testify/assert" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -func TestFinalize(t *testing.T) { - ctx := context.Background() - - tCtx := getMockTaskExecutionContext(ctx, 0) - kubeClient := mocks.KubeClient{} - kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) - - resourceManager := mocks.ResourceManager{} - podTemplate, _, _ := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, "") - pod := addPodFinalizer(&podTemplate) - pod.Name = formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), 1, 1) - assert.Equal(t, "notfound-1-1", pod.Name) - assert.NoError(t, kubeClient.GetClient().Create(ctx, pod)) - - resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - tCtx.OnResourceManager().Return(&resourceManager) - - config := Config{ - MaxArrayJobSize: 100, - ResourceConfig: ResourceConfig{ - PrimaryLabel: "p", - Limit: 10, - }, - } - - retryAttemptsArray, err := bitarray.NewCompactArray(2, 1) - assert.NoError(t, err) - - state := core.State{ - RetryAttempts: retryAttemptsArray, - } - - task := &Task{ - State: &state, - Config: &config, - ChildIdx: 1, - } - - err = task.Finalize(ctx, tCtx, &kubeClient) - assert.NoError(t, err) -} - -func TestGetTaskContainerIndex(t *testing.T) { - t.Run("test container target", func(t *testing.T) { - pod := &v1.Pod{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "container", - }, - }, - }, - } - index, err := getTaskContainerIndex(pod) - assert.NoError(t, err) - assert.Equal(t, 0, index) - }) - t.Run("test missing primary container annotation", func(t *testing.T) { - pod := &v1.Pod{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "container", - }, - { - Name: "container b", - }, - }, - }, - } - _, err := getTaskContainerIndex(pod) - assert.EqualError(t, err, "[POD_TEMPLATE_FAILED] Expected a specified primary container key when building an array job with a K8sPod spec target") - }) - t.Run("test get primary container index", func(t *testing.T) { - pod := &v1.Pod{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "container a", - }, - { - Name: "container b", - }, - { - Name: "container c", - }, - }, - }, - ObjectMeta: metav1.ObjectMeta{ - Annotations: map[string]string{ - primaryContainerKey: "container c", - }, - }, - } - index, err := getTaskContainerIndex(pod) - assert.NoError(t, err) - assert.Equal(t, 2, index) - }) - t.Run("specified primary container doesn't exist", func(t *testing.T) { - pod := &v1.Pod{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - { - Name: "container a", - }, - }, - }, - ObjectMeta: metav1.ObjectMeta{ - Annotations: map[string]string{ - primaryContainerKey: "container c", - }, - }, - } - _, err := getTaskContainerIndex(pod) - assert.EqualError(t, err, "[POD_TEMPLATE_FAILED] Couldn't find any container matching the primary container key when building an array job with a K8sPod spec target") - }) -} diff --git a/go/tasks/plugins/array/k8s/transformer.go b/go/tasks/plugins/array/k8s/transformer.go deleted file mode 100644 index 8b10c54d3..000000000 --- a/go/tasks/plugins/array/k8s/transformer.go +++ /dev/null @@ -1,182 +0,0 @@ -package k8s - -import ( - "context" - "regexp" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - - "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" - - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - core2 "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -const PodKind = "pod" -const primaryContainerKey = "primary_container_name" - -var namespaceRegex = regexp.MustCompile("(?i){{.namespace}}(?i)") - -type arrayTaskContext struct { - core.TaskExecutionContext - arrayInputReader io.InputReader -} - -// InputReader overrides the TaskExecutionContext from base and returns a specialized context for Array -func (a *arrayTaskContext) InputReader() io.InputReader { - return a.arrayInputReader -} - -func GetNamespaceForExecution(tCtx core.TaskExecutionContext, namespaceTemplate string) string { - - // Default to parent namespace - namespace := tCtx.TaskExecutionMetadata().GetNamespace() - if namespaceTemplate != "" { - if namespaceRegex.MatchString(namespaceTemplate) { - namespace = namespaceRegex.ReplaceAllString(namespaceTemplate, namespace) - } else { - namespace = namespaceTemplate - } - } - return namespace -} - -// Initializes a pod from an array job task template with a K8sPod set as the task target. -// TODO: This should be removed by end of 2021 (it duplicates the pod plugin logic) once we improve array job handling -// and move it to the node level. See https://github.com/flyteorg/flyte/issues/1131 -func buildPodMapTask(task *idlCore.TaskTemplate, metadata core.TaskExecutionMetadata) (v1.Pod, error) { - if task.GetK8SPod() == nil || task.GetK8SPod().PodSpec == nil { - return v1.Pod{}, errors.Errorf(errors.BadTaskSpecification, "Missing pod spec for task") - } - var podSpec = &v1.PodSpec{} - err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &podSpec) - if err != nil { - return v1.Pod{}, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - primaryContainerName, ok := task.GetConfig()[primaryContainerKey] - if !ok { - return v1.Pod{}, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig()) - } - - var pod = v1.Pod{ - Spec: *podSpec, - } - if task.GetK8SPod().Metadata != nil { - if task.GetK8SPod().Metadata.Annotations != nil { - pod.Annotations = task.GetK8SPod().Metadata.Annotations - } - if task.GetK8SPod().Metadata.Labels != nil { - pod.Labels = task.GetK8SPod().Metadata.Labels - } - } - if len(pod.Annotations) == 0 { - pod.Annotations = make(map[string]string) - } - pod.Annotations[primaryContainerKey] = primaryContainerName - - // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a - // CrashLoopBackoff after the initial job completion. - pod.Spec.RestartPolicy = v1.RestartPolicyNever - flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(metadata) - return pod, nil -} - -// FlyteArrayJobToK8sPodTemplate returns a pod template for the given task context. Note that Name is not set on the -// result object. It's up to the caller to set the Name before creating the object in K8s. -func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext, namespaceTemplate string) ( - podTemplate v1.Pod, job *idlPlugins.ArrayJob, err error) { - - // Check that the taskTemplate is valid - taskTemplate, err := tCtx.TaskReader().Read(ctx) - if err != nil { - return v1.Pod{}, nil, err - } else if taskTemplate == nil { - return v1.Pod{}, nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil") - } - - if taskTemplate.GetContainer() == nil && taskTemplate.GetK8SPod() == nil { - return v1.Pod{}, nil, errors.Errorf(errors.BadTaskSpecification, - "Required value not set, taskTemplate Container or K8sPod") - } - - arrTCtx := &arrayTaskContext{ - TaskExecutionContext: tCtx, - arrayInputReader: array.GetInputReader(tCtx, taskTemplate), - } - - var arrayJob *idlPlugins.ArrayJob - if taskTemplate.GetCustom() != nil { - arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) - if err != nil { - return v1.Pod{}, nil, err - } - } - - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, tCtx.TaskExecutionMetadata().GetAnnotations()) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, tCtx.TaskExecutionMetadata().GetLabels()) - - var pod = v1.Pod{ - TypeMeta: metav1.TypeMeta{ - Kind: PodKind, - APIVersion: v1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - // Note that name is missing here - Namespace: GetNamespaceForExecution(tCtx, namespaceTemplate), - Labels: labels, - Annotations: annotations, - OwnerReferences: []metav1.OwnerReference{tCtx.TaskExecutionMetadata().GetOwnerReference()}, - }, - } - - if taskTemplate.GetContainer() != nil { - podSpec, err := flytek8s.ToK8sPodSpecWithInterruptible(ctx, arrTCtx, true) - if err != nil { - return v1.Pod{}, nil, err - } - - pod.Spec = *podSpec - } else if taskTemplate.GetK8SPod() != nil { - k8sPod, err := buildPodMapTask(taskTemplate, tCtx.TaskExecutionMetadata()) - if err != nil { - return v1.Pod{}, nil, err - } - - pod.Labels = utils.UnionMaps(pod.Labels, k8sPod.Labels) - pod.Annotations = utils.UnionMaps(pod.Annotations, k8sPod.Annotations) - pod.Spec = k8sPod.Spec - - containerIndex, err := getTaskContainerIndex(&pod) - if err != nil { - return v1.Pod{}, nil, err - } - - templateParameters := template.Parameters{ - TaskExecMetadata: tCtx.TaskExecutionMetadata(), - Inputs: arrTCtx.arrayInputReader, - OutputPath: tCtx.OutputWriter(), - Task: tCtx.TaskReader(), - } - - err = flytek8s.AddFlyteCustomizationsToContainer( - ctx, templateParameters, flytek8s.ResourceCustomizationModeMergeExistingResources, &pod.Spec.Containers[containerIndex]) - if err != nil { - return v1.Pod{}, nil, err - } - } - - return pod, arrayJob, nil -} diff --git a/go/tasks/plugins/array/k8s/transformer_test.go b/go/tasks/plugins/array/k8s/transformer_test.go deleted file mode 100644 index 822e3c544..000000000 --- a/go/tasks/plugins/array/k8s/transformer_test.go +++ /dev/null @@ -1,257 +0,0 @@ -package k8s - -import ( - "context" - "encoding/json" - "fmt" - "testing" - - "k8s.io/apimachinery/pkg/api/resource" - - "github.com/flyteorg/flytestdlib/storage" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v12 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -const testPrimaryContainerName = "primary container" - -var podSpec = v1.PodSpec{ - Containers: []v1.Container{ - { - Name: testPrimaryContainerName, - Resources: v1.ResourceRequirements{ - Requests: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("1"), - v1.ResourceStorage: resource.MustParse("2"), - }, - }, - }, - { - Name: "secondary container", - }, - }, -} - -var arrayJob = idlPlugins.ArrayJob{ - Size: 100, -} - -func getK8sPodTask(t *testing.T, annotations map[string]string) *core.TaskTemplate { - marshalledPodspec, err := json.Marshal(podSpec) - if err != nil { - t.Fatal(err) - } - - structObj := &structpb.Struct{} - if err := json.Unmarshal(marshalledPodspec, structObj); err != nil { - t.Fatal(err) - } - - custom := &structpb.Struct{} - if err := utils.MarshalStruct(&arrayJob, custom); err != nil { - t.Fatal(err) - } - - return &core.TaskTemplate{ - TaskTypeVersion: 2, - Config: map[string]string{ - primaryContainerKey: testPrimaryContainerName, - }, - Target: &core.TaskTemplate_K8SPod{ - K8SPod: &core.K8SPod{ - PodSpec: structObj, - Metadata: &core.K8SObjectMetadata{ - Labels: map[string]string{ - "label": "foo", - }, - Annotations: annotations, - }, - }, - }, - Custom: custom, - } -} - -func TestBuildPodMapTask(t *testing.T) { - tMeta := &mocks.TaskExecutionMetadata{} - tMeta.OnGetSecurityContext().Return(core.SecurityContext{}) - tMeta.OnGetK8sServiceAccount().Return("sa") - pod, err := buildPodMapTask(getK8sPodTask(t, map[string]string{ - "anno": "bar", - }), tMeta) - assert.NoError(t, err) - var expected = podSpec.DeepCopy() - expected.RestartPolicy = v1.RestartPolicyNever - assert.EqualValues(t, *expected, pod.Spec) - assert.EqualValues(t, map[string]string{ - "label": "foo", - }, pod.Labels) - assert.EqualValues(t, map[string]string{ - "anno": "bar", - "primary_container_name": "primary container", - }, pod.Annotations) -} - -func TestBuildPodMapTask_Errors(t *testing.T) { - t.Run("invalid task template", func(t *testing.T) { - _, err := buildPodMapTask(&core.TaskTemplate{}, nil) - assert.EqualError(t, err, "[BadTaskSpecification] Missing pod spec for task") - }) - b, err := json.Marshal(podSpec) - if err != nil { - t.Fatal(err) - } - - structObj := &structpb.Struct{} - if err := json.Unmarshal(b, structObj); err != nil { - t.Fatal(err) - } - t.Run("missing primary container annotation", func(t *testing.T) { - _, err = buildPodMapTask(&core.TaskTemplate{ - Target: &core.TaskTemplate_K8SPod{ - K8SPod: &core.K8SPod{ - PodSpec: structObj, - }, - }, - }, nil) - assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config missing [primary_container_name] key in [map[]]") - }) -} - -func TestBuildPodMapTask_AddAnnotations(t *testing.T) { - tMeta := &mocks.TaskExecutionMetadata{} - tMeta.OnGetSecurityContext().Return(core.SecurityContext{}) - tMeta.OnGetK8sServiceAccount().Return("sa") - podTask := getK8sPodTask(t, nil) - pod, err := buildPodMapTask(podTask, tMeta) - assert.NoError(t, err) - var expected = podSpec.DeepCopy() - expected.RestartPolicy = v1.RestartPolicyNever - assert.EqualValues(t, *expected, pod.Spec) - assert.EqualValues(t, map[string]string{ - "label": "foo", - }, pod.Labels) - assert.EqualValues(t, map[string]string{ - "primary_container_name": "primary container", - }, pod.Annotations) -} - -func TestFlyteArrayJobToK8sPodTemplate(t *testing.T) { - ctx := context.TODO() - tr := &mocks.TaskReader{} - tr.OnRead(ctx).Return(getK8sPodTask(t, map[string]string{ - "anno": "bar", - }), nil) - - ir := &mocks2.InputReader{} - ir.OnGetInputPrefixPath().Return("/prefix/") - ir.OnGetInputPath().Return("/prefix/inputs.pb") - ir.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - - tMeta := &mocks.TaskExecutionMetadata{} - tMeta.OnGetNamespace().Return("n") - tMeta.OnGetLabels().Return(map[string]string{ - "tCtx": "label", - }) - tMeta.OnGetAnnotations().Return(map[string]string{ - "tCtx": "anno", - }) - tMeta.OnGetOwnerReference().Return(v12.OwnerReference{}) - tMeta.OnGetSecurityContext().Return(core.SecurityContext{}) - tMeta.OnGetK8sServiceAccount().Return("sa") - mockResourceOverrides := mocks.TaskOverrides{} - mockResourceOverrides.OnGetResources().Return(&v1.ResourceRequirements{ - Requests: v1.ResourceList{ - "ephemeral-storage": resource.MustParse("1024Mi"), - }, - Limits: v1.ResourceList{ - "ephemeral-storage": resource.MustParse("2048Mi"), - }, - }) - tMeta.OnGetOverrides().Return(&mockResourceOverrides) - tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) - tID := &mocks.TaskExecutionID{} - tID.OnGetID().Return(core.TaskExecutionIdentifier{ - NodeExecutionId: &core.NodeExecutionIdentifier{ - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my_name", - Project: "my_project", - Domain: "my_domain", - }, - }, - RetryAttempt: 1, - }) - tMeta.OnGetTaskExecutionID().Return(tID) - - outputReader := &mocks2.OutputWriter{} - outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) - outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) - outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("")) - - tCtx := &mocks.TaskExecutionContext{} - tCtx.OnTaskReader().Return(tr) - tCtx.OnInputReader().Return(ir) - tCtx.OnTaskExecutionMetadata().Return(tMeta) - tCtx.OnOutputWriter().Return(outputReader) - - pod, job, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, "") - assert.NoError(t, err) - assert.EqualValues(t, metav1.ObjectMeta{ - Namespace: "n", - Labels: map[string]string{ - "tCtx": "label", - "label": "foo", - }, - Annotations: map[string]string{ - "tCtx": "anno", - "anno": "bar", - "primary_container_name": "primary container", - "cluster-autoscaler.kubernetes.io/safe-to-evict": "false", - }, - OwnerReferences: []metav1.OwnerReference{ - {}, - }, - }, pod.ObjectMeta) - assert.EqualValues(t, &arrayJob, job) - defaultMemoryFromConfig := resource.MustParse("1024Mi") - assert.EqualValues(t, v1.ResourceRequirements{ - Requests: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("1"), - v1.ResourceMemory: defaultMemoryFromConfig, - v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"), - }, - Limits: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("1"), - v1.ResourceMemory: defaultMemoryFromConfig, - v1.ResourceEphemeralStorage: resource.MustParse("2048Mi"), - }, - }, pod.Spec.Containers[0].Resources, fmt.Sprintf("%+v", pod.Spec.Containers[0].Resources)) - assert.EqualValues(t, []v1.EnvVar{ - { - Name: "FLYTE_INTERNAL_EXECUTION_ID", - Value: "my_name", - }, - { - Name: "FLYTE_INTERNAL_EXECUTION_PROJECT", - Value: "my_project", - }, - { - Name: "FLYTE_INTERNAL_EXECUTION_DOMAIN", - Value: "my_domain", - }, - { - Name: "FLYTE_ATTEMPT_NUMBER", - Value: "1", - }, - }, pod.Spec.Containers[0].Env) -} diff --git a/go/tasks/plugins/k8s/pod/container.go b/go/tasks/plugins/k8s/pod/container.go index 197ff502f..0ec8240ec 100644 --- a/go/tasks/plugins/k8s/pod/container.go +++ b/go/tasks/plugins/k8s/pod/container.go @@ -12,7 +12,7 @@ import ( ) const ( - containerTaskType = "container" + ContainerTaskType = "container" ) type containerPodBuilder struct { diff --git a/go/tasks/plugins/k8s/pod/container_test.go b/go/tasks/plugins/k8s/pod/container_test.go index 3832b00f3..d4b907915 100644 --- a/go/tasks/plugins/k8s/pod/container_test.go +++ b/go/tasks/plugins/k8s/pod/container_test.go @@ -106,9 +106,8 @@ func dummyContainerTaskContext(resources *v1.ResourceRequirements, command []str } func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { - p := plugin{defaultPodBuilder, podBuilders} taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} - r, err := p.BuildIdentityResource(context.TODO(), taskMetadata) + r, err := DefaultPodPlugin.BuildIdentityResource(context.TODO(), taskMetadata) assert.NoError(t, err) assert.NotNil(t, r) _, ok := r.(*v1.Pod) @@ -117,12 +116,11 @@ func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { } func TestContainerTaskExecutor_BuildResource(t *testing.T) { - p := plugin{defaultPodBuilder, podBuilders} command := []string{"command"} args := []string{"{{.Input}}"} taskCtx := dummyContainerTaskContext(containerResourceRequirements, command, args) - r, err := p.BuildResource(context.TODO(), taskCtx) + r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.NoError(t, err) assert.NotNil(t, r) j, ok := r.(*v1.Pod) @@ -142,7 +140,6 @@ func TestContainerTaskExecutor_BuildResource(t *testing.T) { } func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { - p := plugin{defaultPodBuilder, podBuilders} j := &v1.Pod{ Status: v1.PodStatus{}, } @@ -150,21 +147,21 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { ctx := context.TODO() t.Run("running", func(t *testing.T) { j.Status.Phase = v1.PodRunning - phaseInfo, err := p.GetTaskPhase(ctx, nil, j) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase()) }) t.Run("queued", func(t *testing.T) { j.Status.Phase = v1.PodPending - phaseInfo, err := p.GetTaskPhase(ctx, nil, j) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, phaseInfo.Phase()) }) t.Run("failNoCondition", func(t *testing.T) { j.Status.Phase = v1.PodFailed - phaseInfo, err := p.GetTaskPhase(ctx, nil, j) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) ec := phaseInfo.Err().GetCode() @@ -180,7 +177,7 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { Type: v1.PodReasonUnschedulable, }, } - phaseInfo, err := p.GetTaskPhase(ctx, nil, j) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) ec := phaseInfo.Err().GetCode() @@ -189,7 +186,7 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { t.Run("success", func(t *testing.T) { j.Status.Phase = v1.PodSucceeded - phaseInfo, err := p.GetTaskPhase(ctx, nil, j) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, j) assert.NoError(t, err) assert.NotNil(t, phaseInfo) assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) @@ -197,14 +194,12 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { } func TestContainerTaskExecutor_GetProperties(t *testing.T) { - p := plugin{defaultPodBuilder, podBuilders} expected := k8s.PluginProperties{} - assert.Equal(t, expected, p.GetProperties()) + assert.Equal(t, expected, DefaultPodPlugin.GetProperties()) } func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { ctx := context.TODO() - p := plugin{defaultPodBuilder, podBuilders} reason := "InvalidImageName" message := "Failed to apply default image tag \"TEST/flyteorg/myapp:latest\": couldn't parse image reference" + " \"TEST/flyteorg/myapp:latest\": invalid reference format: repository name must be lowercase" @@ -235,7 +230,7 @@ func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { t.Run("failInvalidImageName", func(t *testing.T) { pendingPod.Status.Phase = v1.PodPending - phaseInfo, err := p.GetTaskPhase(ctx, nil, pendingPod) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(ctx, nil, pendingPod) finalReason := fmt.Sprintf("|%s", reason) finalMessage := fmt.Sprintf("|%s", message) assert.NoError(t, err) diff --git a/go/tasks/plugins/k8s/pod/plugin.go b/go/tasks/plugins/k8s/pod/plugin.go index fc15357dd..9e0aea1cb 100644 --- a/go/tasks/plugins/k8s/pod/plugin.go +++ b/go/tasks/plugins/k8s/pod/plugin.go @@ -11,6 +11,7 @@ import ( pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" v1 "k8s.io/api/core/v1" @@ -19,13 +20,15 @@ import ( const ( podTaskType = "pod" - primaryContainerKey = "primary_container_name" + PrimaryContainerKey = "primary_container_name" ) var ( - defaultPodBuilder = containerPodBuilder{} - podBuilders = map[string]podBuilder{ - sidecarTaskType: sidecarPodBuilder{}, + DefaultPodPlugin = plugin{ + defaultPodBuilder: containerPodBuilder{}, + podBuilders: map[string]podBuilder{ + SidecarTaskType: sidecarPodBuilder{}, + }, } ) @@ -80,7 +83,16 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu return pod, nil } -func (plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { +func (p plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { + logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + return p.GetTaskPhaseWithLogs(ctx, pluginContext, r, logPlugin, " (User)") +} + +func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.PluginContext, r client.Object, logPlugin tasklog.Plugin, logSuffix string) (pluginsCore.PhaseInfo, error) { pod := r.(*v1.Pod) transitionOccurredAt := flytek8s.GetLastTransitionOccurredAt(pod).Time @@ -89,7 +101,7 @@ func (plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, } if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") + taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, pod, 0, logSuffix) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -109,7 +121,7 @@ func (plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, return pluginsCore.PhaseInfoUndefined, nil } - primaryContainerName, exists := r.GetAnnotations()[primaryContainerKey] + primaryContainerName, exists := r.GetAnnotations()[PrimaryContainerKey] if !exists { // if the primary container annotation dos not exist, then the task requires all containers // to succeed to declare success. therefore, if the pod is not in one of the above states we @@ -133,30 +145,25 @@ func (plugin) GetProperties() k8s.PluginProperties { } func init() { - podPlugin := plugin{ - defaultPodBuilder: defaultPodBuilder, - podBuilders: podBuilders, - } - - // Register containerTaskType and sidecarTaskType plugin entries. These separate task types + // Register ContainerTaskType and SidecarTaskType plugin entries. These separate task types // still exist within the system, only now both are evaluated using the same internal pod plugin // instance. This simplifies migration as users may keep the same configuration but are // seamlessly transitioned from separate container and sidecar plugins to a single pod plugin. pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ - ID: containerTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{containerTaskType}, + ID: ContainerTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType}, ResourceToWatch: &v1.Pod{}, - Plugin: podPlugin, + Plugin: DefaultPodPlugin, IsDefault: true, }) pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ - ID: sidecarTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{sidecarTaskType}, + ID: SidecarTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{SidecarTaskType}, ResourceToWatch: &v1.Pod{}, - Plugin: podPlugin, + Plugin: DefaultPodPlugin, IsDefault: false, }) @@ -164,9 +171,9 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ ID: podTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{containerTaskType, sidecarTaskType}, + RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType, SidecarTaskType}, ResourceToWatch: &v1.Pod{}, - Plugin: podPlugin, + Plugin: DefaultPodPlugin, IsDefault: true, }) } diff --git a/go/tasks/plugins/k8s/pod/sidecar.go b/go/tasks/plugins/k8s/pod/sidecar.go index ba208a8a8..7a8495258 100644 --- a/go/tasks/plugins/k8s/pod/sidecar.go +++ b/go/tasks/plugins/k8s/pod/sidecar.go @@ -15,7 +15,7 @@ import ( ) const ( - sidecarTaskType = "sidecar" + SidecarTaskType = "sidecar" ) // Why, you might wonder do we recreate the generated go struct generated from the plugins.SidecarJob proto? Because @@ -80,13 +80,13 @@ func (sidecarPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTempla func getPrimaryContainerNameFromConfig(task *core.TaskTemplate) (string, error) { if len(task.GetConfig()) == 0 { return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", primaryContainerKey) + "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", PrimaryContainerKey) } - primaryContainerName, ok := task.GetConfig()[primaryContainerKey] + primaryContainerName, ok := task.GetConfig()[PrimaryContainerKey] if !ok { return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig()) + "invalid TaskSpecification, config missing [%s] key in [%v]", PrimaryContainerKey, task.GetConfig()) } return primaryContainerName, nil @@ -144,7 +144,7 @@ func (sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, tas return err } - pod.Annotations[primaryContainerKey] = primaryContainerName + pod.Annotations[PrimaryContainerKey] = primaryContainerName return nil } diff --git a/go/tasks/plugins/k8s/pod/sidecar_test.go b/go/tasks/plugins/k8s/pod/sidecar_test.go index 77cb40afa..e301e5bf2 100644 --- a/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -51,7 +51,7 @@ func getSidecarTaskTemplateForTest(sideCarJob sidecarJob) *core.TaskTemplate { panic(err) } return &core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, Custom: &structObj, } } @@ -199,10 +199,10 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { } task := core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ - primaryContainerKey: "primary container", + PrimaryContainerKey: "primary container", }, Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ @@ -241,12 +241,11 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { DefaultMemoryRequest: resource.MustParse("1024Mi"), GpuResourceName: ResourceNvidiaGPU, })) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) - res, err := p.BuildResource(context.TODO(), taskCtx) + res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - primaryContainerKey: "primary container", + PrimaryContainerKey: "primary container", "anno": "bar", }, res.GetAnnotations()) assert.EqualValues(t, map[string]string{ @@ -284,10 +283,10 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { task := core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ - primaryContainerKey: "primary container", + PrimaryContainerKey: "primary container", }, Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ @@ -303,9 +302,8 @@ func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { }, } - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) - _, err := p.BuildResource(context.TODO(), taskCtx) + _, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") } @@ -323,11 +321,11 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { } task := core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, Custom: structObj, TaskTypeVersion: 1, Config: map[string]string{ - primaryContainerKey: "primary container", + PrimaryContainerKey: "primary container", }, } @@ -352,12 +350,11 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) - res, err := p.BuildResource(context.TODO(), taskCtx) + res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - primaryContainerKey: "primary container", + PrimaryContainerKey: "primary container", }, res.GetAnnotations()) assert.EqualValues(t, map[string]string{}, res.GetLabels()) @@ -405,7 +402,7 @@ func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { } task := core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, Custom: structObj, TaskTypeVersion: 1, } @@ -418,16 +415,15 @@ func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) - _, err = p.BuildResource(context.TODO(), taskCtx) + _, err = DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config needs to be non-empty and include missing [primary_container_name] key") task.Config = map[string]string{ "foo": "bar", } taskCtx = getDummySidecarTaskContext(&task, sidecarResourceRequirements) - _, err = p.BuildResource(context.TODO(), taskCtx) + _, err = DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config missing [primary_container_name] key in [map[foo:bar]]") } @@ -446,7 +442,7 @@ func TestBuildSidecarResource(t *testing.T) { t.Fatal(err) } task := core.TaskTemplate{ - Type: sidecarTaskType, + Type: SidecarTaskType, Custom: &sidecarCustom, } @@ -471,12 +467,11 @@ func TestBuildSidecarResource(t *testing.T) { DefaultCPURequest: resource.MustParse("1024m"), DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&task, sidecarResourceRequirements) - res, err := p.BuildResource(context.TODO(), taskCtx) + res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - primaryContainerKey: "a container", + PrimaryContainerKey: "a container", "a1": "a1", }, res.GetAnnotations()) @@ -528,9 +523,8 @@ func TestBuildSidecarReosurceMissingAnnotationsAndLabels(t *testing.T) { task := getSidecarTaskTemplateForTest(sideCarJob) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) - resp, err := p.BuildResource(context.TODO(), taskCtx) + resp, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.NoError(t, err) assert.EqualValues(t, map[string]string{}, resp.GetLabels()) assert.EqualValues(t, map[string]string{"primary_container_name": "PrimaryContainer"}, resp.GetAnnotations()) @@ -550,9 +544,8 @@ func TestBuildSidecarResourceMissingPrimary(t *testing.T) { task := getSidecarTaskTemplateForTest(sideCarJob) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) - _, err := p.BuildResource(context.TODO(), taskCtx) + _, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.True(t, errors.Is(err, errors2.Errorf("BadTaskSpecification", ""))) } @@ -584,11 +577,10 @@ func TestGetTaskSidecarStatus(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - primaryContainerKey: "PrimaryContainer", + PrimaryContainerKey: "PrimaryContainer", }) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) - phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, expectedTaskPhase, phaseInfo.Phase(), "Expected [%v] got [%v] instead, for podPhase [%v]", expectedTaskPhase, phaseInfo.Phase(), podPhase) @@ -612,11 +604,10 @@ func TestDemystifiedSidecarStatus_PrimaryFailed(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - primaryContainerKey: "Primary", + PrimaryContainerKey: "Primary", }) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) - phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) } @@ -638,11 +629,10 @@ func TestDemystifiedSidecarStatus_PrimarySucceeded(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - primaryContainerKey: "Primary", + PrimaryContainerKey: "Primary", }) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) - phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) } @@ -664,11 +654,10 @@ func TestDemystifiedSidecarStatus_PrimaryRunning(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - primaryContainerKey: "Primary", + PrimaryContainerKey: "Primary", }) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) - phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase()) } @@ -685,17 +674,15 @@ func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - primaryContainerKey: "Primary", + PrimaryContainerKey: "Primary", }) - p := &plugin{defaultPodBuilder, podBuilders} taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) - phaseInfo, err := p.GetTaskPhase(context.TODO(), taskCtx, res) + phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) assert.Nil(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase()) } func TestGetProperties(t *testing.T) { - p := &plugin{defaultPodBuilder, podBuilders} expected := k8s.PluginProperties{} - assert.Equal(t, expected, p.GetProperties()) + assert.Equal(t, expected, DefaultPodPlugin.GetProperties()) }