Skip to content

Commit

Permalink
[K8s-Array] Add finalizer (flyteorg#186)
Browse files Browse the repository at this point in the history
* Add finalizer

Signed-off-by: Anmol Khurana <[email protected]>

* Add finalizer

Signed-off-by: Anmol Khurana <[email protected]>

* Add finalizer

Signed-off-by: Anmol Khurana <[email protected]>

* Add finalizer

Signed-off-by: Anmol Khurana <[email protected]>

* Add finalizer

Signed-off-by: Anmol Khurana <[email protected]>

* PR comment

Signed-off-by: Anmol Khurana <[email protected]>

* Fix finalizer name

Signed-off-by: Anmol Khurana <[email protected]>
  • Loading branch information
akhurana001 authored Jul 18, 2021
1 parent a20bca7 commit b671abc
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 6 deletions.
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/k8s/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8s
o, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: pod.Name,
Namespace: pod.Namespace,
LogName: fmt.Sprintf(" #%d-%d", index, retryAttempt),
LogName: fmt.Sprintf(" #%d-%d", retryAttempt, index),
PodUnixStartTime: pod.CreationTimestamp.Unix(),
})

Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/k8s/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ func TestCheckSubTasksState(t *testing.T) {
assert.NotEmpty(t, logLinks)
assert.Equal(t, 10, len(logLinks))
for i := 0; i < 10; i = i + 2 {
assert.Equal(t, fmt.Sprintf("Kubernetes Logs #%d-0 (PhaseRunning)", i/2), logLinks[i].Name)
assert.Equal(t, fmt.Sprintf("Kubernetes Logs #0-%d (PhaseRunning)", i/2), logLinks[i].Name)
assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d/pod?namespace=a-n-b", i/2), logLinks[i].Uri)

assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #%d-0 (PhaseRunning)", i/2), logLinks[i+1].Name)
assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #0-%d (PhaseRunning)", i/2), logLinks[i+1].Name)
assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri)
}

Expand Down
54 changes: 52 additions & 2 deletions go/tasks/plugins/array/k8s/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"
"strings"

metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
Expand Down Expand Up @@ -45,6 +46,29 @@ const (
LaunchReturnState
)

const finalizer = "flyte/array"

func addPodFinalizer(pod *corev1.Pod) *corev1.Pod {
pod.Finalizers = append(pod.Finalizers, finalizer)
return pod
}

func removeString(list []string, target string) []string {
ret := make([]string, 0)
for _, s := range list {
if s != target {
ret = append(ret, s)
}
}

return ret
}

func clearFinalizer(pod *corev1.Pod) *corev1.Pod {
pod.Finalizers = removeString(pod.Finalizers, finalizer)
return pod
}

const (
MonitorSuccess MonitorResult = iota
MonitorError
Expand Down Expand Up @@ -107,7 +131,7 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl
pod = ApplyPodPolicies(ctx, t.Config, pod)
pod = applyNodeSelectorLabels(ctx, t.Config, pod)
pod = applyPodTolerations(ctx, t.Config, pod)

pod = addPodFinalizer(pod)
allocationStatus, err := allocateResource(ctx, tCtx, t.Config, podName)
if err != nil {
return LaunchError, err
Expand Down Expand Up @@ -227,8 +251,34 @@ func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kube
indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)

pod := &v1.Pod{
TypeMeta: metaV1.TypeMeta{
Kind: PodKind,
APIVersion: v1.SchemeGroupVersion.String(),
},
}

err := kubeClient.GetClient().Get(ctx, k8sTypes.NamespacedName{
Name: podName,
Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate),
}, pod)

if err != nil {
if !k8serrors.IsNotFound(err) {
logger.Errorf(ctx, "Error fetching pod [%s] in Finalize [%s]", podName, err)
return err
}
} else {
pod = clearFinalizer(pod)
err := kubeClient.GetClient().Update(ctx, pod)
if err != nil {
logger.Errorf(ctx, "Error updating pod finalizer [%s] in Finalize [%s]", podName, err)
return err
}
}

// Deallocate Resource
err := deallocateResource(ctx, tCtx, t.Config, t.ChildIdx)
err = deallocateResource(ctx, tCtx, t.Config, t.ChildIdx)
if err != nil {
logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err)
return err
Expand Down
41 changes: 40 additions & 1 deletion go/tasks/plugins/array/k8s/task_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
package k8s

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"

"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"

"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestFinalize(t *testing.T) {
ctx := context.Background()

tCtx := getMockTaskExecutionContext(ctx)
kubeClient := mocks.KubeClient{}
kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient())

resourceManager := mocks.ResourceManager{}
podTemplate, _, _ := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, "")
pod := addPodFinalizer(&podTemplate)
pod.Name = formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "1")
assert.NoError(t, kubeClient.GetClient().Create(ctx, pod))

resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
tCtx.OnResourceManager().Return(&resourceManager)

config := Config{
MaxArrayJobSize: 100,
ResourceConfig: ResourceConfig{
PrimaryLabel: "p",
Limit: 10,
},
}

task := &Task{
Config: &config,
ChildIdx: 1,
}

err := task.Finalize(ctx, tCtx, &kubeClient)
assert.NoError(t, err)
}

func TestGetTaskContainerIndex(t *testing.T) {
t.Run("test container target", func(t *testing.T) {
pod := &v1.Pod{
Expand Down

0 comments on commit b671abc

Please sign in to comment.