Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Render task template in the agent client #384

Merged
merged 5 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/utils/strings/slices"
)

type MockPlugin struct {
Expand All @@ -39,7 +40,11 @@ type MockPlugin struct {
type MockClient struct {
}

func (m *MockClient) CreateTask(_ context.Context, _ *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}
if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) {
return nil, fmt.Errorf("args not as expected")
}
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

Expand Down Expand Up @@ -95,6 +100,9 @@ func TestEndToEnd(t *testing.T) {
template := flyteIdlCore.TaskTemplate{
Type: "bigquery_query_job_task",
Custom: st,
Target: &flyteIdlCore.TaskTemplate_Container{
Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}},
},
}
basePrefix := storage.DataReference("fake://bucket/prefix/")

Expand Down
21 changes: 17 additions & 4 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/promutils"
Expand Down Expand Up @@ -68,6 +68,19 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
return nil, nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, nil, err
}
taskTemplate.GetContainer().Args = modifiedArgs
}
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
Expand Down Expand Up @@ -150,7 +163,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase

switch resource.State {
case admin.State_RUNNING:
return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil
return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil
case admin.State_PERMANENT_FAILURE:
return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case admin.State_RETRYABLE_FAILURE:
Expand All @@ -164,7 +177,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
}
return core.PhaseInfoSuccess(taskInfo), nil
}
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
Expand Down Expand Up @@ -225,7 +238,7 @@ func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent
return service.NewAsyncAgentServiceClient(conn), nil
}

func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata) admin.TaskExecutionMetadata {
func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata {
taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID()
return admin.TaskExecutionMetadata{
TaskExecutionId: &taskExecutionID,
Expand Down