diff --git a/go/tasks/v1/flytek8s/k8splugin_state.go b/go/tasks/v1/flytek8s/k8splugin_state.go new file mode 100644 index 000000000..aabd10f2f --- /dev/null +++ b/go/tasks/v1/flytek8s/k8splugin_state.go @@ -0,0 +1,71 @@ +package flytek8s + +import ( + "encoding/json" + "fmt" + + "github.com/lyft/flyteplugins/go/tasks/v1/types" +) + +const stateKey = "os" + +// This status internal state of the object not read/updated by upstream components (eg. Node manager) +type K8sObjectStatus int + +const ( + k8sObjectUnknown K8sObjectStatus = iota + k8sObjectExists + k8sObjectDeleted +) + +func (q K8sObjectStatus) String() string { + switch q { + case k8sObjectUnknown: + return "NotStarted" + case k8sObjectExists: + return "Running" + case k8sObjectDeleted: + return "Deleted" + } + return "IllegalK8sObjectStatus" +} + +// This status internal state of the object not read/updated by upstream components (eg. Node manager) +type K8sObjectState struct { + Status K8sObjectStatus `json:"s"` + TerminalPhase types.TaskPhase `json:"tp"` +} + +func retrieveK8sObjectState(customState map[string]interface{}) (K8sObjectStatus, types.TaskPhase, error) { + v, found := customState[stateKey] + if !found { + return k8sObjectUnknown, types.TaskPhaseUnknown, nil + } + + state, err := convertToState(v) + if err != nil { + return k8sObjectUnknown, types.TaskPhaseUnknown, err + } + return state.Status, state.TerminalPhase, nil +} + +func storeK8sObjectState(status K8sObjectStatus, phase types.TaskPhase) map[string]interface{} { + customState := make(map[string]interface{}) + customState[stateKey] = K8sObjectState{Status: status, TerminalPhase: phase} + return customState +} + +func convertToState(iface interface{}) (K8sObjectState, error) { + raw, err := json.Marshal(iface) + if err != nil { + return K8sObjectState{}, err + } + + item := &K8sObjectState{} + err = json.Unmarshal(raw, item) + if err != nil { + return K8sObjectState{}, fmt.Errorf("failed to unmarshal state into K8sObjectState") + } + + return *item, nil +} diff --git a/go/tasks/v1/flytek8s/k8splugin_state_test.go b/go/tasks/v1/flytek8s/k8splugin_state_test.go new file mode 100644 index 000000000..6495d2bd5 --- /dev/null +++ b/go/tasks/v1/flytek8s/k8splugin_state_test.go @@ -0,0 +1,26 @@ +package flytek8s + +import ( + "testing" + "github.com/stretchr/testify/assert" + "github.com/lyft/flyteplugins/go/tasks/v1/types" + "encoding/json" +) + +func TestRetrieveK8sObjectStatus(t *testing.T) { + status := k8sObjectExists + phase := types.TaskPhaseRunning + customState := storeK8sObjectState(status, phase) + + raw, err := json.Marshal(customState) + assert.NoError(t, err) + + unmarshalledCustomState := make(map[string]interface{}) + err = json.Unmarshal(raw, &unmarshalledCustomState) + assert.NoError(t, err) + + retrievedStatus, retrievedPhase, err := retrieveK8sObjectState(unmarshalledCustomState) + assert.NoError(t, err) + assert.Equal(t, status, retrievedStatus) + assert.Equal(t, phase, retrievedPhase) +} diff --git a/go/tasks/v1/flytek8s/plugin_executor.go b/go/tasks/v1/flytek8s/plugin_executor.go old mode 100755 new mode 100644 index e964137de..d4adb025a --- a/go/tasks/v1/flytek8s/plugin_executor.go +++ b/go/tasks/v1/flytek8s/plugin_executor.go @@ -15,6 +15,8 @@ import ( "strings" + "fmt" + eventErrors "github.com/lyft/flyteidl/clients/go/events/errors" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/v1/errors" @@ -251,6 +253,38 @@ func (e *K8sTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.Tas finalStatus = types.TaskStatusPermanentFailure(err) } else { AddObjectMetadata(taskCtx, o) + + // NOTE: To ensure objects are cleaned up, the plugins need a persistent step in addition to upstream plugin executor + // state machine. Once the object reaches its terminal state, we commit the completion in two steps: + // Round1: mark the object as deleted in state store (object's custom state) + // Round2: instead of regular retrieval (which may fail in this case), just delete the object + objStatus, terminalPhase, err := retrieveK8sObjectState(taskCtx.GetCustomState()) + if err != nil { + logger.Warningf(ctx, "Failed to retrieve object status: %v. Error: %v", + taskCtx.GetTaskExecutionID().GetGeneratedName(), err) + return types.TaskStatusUndefined, err + } + + if objStatus == k8sObjectDeleted { + // kill the object execution if still alive + err = instance.kubeClient.Delete(ctx, o) + + if err != nil { + if IsK8sObjectNotExists(err) { + logger.Debugf(ctx, "the k8s object %v was found to have successfully exited after completion", taskCtx.GetTaskExecutionID().GetGeneratedName()) + } else { + return types.TaskStatusUndefined, err + } + } else { + logger.Debugf(ctx, "deleted the k8s object %v in terminal phase", taskCtx.GetTaskExecutionID().GetGeneratedName()) + } + finalStatus.Phase = terminalPhase + if terminalPhase.IsPermanentFailure() { + finalStatus.Err = fmt.Errorf("k8s task failed, error info not available") + } + return finalStatus, nil + } + finalStatus, info, err = e.getResource(ctx, taskCtx, o) if err != nil { return types.TaskStatusUndefined, err @@ -283,10 +317,20 @@ func (e *K8sTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.Tas // This must happen after sending admin event. It's safe against partial failures because if the event failed, we will // simply retry in the next round. If the event succeeded but this failed, we will try again the next round to send // the same event (idempotent) and then come here again... - if finalStatus.Phase.IsTerminal() && len(o.GetFinalizers()) > 0 { - err = e.ClearFinalizers(ctx, o) - if err != nil { - return types.TaskStatusUndefined, err + if finalStatus.Phase.IsTerminal() { + if len(o.GetFinalizers()) > 0 { + err = e.ClearFinalizers(ctx, o) + if err != nil { + return types.TaskStatusUndefined, err + } + } + + if e.handler.GetProperties().DeleteResourceOnAbort { + finalStatus = types.TaskStatus{ + Phase: taskCtx.GetPhase(), + PhaseVersion: taskCtx.GetPhaseVersion(), + State: storeK8sObjectState(k8sObjectDeleted, finalStatus.Phase), + } } } diff --git a/go/tasks/v1/flytek8s/plugin_executor_test.go b/go/tasks/v1/flytek8s/plugin_executor_test.go index d2d9814bc..9758fd2d6 100755 --- a/go/tasks/v1/flytek8s/plugin_executor_test.go +++ b/go/tasks/v1/flytek8s/plugin_executor_test.go @@ -285,6 +285,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedNewStatus.PhaseVersion = uint32(1) tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(1)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), @@ -326,6 +327,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedNewStatus.PhaseVersion = uint32(1) tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(1)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), @@ -371,6 +373,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedNewStatus.PhaseVersion = uint32(1) tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(1)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil) s, err := k.CheckTaskStatus(ctx, tctx, nil) @@ -395,6 +398,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedOldPhase := types.TaskPhaseRunning tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(0)) + tctx.On("GetCustomState").Return(nil) evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil) @@ -427,9 +431,6 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { } assert.NoError(t, c.Create(ctx, testPod)) - defer func() { - assert.NoError(t, c.Delete(ctx, testPod)) - }() assert.NoError(t, store.WriteProtobuf(ctx, tctx.GetErrorFile(), storage.Options{}, &core.ErrorDocument{ Error: &core.ContainerError{ @@ -444,14 +445,36 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedOldPhase := types.TaskPhaseQueued tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(0)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil) + mockResourceHandler.On("GetProperties").Return(types.ExecutorProperties{DeleteResourceOnAbort:true}) evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil) s, err := k.CheckTaskStatus(ctx, tctx, nil) - assert.Nil(t, s.State) + // first time after termination, we expect phase to not change but have custom state populated + assert.NotNil(t, s.State) + assert.Equal(t, flytek8s.K8sObjectStatus(2), s.State["os"].(flytek8s.K8sObjectState).Status) + assert.Equal(t, types.TaskPhase(3), s.State["os"].(flytek8s.K8sObjectState).TerminalPhase) assert.NoError(t, err) + assert.Equal(t, types.TaskPhaseQueued, s.Phase) + + // another round of CheckTaskStatus with custom state from previous round + mockResourceHandler1 := &mocks.K8sResourceHandler{} + tctx1 := getMockTaskContext() + mockResourceHandler1.On("GetProperties").Return(types.ExecutorProperties{DeleteResourceOnAbort:true}) + mockResourceHandler1.On("BuildIdentityResource", mock.Anything, tctx1).Return(&v1.Pod{}, nil) + mockResourceHandler1.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil) + + k1 := flytek8s.NewK8sTaskExecutorForResource("x1", &v1.Pod{}, mockResourceHandler1, time.Second) + assert.NoError(t, k1.Initialize(ctx, params)) + + tctx1.On("GetPhase").Return(expectedOldPhase) + tctx1.On("GetPhaseVersion").Return(uint32(0)) + tctx1.On("GetCustomState").Return(s.State) + s, err = k1.CheckTaskStatus(ctx, tctx1, nil) + assert.Nil(t, s.State) assert.Equal(t, types.TaskPhasePermanentFailure, s.Phase) }) @@ -494,6 +517,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedOldPhase := types.TaskPhaseQueued tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(0)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil) evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }), @@ -536,6 +560,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) { expectedOldPhase := types.TaskPhaseQueued tctx.On("GetPhase").Return(expectedOldPhase) tctx.On("GetPhaseVersion").Return(uint32(0)) + tctx.On("GetCustomState").Return(nil) mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusQueued, nil, nil) s, err := k.CheckTaskStatus(ctx, tctx, nil)