From 2fd51ebc670934bd26c0283c5c963574dbc7e737 Mon Sep 17 00:00:00 2001 From: ddl-rliu <140021987+ddl-rliu@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:53:33 -0700 Subject: [PATCH] Add custominfo to agents (#5604) Signed-off-by: ddl-rliu Signed-off-by: Bugra Gedik --- .../go/tasks/plugins/webapi/agent/plugin.go | 32 +++++++++++-------- .../tasks/plugins/webapi/agent/plugin_test.go | 16 +++++++--- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 20a65ccba1..a7b2a3d1d4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -8,6 +8,7 @@ import ( "time" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" "k8s.io/apimachinery/pkg/util/wait" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" @@ -39,10 +40,11 @@ type Plugin struct { type ResourceWrapper struct { Phase flyteIdl.TaskExecution_Phase // Deprecated: Please Use Phase instead. - State admin.State - Outputs *flyteIdl.LiteralMap - Message string - LogLinks []*flyteIdl.TaskLog + State admin.State + Outputs *flyteIdl.LiteralMap + Message string + LogLinks []*flyteIdl.TaskLog + CustomInfo *structpb.Struct } // IsTerminal is used to avoid making network calls to the agent service if the resource is already in a terminal state. @@ -192,10 +194,11 @@ func (p *Plugin) ExecuteTaskSync( } return nil, ResourceWrapper{ - Phase: resource.Phase, - Outputs: resource.Outputs, - Message: resource.Message, - LogLinks: resource.LogLinks, + Phase: resource.Phase, + Outputs: resource.Outputs, + Message: resource.Message, + LogLinks: resource.LogLinks, + CustomInfo: resource.CustomInfo, }, err } @@ -221,11 +224,12 @@ func (p *Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest web } return ResourceWrapper{ - Phase: res.Resource.Phase, - State: res.Resource.State, - Outputs: res.Resource.Outputs, - Message: res.Resource.Message, - LogLinks: res.Resource.LogLinks, + Phase: res.Resource.Phase, + State: res.Resource.State, + Outputs: res.Resource.Outputs, + Message: res.Resource.Message, + LogLinks: res.Resource.LogLinks, + CustomInfo: res.Resource.CustomInfo, }, nil } @@ -254,7 +258,7 @@ func (p *Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error func (p *Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { resource := taskCtx.Resource().(ResourceWrapper) - taskInfo := &core.TaskInfo{Logs: resource.LogLinks} + taskInfo := &core.TaskInfo{Logs: resource.LogLinks, CustomInfo: resource.CustomInfo} switch resource.Phase { case flyteIdl.TaskExecution_QUEUED: diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 3db1c464b6..9e8c97903e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" agentMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" @@ -114,17 +115,24 @@ func TestPlugin(t *testing.T) { }) t.Run("test RUNNING Status", func(t *testing.T) { + simpleStruct := structpb.Struct{ + Fields: map[string]*structpb.Value{ + "foo": {Kind: &structpb.Value_StringValue{StringValue: "foo"}}, + }, + } taskContext := new(webapiPlugin.StatusContext) taskContext.On("Resource").Return(ResourceWrapper{ - State: admin.State_RUNNING, - Outputs: nil, - Message: "Job is running", - LogLinks: []*flyteIdlCore.TaskLog{{Uri: "http://localhost:3000/log", Name: "Log Link"}}, + State: admin.State_RUNNING, + Outputs: nil, + Message: "Job is running", + LogLinks: []*flyteIdlCore.TaskLog{{Uri: "http://localhost:3000/log", Name: "Log Link"}}, + CustomInfo: &simpleStruct, }) phase, err := plugin.Status(context.Background(), taskContext) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase()) + assert.Equal(t, &simpleStruct, phase.Info().CustomInfo) }) t.Run("test PERMANENT_FAILURE Status", func(t *testing.T) {