Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Delete sidecar pod when task completes. #26

Merged
merged 18 commits into from
Oct 11, 2019
70 changes: 70 additions & 0 deletions go/tasks/v1/flytek8s/k8splugin_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package flytek8s

import (
"fmt"
"github.com/lyft/flyteplugins/go/tasks/v1/types"
"encoding/json"
)

const stateKey = "os"

// This status internal state of the object not read/updated by upstream components (eg. Node manager)
type K8sObjectStatus int

const (
k8sObjectUnknown K8sObjectStatus = iota
k8sObjectExists
k8sObjectDeleted
)

func (q K8sObjectStatus) String() string {
switch q {
case k8sObjectUnknown:
return "NotStarted"
case k8sObjectExists:
return "Running"
case k8sObjectDeleted:
return "Deleted"
}
return "IllegalK8sObjectStatus"
}

// This status internal state of the object not read/updated by upstream components (eg. Node manager)
type K8sObjectState struct {
Status K8sObjectStatus `json:"s"`
TerminalPhase types.TaskPhase `json:"tp"`
}

func retrieveK8sObjectStatus(customState map[string]interface{}) (K8sObjectStatus, types.TaskPhase, error) {
for k, v := range customState {
surindersinghp marked this conversation as resolved.
Show resolved Hide resolved
if k == stateKey {
state, err := convertToState(v)
if err != nil {
return k8sObjectUnknown, types.TaskPhaseUnknown, err
}
return state.Status, state.TerminalPhase, nil
}
}
return k8sObjectUnknown, types.TaskPhaseUnknown, nil
}

func storeK8sObjectStatus(status K8sObjectStatus, phase types.TaskPhase) map[string]interface{} {
customState := make(map[string]interface{})
customState[stateKey] = K8sObjectState{Status:status, TerminalPhase:phase}
return customState
}

func convertToState(iface interface{}) (K8sObjectState, error) {
raw, err := json.Marshal(iface)
if err != nil {
return K8sObjectState{}, err
}

item := &K8sObjectState{}
err = json.Unmarshal(raw, item)
if err != nil {
return K8sObjectState{}, fmt.Errorf("failed to unmarshal state into K8sObjectState")
}

return *item, nil
}
26 changes: 26 additions & 0 deletions go/tasks/v1/flytek8s/k8splugin_state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package flytek8s

import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/lyft/flyteplugins/go/tasks/v1/types"
"encoding/json"
)

func TestRetrieveK8sObjectStatus(t *testing.T) {
status := k8sObjectExists
phase := types.TaskPhaseRunning
customState := storeK8sObjectStatus(status, phase)

raw, err := json.Marshal(customState)
assert.NoError(t, err)

unmarshalledCustomState := make(map[string]interface{})
err = json.Unmarshal(raw, &unmarshalledCustomState)
assert.NoError(t, err)

retrievedStatus, retrievedPhase, err := retrieveK8sObjectStatus(unmarshalledCustomState)
assert.NoError(t, err)
assert.Equal(t, status, retrievedStatus)
assert.Equal(t, phase, retrievedPhase)
}
surindersinghp marked this conversation as resolved.
Show resolved Hide resolved
45 changes: 41 additions & 4 deletions go/tasks/v1/flytek8s/plugin_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ func (e *K8sTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.Tas
PhaseVersion: taskCtx.GetPhaseVersion(),
}

// NOTE: To ensure objects are cleaned up, the plugins need a persistent step in addition to upstream plugin executor
// state machine. Once the object reaches its terminal state, we commit the completion in two steps:
// Round1: mark the object as deleted in state store (object's custom state)
// Round2: instead of regular retrieval (which may fail in this case), just delete the object
objStatus, terminalPhase, err := retrieveK8sObjectStatus(taskCtx.GetCustomState())
if err != nil {
logger.Warningf(ctx, "Failed to retrieve object status: %v. Error: %v",
taskCtx.GetTaskExecutionID().GetGeneratedName(), err)
return types.TaskStatusUndefined, err
}
if objStatus == k8sObjectDeleted {
// kill the object execution if still live
if e.handler.GetProperties().DeleteResourceOnAbort {
err = instance.kubeClient.Delete(ctx, o)
surindersinghp marked this conversation as resolved.
Show resolved Hide resolved

if err != nil {
if IsK8sObjectNotExists(err) {
logger.Debugf(ctx, "the k8s object %v was found to have successfully exited after completion", taskCtx.GetTaskExecutionID().GetGeneratedName())
} else {
return types.TaskStatusUndefined, err
}
} else {
logger.Debugf(ctx, "deleted the k8s object %v in terminal phase", taskCtx.GetTaskExecutionID().GetGeneratedName())
}
}
finalStatus.Phase = terminalPhase
surindersinghp marked this conversation as resolved.
Show resolved Hide resolved
return finalStatus, nil
}

