Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor k8s array plugin to reuse pod plugin for subtasks #244

Merged
merged 21 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions go/tasks/logs/logging_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ type logPlugin struct {
}

// Internal
func GetLogsForContainerInPod(ctx context.Context, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) {
logPlugin, err := InitializeLogPlugins(GetLogConfig())
if err != nil {
return nil, err
}

func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) {
if logPlugin == nil {
return nil, nil
}
Expand Down
60 changes: 34 additions & 26 deletions go/tasks/logs/logging_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,32 @@ import (
const podName = "PodName"

func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{}))
l, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix")
logPlugin, err := InitializeLogPlugins(&LogConfig{})
assert.NoError(t, err)
l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix")
assert.NoError(t, err)
assert.Nil(t, l)
}

func TestGetLogsForContainerInPod_NoLogs(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
p, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix")
})
assert.NoError(t, err)
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_BadIndex(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -57,17 +60,18 @@ func TestGetLogsForContainerInPod_BadIndex(t *testing.T) {
}
pod.Name = podName

p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix")
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -81,16 +85,17 @@ func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) {
}
pod.Name = podName

p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix")
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{IsCloudwatchEnabled: true,
logPlugin, err := InitializeLogPlugins(&LogConfig{IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -110,16 +115,17 @@ func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}

func TestGetLogsForContainerInPod_K8s(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -139,19 +145,20 @@ func TestGetLogsForContainerInPod_K8s(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}

func TestGetLogsForContainerInPod_All(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -171,18 +178,18 @@ func TestGetLogsForContainerInPod_All(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 2)
}

func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) {

assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsStackDriverEnabled: true,
GCPProjectName: "myGCPProject",
StackdriverLogResourceName: "aws_ec2_instance",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -202,7 +209,7 @@ func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}
Expand Down Expand Up @@ -252,7 +259,8 @@ func TestGetLogsForContainerInPod_LegacyTemplate(t *testing.T) {
}

func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*core.TaskLog) {
assert.NoError(tb, SetLogConfig(config))
logPlugin, err := InitializeLogPlugins(config)
assert.NoError(tb, err)

pod := &v1.Pod{
ObjectMeta: v12.ObjectMeta{
Expand All @@ -275,7 +283,7 @@ func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*c
},
}

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " my-Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " my-Suffix")
assert.Nil(tb, err)
assert.Len(tb, logs, len(expectedTaskLogs))
if diff := deep.Equal(logs, expectedTaskLogs); len(diff) > 0 {
Expand Down
18 changes: 16 additions & 2 deletions go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"reflect"
"sync"

"k8s.io/apimachinery/pkg/api/meta"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
Expand Down Expand Up @@ -89,6 +89,20 @@ func (m *FakeKubeClient) Create(ctx context.Context, obj client.Object, opts ...
m.syncObj.Lock()
defer m.syncObj.Unlock()

// if obj is a *v1.Pod then append a ContainerStatus for each Container
pod, ok := obj.(*v1.Pod)
if ok {
for i := range pod.Spec.Containers {
if len(pod.Status.ContainerStatuses) > i {
continue
}

pod.Status.ContainerStatuses = append(pod.Status.ContainerStatuses, v1.ContainerStatus{
ContainerID: "docker://container-name",
})
}
}

accessor, err := meta.Accessor(obj)
if err != nil {
return err
Expand Down
24 changes: 13 additions & 11 deletions go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ package k8s
import (
"context"

"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/promutils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const executorName = "k8s-array"
Expand Down Expand Up @@ -145,18 +144,21 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
}

func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
pluginState := &arrayCore.State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
}

func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
pluginConfig := GetConfig()

pluginState := &arrayCore.State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, pluginConfig, pluginState)
return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
}

func (e Executor) Start(ctx context.Context) error {
Expand Down
26 changes: 11 additions & 15 deletions go/tasks/plugins/array/k8s/integration_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,29 @@
package k8s

import (
"context"
"strconv"
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"

"github.com/flyteorg/flytestdlib/storage"
"github.com/flyteorg/flyteidl/clients/go/coreutils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"context"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteidl/clients/go/coreutils"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"sigs.k8s.io/controller-runtime/pkg/client"
)

func init() {
Expand Down
Loading