Skip to content

Commit

Permalink
Refactor k8s array plugin to reuse pod plugin for subtasks (flyteorg#244
Browse files Browse the repository at this point in the history
)

* refactored to use pod plugin in map task subtasks

Signed-off-by: Daniel Rammer <[email protected]>

* initializing array log plugins

Signed-off-by: Daniel Rammer <[email protected]>

* fixed namespace template tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed podname for logs

Signed-off-by: Daniel Rammer <[email protected]>

* setting task log names based on retry attempt and index

Signed-off-by: Daniel Rammer <[email protected]>

* setting task log names correctly

Signed-off-by: Daniel Rammer <[email protected]>

* filled out management unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* cleaned up subtask execution context and added unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed existing unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issues

Signed-off-by: Daniel Rammer <[email protected]>

* added function comments to management.go

Signed-off-by: Daniel Rammer <[email protected]>

* fixed phase transitions from failures when launching or monitoring subtasks

Signed-off-by: Daniel Rammer <[email protected]>

* added finalizeSubtask function

Signed-off-by: Daniel Rammer <[email protected]>

* added abort functionality

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issues

Signed-off-by: Daniel Rammer <[email protected]>

* fixed log suffix to match pod name

Signed-off-by: Daniel Rammer <[email protected]>

* fixed abort error

Signed-off-by: Daniel Rammer <[email protected]>

* finalizing subtask on terminal phase

Signed-off-by: Daniel Rammer <[email protected]>

* deallocating resource on launchSubtask fail

Signed-off-by: Daniel Rammer <[email protected]>

* fixed management derr issue

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Mar 16, 2022
1 parent d9a8d89 commit c4fe311
Show file tree
Hide file tree
Showing 23 changed files with 1,468 additions and 2,001 deletions.
7 changes: 1 addition & 6 deletions go/tasks/logs/logging_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ type logPlugin struct {
}

// Internal
func GetLogsForContainerInPod(ctx context.Context, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) {
logPlugin, err := InitializeLogPlugins(GetLogConfig())
if err != nil {
return nil, err
}

func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, pod *v1.Pod, index uint32, nameSuffix string) ([]*core.TaskLog, error) {
if logPlugin == nil {
return nil, nil
}
Expand Down
60 changes: 34 additions & 26 deletions go/tasks/logs/logging_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,32 @@ import (
const podName = "PodName"

func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{}))
l, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix")
logPlugin, err := InitializeLogPlugins(&LogConfig{})
assert.NoError(t, err)
l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix")
assert.NoError(t, err)
assert.Nil(t, l)
}

func TestGetLogsForContainerInPod_NoLogs(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
p, err := GetLogsForContainerInPod(context.TODO(), nil, 0, " Suffix")
})
assert.NoError(t, err)
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, nil, 0, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_BadIndex(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -57,17 +60,18 @@ func TestGetLogsForContainerInPod_BadIndex(t *testing.T) {
}
pod.Name = podName

p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix")
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -81,16 +85,17 @@ func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) {
}
pod.Name = podName

p, err := GetLogsForContainerInPod(context.TODO(), pod, 1, " Suffix")
p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 1, " Suffix")
assert.NoError(t, err)
assert.Nil(t, p)
}

func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{IsCloudwatchEnabled: true,
logPlugin, err := InitializeLogPlugins(&LogConfig{IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -110,16 +115,17 @@ func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}

func TestGetLogsForContainerInPod_K8s(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -139,19 +145,20 @@ func TestGetLogsForContainerInPod_K8s(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}

func TestGetLogsForContainerInPod_All(t *testing.T) {
assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
IsCloudwatchEnabled: true,
CloudwatchRegion: "us-east-1",
CloudwatchLogGroup: "/kubernetes/flyte-production",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -171,18 +178,18 @@ func TestGetLogsForContainerInPod_All(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 2)
}

func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) {

assert.NoError(t, SetLogConfig(&LogConfig{
logPlugin, err := InitializeLogPlugins(&LogConfig{
IsStackDriverEnabled: true,
GCPProjectName: "myGCPProject",
StackdriverLogResourceName: "aws_ec2_instance",
}))
})
assert.NoError(t, err)

pod := &v1.Pod{
Spec: v1.PodSpec{
Expand All @@ -202,7 +209,7 @@ func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) {
}
pod.Name = podName

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " Suffix")
assert.Nil(t, err)
assert.Len(t, logs, 1)
}
Expand Down Expand Up @@ -252,7 +259,8 @@ func TestGetLogsForContainerInPod_LegacyTemplate(t *testing.T) {
}

func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*core.TaskLog) {
assert.NoError(tb, SetLogConfig(config))
logPlugin, err := InitializeLogPlugins(config)
assert.NoError(tb, err)

pod := &v1.Pod{
ObjectMeta: v12.ObjectMeta{
Expand All @@ -275,7 +283,7 @@ func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*c
},
}

logs, err := GetLogsForContainerInPod(context.TODO(), pod, 0, " my-Suffix")
logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, pod, 0, " my-Suffix")
assert.Nil(tb, err)
assert.Len(tb, logs, len(expectedTaskLogs))
if diff := deep.Equal(logs, expectedTaskLogs); len(diff) > 0 {
Expand Down
18 changes: 16 additions & 2 deletions go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"reflect"
"sync"

"k8s.io/apimachinery/pkg/api/meta"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
Expand Down Expand Up @@ -89,6 +89,20 @@ func (m *FakeKubeClient) Create(ctx context.Context, obj client.Object, opts ...
m.syncObj.Lock()
defer m.syncObj.Unlock()

// if obj is a *v1.Pod then append a ContainerStatus for each Container
pod, ok := obj.(*v1.Pod)
if ok {
for i := range pod.Spec.Containers {
if len(pod.Status.ContainerStatuses) > i {
continue
}

pod.Status.ContainerStatuses = append(pod.Status.ContainerStatuses, v1.ContainerStatus{
ContainerID: "docker://container-name",
})
}
}

accessor, err := meta.Accessor(obj)
if err != nil {
return err
Expand Down
24 changes: 13 additions & 11 deletions go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ package k8s
import (
"context"

"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/promutils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const executorName = "k8s-array"
Expand Down Expand Up @@ -145,18 +144,21 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
}

func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
pluginState := &arrayCore.State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), abortSubtask, pluginState)
}

func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
pluginConfig := GetConfig()

pluginState := &arrayCore.State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

return TerminateSubTasks(ctx, tCtx, e.kubeClient, pluginConfig, pluginState)
return TerminateSubTasks(ctx, tCtx, e.kubeClient, GetConfig(), finalizeSubtask, pluginState)
}

func (e Executor) Start(ctx context.Context) error {
Expand Down
26 changes: 11 additions & 15 deletions go/tasks/plugins/array/k8s/integration_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,29 @@
package k8s

import (
"context"
"strconv"
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"

"github.com/flyteorg/flytestdlib/storage"
"github.com/flyteorg/flyteidl/clients/go/coreutils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"context"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteidl/clients/go/coreutils"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"sigs.k8s.io/controller-runtime/pkg/client"
)

func init() {
Expand Down
Loading

0 comments on commit c4fe311

Please sign in to comment.