Skip to content

Commit

Permalink
Adjust ObjectStore plugin to pass cluster name (#114)
Browse files Browse the repository at this point in the history
* Adjust ObjectStore plugin to pass cluster name

Signed-off-by: Iaroslav Ciupin <[email protected]>
  • Loading branch information
iaroslav-ciupin authored Mar 6, 2024
1 parent f1569a2 commit d0baca9
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 39 deletions.
4 changes: 4 additions & 0 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,8 @@ func (m *ExecutionManager) GetExecutionData(
}
}

cluster := execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster()

id := request.GetId()
objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
var inputs *core.LiteralMap
Expand All @@ -1763,6 +1765,7 @@ func (m *ExecutionManager) GetExecutionData(
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
cluster,
id.Project,
id.Domain,
executionModel.InputsURI.String(),
Expand All @@ -1779,6 +1782,7 @@ func (m *ExecutionManager) GetExecutionData(
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
util.ToExecutionClosureInterface(execution.Closure),
cluster,
id.Project,
id.Domain,
objectStore)
Expand Down
51 changes: 39 additions & 12 deletions flyteadmin/pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,30 +521,57 @@ func (m *NodeExecutionManager) GetNodeExecutionData(
}

ctx = getNodeExecutionContext(ctx, request.Id)
nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v",
request.Id, err)
return nil, err
}
group, groupCtx := errgroup.WithContext(ctx)
var nodeExecutionModel *models.NodeExecution
var nodeExecution *admin.NodeExecution
group.Go(func() error {
var err error
nodeExecutionModel, err = util.GetNodeExecutionModel(ctx, m.db, request.Id)
if err != nil {
logger.Errorf(ctx, "failed to get node execution with id [%+v] with err %v",
request.Id, err)
return err
}
nodeExecution, err = transformers.FromNodeExecutionModel(*nodeExecutionModel, transformers.DefaultExecutionTransformerOptions)
if err != nil {
logger.Errorf(ctx, "failed to transform node execution model [%+v] when fetching data: %v", request.Id, err)
}
return err
})