var info *events.TaskEventInfo

if err != nil {
Expand Down Expand Up @@ -283,10 +312,18 @@ func (e *K8sTaskExecutor) CheckTaskStatus(ctx context.Context, taskCtx types.Tas
// This must happen after sending admin event. It's safe against partial failures because if the event failed, we will
// simply retry in the next round. If the event succeeded but this failed, we will try again the next round to send
// the same event (idempotent) and then come here again...
if finalStatus.Phase.IsTerminal() && len(o.GetFinalizers()) > 0 {
err = e.ClearFinalizers(ctx, o)
if err != nil {
return types.TaskStatusUndefined, err
if finalStatus.Phase.IsTerminal() {
if len(o.GetFinalizers()) > 0 {
err = e.ClearFinalizers(ctx, o)
if err != nil {
return types.TaskStatusUndefined, err
}
}

finalStatus = types.TaskStatus{
Phase: taskCtx.GetPhase(),
PhaseVersion: taskCtx.GetPhaseVersion(),
State: storeK8sObjectStatus(k8sObjectDeleted, finalStatus.Phase),
surindersinghp marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
29 changes: 28 additions & 1 deletion go/tasks/v1/flytek8s/plugin_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedNewStatus.PhaseVersion = uint32(1)
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(1))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil)

evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }),
Expand Down Expand Up @@ -326,6 +327,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedNewStatus.PhaseVersion = uint32(1)
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(1))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil)

evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }),
Expand Down Expand Up @@ -371,6 +373,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedNewStatus.PhaseVersion = uint32(1)
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(1))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(expectedNewStatus, nil, nil)

s, err := k.CheckTaskStatus(ctx, tctx, nil)
Expand All @@ -395,6 +398,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedOldPhase := types.TaskPhaseRunning
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(0))
tctx.On("GetCustomState").Return(nil)

evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }),
mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil)
Expand Down Expand Up @@ -444,14 +448,35 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedOldPhase := types.TaskPhaseQueued
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(0))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil)

evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }),
mock.MatchedBy(func(e *event.TaskExecutionEvent) bool { return true })).Return(nil)

s, err := k.CheckTaskStatus(ctx, tctx, nil)
assert.Nil(t, s.State)
// first time after termination, we expect phase to not change but have custom state populated
assert.NotNil(t, s.State)
assert.Equal(t, flytek8s.K8sObjectStatus(2), s.State["os"].(flytek8s.K8sObjectState).Status)
assert.Equal(t, types.TaskPhase(3), s.State["os"].(flytek8s.K8sObjectState).TerminalPhase)
assert.NoError(t, err)
assert.Equal(t, types.TaskPhaseQueued, s.Phase)

// another round of CheckTaskStatus with custom state from previous round
mockResourceHandler1 := &mocks.K8sResourceHandler{}
tctx1 := getMockTaskContext()
mockResourceHandler1.On("GetProperties").Return(types.ExecutorProperties{DeleteResourceOnAbort:true})
mockResourceHandler1.On("BuildIdentityResource", mock.Anything, tctx1).Return(&v1.Pod{}, nil)
mockResourceHandler1.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil)

k1 := flytek8s.NewK8sTaskExecutorForResource("x1", &v1.Pod{}, mockResourceHandler1, time.Second)
assert.NoError(t, k1.Initialize(ctx, params))

tctx1.On("GetPhase").Return(expectedOldPhase)
tctx1.On("GetPhaseVersion").Return(uint32(0))
tctx1.On("GetCustomState").Return(s.State)
s, err = k1.CheckTaskStatus(ctx, tctx1, nil)
assert.Nil(t, s.State)
assert.Equal(t, types.TaskPhasePermanentFailure, s.Phase)
})

Expand Down Expand Up @@ -494,6 +519,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedOldPhase := types.TaskPhaseQueued
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(0))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusSucceeded, nil, nil)

evRecorder.On("RecordTaskEvent", mock.MatchedBy(func(c context.Context) bool { return true }),
Expand Down Expand Up @@ -536,6 +562,7 @@ func TestK8sTaskExecutor_CheckTaskStatus(t *testing.T) {
expectedOldPhase := types.TaskPhaseQueued
tctx.On("GetPhase").Return(expectedOldPhase)
tctx.On("GetPhaseVersion").Return(uint32(0))
tctx.On("GetCustomState").Return(nil)
mockResourceHandler.On("GetTaskStatus", mock.Anything, mock.Anything, mock.MatchedBy(func(o *v1.Pod) bool { return true })).Return(types.TaskStatusQueued, nil, nil)

s, err := k.CheckTaskStatus(ctx, tctx, nil)
Expand Down