From d0baca9068268d8787806913e284dcad31607542 Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Wed, 6 Mar 2024 09:41:47 +0200 Subject: [PATCH] Adjust ObjectStore plugin to pass cluster name (#114) * Adjust ObjectStore plugin to pass cluster name Signed-off-by: Iaroslav Ciupin --- .../pkg/manager/impl/execution_manager.go | 4 ++ .../manager/impl/node_execution_manager.go | 51 ++++++++++++++----- .../manager/impl/task_execution_manager.go | 45 ++++++++++++---- flyteadmin/pkg/manager/impl/util/data.go | 23 +++++---- flyteadmin/pkg/manager/impl/util/data_test.go | 21 +++++--- 5 files changed, 105 insertions(+), 39 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index e36fcccadf..d7ff666088 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -1752,6 +1752,8 @@ func (m *ExecutionManager) GetExecutionData( } } + cluster := execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster() + id := request.GetId() objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore) var inputs *core.LiteralMap @@ -1763,6 +1765,7 @@ func (m *ExecutionManager) GetExecutionData( m.urlData, m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, + cluster, id.Project, id.Domain, executionModel.InputsURI.String(), @@ -1779,6 +1782,7 @@ func (m *ExecutionManager) GetExecutionData( m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, util.ToExecutionClosureInterface(execution.Closure), + cluster, id.Project, id.Domain, objectStore) diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager.go b/flyteadmin/pkg/manager/impl/node_execution_manager.go index f06a9b9195..4c6b52fe35 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager.go @@ -521,16 +521,42 @@ func (m *NodeExecutionManager) GetNodeExecutionData( } ctx = getNodeExecutionContext(ctx, request.Id) - nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id) - if err != nil { - logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v", - request.Id, err) - return nil, err - } + group, groupCtx := errgroup.WithContext(ctx) + var nodeExecutionModel *models.NodeExecution + var nodeExecution *admin.NodeExecution + group.Go(func() error { + var err error + nodeExecutionModel, err = util.GetNodeExecutionModel(ctx, m.db, request.Id) + if err != nil { + logger.Errorf(ctx, "failed to get node execution with id [%+v] with err %v", + request.Id, err) + return err + } + nodeExecution, err = transformers.FromNodeExecutionModel(*nodeExecutionModel, transformers.DefaultExecutionTransformerOptions) + if err != nil { + logger.Errorf(ctx, "failed to transform node execution model [%+v] when fetching data: %v", request.Id, err) + } + return err + }) - nodeExecution, err := transformers.FromNodeExecutionModel(*nodeExecutionModel, transformers.DefaultExecutionTransformerOptions) - if err != nil { - logger.Debugf(ctx, "failed to transform node execution model [%+v] when fetching data: %v", request.Id, err) + cluster := "" + group.Go(func() error { + // when fetching remote S3 URIs, we need a cluster name to send to Union dataproxy to get the correct tunnel + execModel, err := util.GetExecutionModel(ctx, m.db, *request.GetId().GetExecutionId()) + if err != nil { + logger.Errorf(ctx, "failed to fetch execution model: %v", err) + return err + } + execution, err := transformers.FromExecutionModel(ctx, *execModel, transformers.DefaultExecutionTransformerOptions) + if err != nil { + logger.Errorf(ctx, "failed to transform execution model [%+v] to proto object with err: %v", execModel.Name, err) + return err + } + cluster = execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster() + return nil + }) + + if err := group.Wait(); err != nil { return nil, err } @@ -538,13 +564,14 @@ func (m *NodeExecutionManager) GetNodeExecutionData( objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore) var inputs *core.LiteralMap var inputURLBlob *admin.UrlBlob - group, groupCtx := errgroup.WithContext(ctx) + group, groupCtx = errgroup.WithContext(ctx) group.Go(func() error { var err error inputs, inputURLBlob, err = util.GetInputs(groupCtx, m.urlData, m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, + cluster, id.Project, id.Domain, nodeExecution.InputUri, @@ -561,14 +588,14 @@ func (m *NodeExecutionManager) GetNodeExecutionData( m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, nodeExecution.Closure, + cluster, id.Project, id.Domain, objectStore) return err }) - err = group.Wait() - if err != nil { + if err := group.Wait(); err != nil { return nil, err } response := &admin.NodeExecutionGetDataResponse{ diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index a7f3849e0e..0e88ab1998 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -298,18 +298,41 @@ func (m *TaskExecutionManager) ListTaskExecutions( }, nil } -func (m *TaskExecutionManager) GetTaskExecutionData( - ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) { +func (m *TaskExecutionManager) GetTaskExecutionData(ctx context.Context, + request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) { if err := validation.ValidateTaskExecutionIdentifier(request.Id); err != nil { logger.Debugf(ctx, "Invalid identifier [%+v]: %v", request.Id, err) } ctx = getTaskExecutionContext(ctx, request.Id) - taskExecution, err := m.GetTaskExecution(ctx, admin.TaskExecutionGetRequest{ - Id: request.Id, + group, groupCtx := errgroup.WithContext(ctx) + var taskExecution *admin.TaskExecution + group.Go(func() error { + var err error + taskExecution, err = m.GetTaskExecution(ctx, admin.TaskExecutionGetRequest{Id: request.Id}) + if err != nil { + logger.Errorf(ctx, "Failed to get task execution with id [%+v] with err %v", request.Id, err) + } + return err }) - if err != nil { - logger.Debugf(ctx, "Failed to get task execution with id [%+v] with err %v", - request.Id, err) + + cluster := "" + group.Go(func() error { + // when fetching remote S3 URIs, we need a cluster name to send to Union dataproxy to get the correct tunnel + execModel, err := util.GetExecutionModel(ctx, m.db, *request.GetId().GetNodeExecutionId().GetExecutionId()) + if err != nil { + logger.Errorf(ctx, "failed to fetch execution model: %v", err) + return err + } + execution, err := transformers.FromExecutionModel(ctx, *execModel, transformers.DefaultExecutionTransformerOptions) + if err != nil { + logger.Errorf(ctx, "failed to transform execution model [%+v] to proto object with err: %v", execModel.Name, err) + return err + } + cluster = execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster() + return nil + }) + + if err := group.Wait(); err != nil { return nil, err } @@ -317,13 +340,14 @@ func (m *TaskExecutionManager) GetTaskExecutionData( id := request.GetId().GetNodeExecutionId().GetExecutionId() var inputs *core.LiteralMap var inputURLBlob *admin.UrlBlob - group, groupCtx := errgroup.WithContext(ctx) + group, groupCtx = errgroup.WithContext(ctx) group.Go(func() error { var err error inputs, inputURLBlob, err = util.GetInputs(groupCtx, m.urlData, m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, + cluster, id.Project, id.Domain, taskExecution.InputUri, @@ -340,16 +364,17 @@ func (m *TaskExecutionManager) GetTaskExecutionData( m.config.ApplicationConfiguration().GetRemoteDataConfig(), m.storageClient, taskExecution.Closure, + cluster, id.Project, id.Domain, objectStore) return err }) - err = group.Wait() - if err != nil { + if err := group.Wait(); err != nil { return nil, err } + response := &admin.TaskExecutionGetDataResponse{ Inputs: inputURLBlob, Outputs: outputURLBlob, diff --git a/flyteadmin/pkg/manager/impl/util/data.go b/flyteadmin/pkg/manager/impl/util/data.go index af165e82a6..080cc9fe71 100644 --- a/flyteadmin/pkg/manager/impl/util/data.go +++ b/flyteadmin/pkg/manager/impl/util/data.go @@ -23,6 +23,7 @@ const ( ) type GetObjectRequest struct { + Cluster string Project string Domain string Prefix string @@ -50,7 +51,7 @@ func GetInputs(ctx context.Context, urlData dataInterfaces.RemoteURLInterface, remoteDataConfig *runtimeInterfaces.RemoteDataConfig, storageClient *storage.DataStore, - project, domain, inputURI string, + cluster, project, domain, inputURI string, objectStore ObjectStore, ) (*core.LiteralMap, *admin.UrlBlob, error) { var inputsURLBlob admin.UrlBlob @@ -69,11 +70,10 @@ func GetInputs(ctx context.Context, } if shouldFetchData(remoteDataConfig, inputsURLBlob) { - base := string(storageClient.GetBaseContainerFQN(ctx)) - if strings.HasPrefix(inputURI, base) { + if IsLocalURI(ctx, storageClient, inputURI) { err = storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs) } else { - err = readFromDataPlane(ctx, objectStore, project, domain, inputURI, &fullInputs) + err = readFromDataPlane(ctx, objectStore, cluster, project, domain, inputURI, &fullInputs) } if err != nil { // If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether. @@ -126,7 +126,7 @@ func GetOutputs(ctx context.Context, remoteDataConfig *runtimeInterfaces.RemoteDataConfig, storageClient *storage.DataStore, closure ExecutionClosure, - project, domain string, + cluster, project, domain string, objectStore ObjectStore, ) (*core.LiteralMap, *admin.UrlBlob, error) { var outputsURLBlob admin.UrlBlob @@ -150,11 +150,10 @@ func GetOutputs(ctx context.Context, logger.Debugf(ctx, "execution closure contains output data that exceeds max data size for responses") } } else if shouldFetchOutputData(remoteDataConfig, outputsURLBlob, closure.GetOutputUri()) { - base := string(storageClient.GetBaseContainerFQN(ctx)) - if strings.HasPrefix(closure.GetOutputUri(), base) { + if IsLocalURI(ctx, storageClient, closure.GetOutputUri()) { err = storageClient.ReadProtobuf(ctx, storage.DataReference(closure.GetOutputUri()), fullOutputs) } else { - err = readFromDataPlane(ctx, objectStore, project, domain, closure.GetOutputUri(), fullOutputs) + err = readFromDataPlane(ctx, objectStore, cluster, project, domain, closure.GetOutputUri(), fullOutputs) } if err != nil { // If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether. @@ -166,9 +165,14 @@ func GetOutputs(ctx context.Context, return fullOutputs, &outputsURLBlob, nil } +func IsLocalURI(ctx context.Context, store *storage.DataStore, uri string) bool { + base := store.GetBaseContainerFQN(ctx) + return strings.HasPrefix(uri, string(base)) +} + func readFromDataPlane(ctx context.Context, objectStore ObjectStore, - project, domain, reference string, + cluster, project, domain, reference string, msg proto.Message, ) error { if objectStore == nil { @@ -181,6 +185,7 @@ func readFromDataPlane(ctx context.Context, } out, err := objectStore.GetObject(ctx, GetObjectRequest{ + Cluster: cluster, Prefix: refURL.Path, Project: project, Domain: domain, diff --git a/flyteadmin/pkg/manager/impl/util/data_test.go b/flyteadmin/pkg/manager/impl/util/data_test.go index e5ed58596f..31a9a9aaaf 100644 --- a/flyteadmin/pkg/manager/impl/util/data_test.go +++ b/flyteadmin/pkg/manager/impl/util/data_test.go @@ -24,7 +24,10 @@ var testLiteralMap = &core.LiteralMap{ }, } -const testOutputsURI = "s3://bucket/bar/outputs.pb" +const ( + testOutputsURI = "s3://bucket/bar/outputs.pb" + clusterName = "foo-cluster" +) type objectStoreMock struct { mock.Mock @@ -149,7 +152,7 @@ func TestGetInputs(t *testing.T) { t.Run("should sign URL", func(t *testing.T) { remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: true} - fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputsURI, nil) + fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, "", project, domain, inputsURI, nil) assert.NoError(t, err) assert.True(t, proto.Equal(fullInputs, testLiteralMap)) @@ -159,7 +162,7 @@ func TestGetInputs(t *testing.T) { t.Run("should not sign URL", func(t *testing.T) { remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: false} - fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputsURI, nil) + fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, "", project, domain, inputsURI, nil) assert.NoError(t, err) assert.True(t, proto.Equal(fullInputs, testLiteralMap)) @@ -172,6 +175,7 @@ func TestGetInputs(t *testing.T) { bts, _ := proto.Marshal(testLiteralMap) store. On("GetObject", GetObjectRequest{ + Cluster: clusterName, Project: project, Domain: domain, Prefix: "/foo/bar", @@ -182,7 +186,7 @@ func TestGetInputs(t *testing.T) { ctx := context.TODO() inputURI := "s3://wrong/foo/bar" - fullInputs, inputURLBlob, err := GetInputs(ctx, mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputURI, store) + fullInputs, inputURLBlob, err := GetInputs(ctx, mockRemoteURL, &remoteDataConfig, mockStorage, clusterName, project, domain, inputURI, store) assert.NoError(t, err) assert.True(t, proto.Equal(fullInputs, testLiteralMap)) @@ -222,7 +226,7 @@ func TestGetOutputs(t *testing.T) { t.Run("offloaded outputs with signed URL", func(t *testing.T) { remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: true} - fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil) + fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil) assert.NoError(t, err) assert.True(t, proto.Equal(fullOutputs, testLiteralMap)) @@ -232,7 +236,7 @@ func TestGetOutputs(t *testing.T) { t.Run("offloaded outputs without signed URL", func(t *testing.T) { remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: false} - fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil) + fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil) assert.NoError(t, err) assert.True(t, proto.Equal(fullOutputs, testLiteralMap)) @@ -245,6 +249,7 @@ func TestGetOutputs(t *testing.T) { bts, _ := proto.Marshal(testLiteralMap) store. On("GetObject", GetObjectRequest{ + Cluster: clusterName, Project: project, Domain: domain, Prefix: "/foo/bar", @@ -257,7 +262,7 @@ func TestGetOutputs(t *testing.T) { }, } - fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, testClosure, project, domain, store) + fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, testClosure, clusterName, project, domain, store) assert.NoError(t, err) assert.True(t, proto.Equal(fullOutputs, testLiteralMap)) @@ -286,7 +291,7 @@ func TestGetOutputs(t *testing.T) { }, } - fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil) + fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil) assert.NoError(t, err) assert.True(t, proto.Equal(fullOutputs, testLiteralMap))