nodeExecution, err := transformers.FromNodeExecutionModel(*nodeExecutionModel, transformers.DefaultExecutionTransformerOptions)
if err != nil {
logger.Debugf(ctx, "failed to transform node execution model [%+v] when fetching data: %v", request.Id, err)
cluster := ""
group.Go(func() error {
// when fetching remote S3 URIs, we need a cluster name to send to Union dataproxy to get the correct tunnel
execModel, err := util.GetExecutionModel(ctx, m.db, *request.GetId().GetExecutionId())
if err != nil {
logger.Errorf(ctx, "failed to fetch execution model: %v", err)
return err
}
execution, err := transformers.FromExecutionModel(ctx, *execModel, transformers.DefaultExecutionTransformerOptions)
if err != nil {
logger.Errorf(ctx, "failed to transform execution model [%+v] to proto object with err: %v", execModel.Name, err)
return err
}
cluster = execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster()
return nil
})

if err := group.Wait(); err != nil {
return nil, err
}

id := request.GetId().GetExecutionId()
objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
var inputs *core.LiteralMap
var inputURLBlob *admin.UrlBlob
group, groupCtx := errgroup.WithContext(ctx)
group, groupCtx = errgroup.WithContext(ctx)
group.Go(func() error {
var err error
inputs, inputURLBlob, err = util.GetInputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
cluster,
id.Project,
id.Domain,
nodeExecution.InputUri,
Expand All @@ -561,14 +588,14 @@ func (m *NodeExecutionManager) GetNodeExecutionData(
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
nodeExecution.Closure,
cluster,
id.Project,
id.Domain,
objectStore)
return err
})

err = group.Wait()
if err != nil {
if err := group.Wait(); err != nil {
return nil, err
}
response := &admin.NodeExecutionGetDataResponse{
Expand Down
45 changes: 35 additions & 10 deletions flyteadmin/pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,32 +298,56 @@ func (m *TaskExecutionManager) ListTaskExecutions(
}, nil
}

func (m *TaskExecutionManager) GetTaskExecutionData(
ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
func (m *TaskExecutionManager) GetTaskExecutionData(ctx context.Context,
request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
if err := validation.ValidateTaskExecutionIdentifier(request.Id); err != nil {
logger.Debugf(ctx, "Invalid identifier [%+v]: %v", request.Id, err)
}
ctx = getTaskExecutionContext(ctx, request.Id)
taskExecution, err := m.GetTaskExecution(ctx, admin.TaskExecutionGetRequest{
Id: request.Id,
group, groupCtx := errgroup.WithContext(ctx)
var taskExecution *admin.TaskExecution
group.Go(func() error {
var err error
taskExecution, err = m.GetTaskExecution(ctx, admin.TaskExecutionGetRequest{Id: request.Id})
if err != nil {
logger.Errorf(ctx, "Failed to get task execution with id [%+v] with err %v", request.Id, err)
}
return err
})
if err != nil {
logger.Debugf(ctx, "Failed to get task execution with id [%+v] with err %v",
request.Id, err)

cluster := ""
group.Go(func() error {
// when fetching remote S3 URIs, we need a cluster name to send to Union dataproxy to get the correct tunnel
execModel, err := util.GetExecutionModel(ctx, m.db, *request.GetId().GetNodeExecutionId().GetExecutionId())
if err != nil {
logger.Errorf(ctx, "failed to fetch execution model: %v", err)
return err
}
execution, err := transformers.FromExecutionModel(ctx, *execModel, transformers.DefaultExecutionTransformerOptions)
if err != nil {
logger.Errorf(ctx, "failed to transform execution model [%+v] to proto object with err: %v", execModel.Name, err)
return err
}
cluster = execution.GetSpec().GetMetadata().GetSystemMetadata().GetExecutionCluster()
return nil
})

if err := group.Wait(); err != nil {
return nil, err
}

objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
id := request.GetId().GetNodeExecutionId().GetExecutionId()
var inputs *core.LiteralMap
var inputURLBlob *admin.UrlBlob
group, groupCtx := errgroup.WithContext(ctx)
group, groupCtx = errgroup.WithContext(ctx)
group.Go(func() error {
var err error
inputs, inputURLBlob, err = util.GetInputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
cluster,
id.Project,
id.Domain,
taskExecution.InputUri,
Expand All @@ -340,16 +364,17 @@ func (m *TaskExecutionManager) GetTaskExecutionData(
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
taskExecution.Closure,
cluster,
id.Project,
id.Domain,
objectStore)
return err
})

err = group.Wait()
if err != nil {
if err := group.Wait(); err != nil {
return nil, err
}

response := &admin.TaskExecutionGetDataResponse{
Inputs: inputURLBlob,
Outputs: outputURLBlob,
Expand Down
23 changes: 14 additions & 9 deletions flyteadmin/pkg/manager/impl/util/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
)

type GetObjectRequest struct {
Cluster string
Project string
Domain string
Prefix string
Expand Down Expand Up @@ -50,7 +51,7 @@ func GetInputs(ctx context.Context,
urlData dataInterfaces.RemoteURLInterface,
remoteDataConfig *runtimeInterfaces.RemoteDataConfig,
storageClient *storage.DataStore,
project, domain, inputURI string,
cluster, project, domain, inputURI string,
objectStore ObjectStore,
) (*core.LiteralMap, *admin.UrlBlob, error) {
var inputsURLBlob admin.UrlBlob
Expand All @@ -69,11 +70,10 @@ func GetInputs(ctx context.Context,
}

if shouldFetchData(remoteDataConfig, inputsURLBlob) {
base := string(storageClient.GetBaseContainerFQN(ctx))
if strings.HasPrefix(inputURI, base) {
if IsLocalURI(ctx, storageClient, inputURI) {
err = storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs)
} else {
err = readFromDataPlane(ctx, objectStore, project, domain, inputURI, &fullInputs)
err = readFromDataPlane(ctx, objectStore, cluster, project, domain, inputURI, &fullInputs)
}
if err != nil {
// If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether.
Expand Down Expand Up @@ -126,7 +126,7 @@ func GetOutputs(ctx context.Context,
remoteDataConfig *runtimeInterfaces.RemoteDataConfig,
storageClient *storage.DataStore,
closure ExecutionClosure,
project, domain string,
cluster, project, domain string,
objectStore ObjectStore,
) (*core.LiteralMap, *admin.UrlBlob, error) {
var outputsURLBlob admin.UrlBlob
Expand All @@ -150,11 +150,10 @@ func GetOutputs(ctx context.Context,
logger.Debugf(ctx, "execution closure contains output data that exceeds max data size for responses")
}
} else if shouldFetchOutputData(remoteDataConfig, outputsURLBlob, closure.GetOutputUri()) {
base := string(storageClient.GetBaseContainerFQN(ctx))
if strings.HasPrefix(closure.GetOutputUri(), base) {
if IsLocalURI(ctx, storageClient, closure.GetOutputUri()) {
err = storageClient.ReadProtobuf(ctx, storage.DataReference(closure.GetOutputUri()), fullOutputs)
} else {
err = readFromDataPlane(ctx, objectStore, project, domain, closure.GetOutputUri(), fullOutputs)
err = readFromDataPlane(ctx, objectStore, cluster, project, domain, closure.GetOutputUri(), fullOutputs)
}
if err != nil {
// If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether.
Expand All @@ -166,9 +165,14 @@ func GetOutputs(ctx context.Context,
return fullOutputs, &outputsURLBlob, nil
}

func IsLocalURI(ctx context.Context, store *storage.DataStore, uri string) bool {
base := store.GetBaseContainerFQN(ctx)
return strings.HasPrefix(uri, string(base))
}

func readFromDataPlane(ctx context.Context,
objectStore ObjectStore,
project, domain, reference string,
cluster, project, domain, reference string,
msg proto.Message,
) error {
if objectStore == nil {
Expand All @@ -181,6 +185,7 @@ func readFromDataPlane(ctx context.Context,
}

out, err := objectStore.GetObject(ctx, GetObjectRequest{
Cluster: cluster,
Prefix: refURL.Path,
Project: project,
Domain: domain,
Expand Down
21 changes: 13 additions & 8 deletions flyteadmin/pkg/manager/impl/util/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ var testLiteralMap = &core.LiteralMap{
},
}

const testOutputsURI = "s3://bucket/bar/outputs.pb"
const (
testOutputsURI = "s3://bucket/bar/outputs.pb"
clusterName = "foo-cluster"
)

type objectStoreMock struct {
mock.Mock
Expand Down Expand Up @@ -149,7 +152,7 @@ func TestGetInputs(t *testing.T) {
t.Run("should sign URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: true}

fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputsURI, nil)
fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, "", project, domain, inputsURI, nil)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
Expand All @@ -159,7 +162,7 @@ func TestGetInputs(t *testing.T) {
t.Run("should not sign URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: false}

fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputsURI, nil)
fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, "", project, domain, inputsURI, nil)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
Expand All @@ -172,6 +175,7 @@ func TestGetInputs(t *testing.T) {
bts, _ := proto.Marshal(testLiteralMap)
store.
On("GetObject", GetObjectRequest{
Cluster: clusterName,
Project: project,
Domain: domain,
Prefix: "/foo/bar",
Expand All @@ -182,7 +186,7 @@ func TestGetInputs(t *testing.T) {
ctx := context.TODO()
inputURI := "s3://wrong/foo/bar"

fullInputs, inputURLBlob, err := GetInputs(ctx, mockRemoteURL, &remoteDataConfig, mockStorage, project, domain, inputURI, store)
fullInputs, inputURLBlob, err := GetInputs(ctx, mockRemoteURL, &remoteDataConfig, mockStorage, clusterName, project, domain, inputURI, store)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
Expand Down Expand Up @@ -222,7 +226,7 @@ func TestGetOutputs(t *testing.T) {
t.Run("offloaded outputs with signed URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: true}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil)
fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
Expand All @@ -232,7 +236,7 @@ func TestGetOutputs(t *testing.T) {
t.Run("offloaded outputs without signed URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{Enabled: false}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil)
fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
Expand All @@ -245,6 +249,7 @@ func TestGetOutputs(t *testing.T) {
bts, _ := proto.Marshal(testLiteralMap)
store.
On("GetObject", GetObjectRequest{
Cluster: clusterName,
Project: project,
Domain: domain,
Prefix: "/foo/bar",
Expand All @@ -257,7 +262,7 @@ func TestGetOutputs(t *testing.T) {
},
}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, testClosure, project, domain, store)
fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, testClosure, clusterName, project, domain, store)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
Expand Down Expand Up @@ -286,7 +291,7 @@ func TestGetOutputs(t *testing.T) {
},
}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, project, domain, nil)
fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure, "", project, domain, nil)

assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
Expand Down

0 comments on commit d0baca9

Please sign in to comment.