Skip to content

Commit

Permalink
Create a FileOutput reader if the agent produce file output (#391)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 19, 2023
1 parent 45a095e commit 35ae1b4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
22 changes: 15 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.Crea
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

func (m *MockClient) GetTask(_ context.Context, _ *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
},
}}}, nil
func (m *MockClient) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
if req.GetTaskType() == "bigquery_query_job_task" {
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
},
}}}, nil
}
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil
}

func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) {
Expand Down Expand Up @@ -113,6 +116,11 @@ func TestEndToEnd(t *testing.T) {

phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
assert.Equal(t, true, phase.Phase().IsSuccess())

template.Type = "spark_job"
phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
assert.Equal(t, true, phase.Phase().IsSuccess())

})

t.Run("failed to create a job", func(t *testing.T) {
Expand Down Expand Up @@ -251,7 +259,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
func newMockAgentPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"},
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &MockPlugin{
Plugin{
Expand Down
35 changes: 29 additions & 6 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flytestdlib/config"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

Expand All @@ -19,8 +18,11 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/promutils"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -176,17 +178,38 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
case admin.State_RETRYABLE_FAILURE:
return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case admin.State_SUCCEEDED:
if resource.Outputs != nil {
err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil))
if err != nil {
return core.PhaseInfoUndefined, err
}
err = writeOutput(ctx, taskCtx, resource)
if err != nil {
logger.Errorf(ctx, "Failed to write output with err %s", err.Error())
return core.PhaseInfoUndefined, err
}
return core.PhaseInfoSuccess(taskInfo), nil
}
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource *ResourceWrapper) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return err
}

if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
logger.Debugf(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}

var opReader io.OutputReader
if resource.Outputs != nil {
logger.Debugf(ctx, "Agent returned an output")
opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)
} else {
logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.")
opReader = ioutils.NewRemoteFileOutputReader(ctx, taskCtx.DataStore(), taskCtx.OutputWriter(), taskCtx.MaxDatasetSizeBytes())
}
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if agent, exists := cfg.Agents[id]; exists {
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestPlugin(t *testing.T) {
assert.Equal(t, plugin.cfg.ResourceConstraints, constraints)
})

t.Run("tet newAgentPlugin", func(t *testing.T) {
t.Run("test newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
assert.NotNil(t, p)
assert.Equal(t, "agent-service", p.ID)
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb")
outputWriter.OnGetCheckpointPrefix().Return("/checkpoint")
outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev")
outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil)

outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
or := args.Get(1).(io.OutputReader)
Expand Down

0 comments on commit 35ae1b4

Please sign in to comment.