diff --git a/go.mod b/go.mod index 80b06c9df..aaefc2ca3 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 github.com/dask/dask-kubernetes/v2023 v2023.0.0-20230626103304-abd02cd17b26 - github.com/flyteorg/flyteidl v1.5.10 + github.com/flyteorg/flyteidl v1.5.13 github.com/flyteorg/flytestdlib v1.0.15 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.3 diff --git a/go.sum b/go.sum index fd47f0ad0..b83d5c84a 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.5.10 h1:SHeiaWRt8EAVuFsat+BJswtc07HTZ4DqhfTEYSm621k= -github.com/flyteorg/flyteidl v1.5.10/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.13 h1:IQ2Cw+u36ew3BPyRDAcHdzc/GyNEOXOxhKy9jbS4hbo= +github.com/flyteorg/flyteidl v1.5.13/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 70a335021..dbcb568d4 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -72,7 +72,8 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix}) + taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix, TaskExecutionMetadata: &taskExecutionMetadata}) if err != nil { return nil, nil, err } @@ -183,6 +184,18 @@ func getClientFunc(ctx context.Context, endpoint string, connectionCache map[str return service.NewAsyncAgentServiceClient(conn), nil } +func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata) admin.TaskExecutionMetadata { + taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() + return admin.TaskExecutionMetadata{ + TaskExecutionId: &taskExecutionID, + Namespace: taskExecutionMetadata.GetNamespace(), + Labels: taskExecutionMetadata.GetLabels(), + Annotations: taskExecutionMetadata.GetAnnotations(), + K8SServiceAccount: taskExecutionMetadata.GetK8sServiceAccount(), + EnvironmentVariables: taskExecutionMetadata.GetEnvironmentVariables(), + } +} + func newAgentPlugin() webapi.PluginEntry { supportedTaskTypes := GetConfig().SupportedTaskTypes