Skip to content

Commit

Permalink
Overlap fetching input and output data (#109)
Browse files Browse the repository at this point in the history
## Overview
This change updates `GetExecutionData`, `GetNodeExecutionData`, and `GetTaskExecutionData` to use overlapped reads when fetching input and output data.

## Test Plan
- [x] Existing unit tests pass

## Rollout Plan (if applicable)
Pick up and rollout via org-sync pipelines. Change is low risk

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If so, please check this box for auditing. Note, this is the responsibility of each developer. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [x] To be upstreamed

## Jira Issue
https://unionai.atlassian.net/browse/CLOUD-1621
  • Loading branch information
andrewwdye authored Mar 3, 2024
1 parent b14f821 commit f1569a2
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 74 deletions.
73 changes: 39 additions & 34 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1058,35 +1058,34 @@ func (m *ExecutionManager) launchExecution(
}

// Overlap the blob store reads and writes
getClosureGroup, getClosureGroupCtx := errgroup.WithContext(ctx)
group, groupCtx := errgroup.WithContext(ctx)
var closure *admin.WorkflowClosure
getClosureGroup.Go(func() error {
group.Go(func() error {
var err error
closure, err = util.FetchAndGetWorkflowClosure(getClosureGroupCtx, m.storageClient, workflowModel.RemoteClosureIdentifier)
closure, err = util.FetchAndGetWorkflowClosure(groupCtx, m.storageClient, workflowModel.RemoteClosureIdentifier)
if err != nil {
logger.Debugf(ctx, "Failed to get workflow with id %+v with err %v", launchPlan.Spec.WorkflowId, err)
}
return err
})

offloadInputsGroup, offloadInputsGroupCtx := errgroup.WithContext(ctx)
var inputsURI storage.DataReference
offloadInputsGroup.Go(func() error {
group.Go(func() error {
var err error
inputsURI, err = common.OffloadLiteralMap(offloadInputsGroupCtx, m.storageClient, executionInputs,
inputsURI, err = common.OffloadLiteralMap(groupCtx, m.storageClient, executionInputs,
workflowExecutionID.Org, workflowExecutionID.Project, workflowExecutionID.Domain, workflowExecutionID.Name, shared.Inputs)
return err
})

var userInputsURI storage.DataReference
offloadInputsGroup.Go(func() error {
group.Go(func() error {
var err error
userInputsURI, err = common.OffloadLiteralMap(offloadInputsGroupCtx, m.storageClient, request.Inputs,
userInputsURI, err = common.OffloadLiteralMap(groupCtx, m.storageClient, request.Inputs,
workflowExecutionID.Org, workflowExecutionID.Project, workflowExecutionID.Domain, workflowExecutionID.Name, shared.UserInputs)
return err
})

err = getClosureGroup.Wait()
err = group.Wait()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -1190,11 +1189,6 @@ func (m *ExecutionManager) launchExecution(
notificationsSettings = make([]*admin.Notification, 0)
}

err = offloadInputsGroup.Wait()
if err != nil {
return nil, nil, err
}

// Publish of event is also gated on the artifact client being available, even though it's not directly required.
// TODO: Artifact feature gate, remove when ready
if m.artifactRegistry.GetClient() != nil {
Expand Down Expand Up @@ -1760,30 +1754,41 @@ func (m *ExecutionManager) GetExecutionData(

id := request.GetId()
objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
inputs, inputURLBlob, err := util.GetInputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
executionModel.InputsURI.String(),
objectStore)
if err != nil {
return nil, err
}
var inputs *core.LiteralMap
var inputURLBlob *admin.UrlBlob
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
inputs, inputURLBlob, err = util.GetInputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
executionModel.InputsURI.String(),
objectStore)
return err
})

var outputs *core.LiteralMap
var outputURLBlob *admin.UrlBlob
group.Go(func() error {
var err error
outputs, outputURLBlob, err = util.GetOutputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
util.ToExecutionClosureInterface(execution.Closure),
id.Project,
id.Domain,
objectStore)
return err
})

outputs, outputURLBlob, err := util.GetOutputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
util.ToExecutionClosureInterface(execution.Closure),
id.Project,
id.Domain,
objectStore)
err = group.Wait()
if err != nil {
return nil, err
}

response := &admin.WorkflowExecutionGetDataResponse{
Inputs: inputURLBlob,
Outputs: outputURLBlob,
Expand Down
52 changes: 32 additions & 20 deletions flyteadmin/pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/prometheus/client_golang/prometheus"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"

cloudeventInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/cloudevent/interfaces"
Expand Down Expand Up @@ -535,30 +536,41 @@ func (m *NodeExecutionManager) GetNodeExecutionData(

id := request.GetId().GetExecutionId()
objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
inputs, inputURLBlob, err := util.GetInputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
nodeExecution.InputUri,
objectStore)
if err != nil {
return nil, err
}
var inputs *core.LiteralMap
var inputURLBlob *admin.UrlBlob
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
inputs, inputURLBlob, err = util.GetInputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
nodeExecution.InputUri,
objectStore)
return err
})

outputs, outputURLBlob, err := util.GetOutputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
nodeExecution.Closure,
id.Project,
id.Domain,
objectStore)
var outputs *core.LiteralMap
var outputURLBlob *admin.UrlBlob
group.Go(func() error {
var err error
outputs, outputURLBlob, err = util.GetOutputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
nodeExecution.Closure,
id.Project,
id.Domain,
objectStore)
return err
})

err = group.Wait()
if err != nil {
return nil, err
}

response := &admin.NodeExecutionGetDataResponse{
Inputs: inputURLBlob,
Outputs: outputURLBlob,
Expand Down
52 changes: 32 additions & 20 deletions flyteadmin/pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/golang/protobuf/proto"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"

cloudeventInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/cloudevent/interfaces"
Expand Down Expand Up @@ -314,30 +315,41 @@ func (m *TaskExecutionManager) GetTaskExecutionData(

objectStore := plugins.Get[util.ObjectStore](m.pluginRegistry, plugins.PluginIDObjectStore)
id := request.GetId().GetNodeExecutionId().GetExecutionId()
inputs, inputURLBlob, err := util.GetInputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
taskExecution.InputUri,
objectStore)
if err != nil {
return nil, err
}
var inputs *core.LiteralMap
var inputURLBlob *admin.UrlBlob
group, groupCtx := errgroup.WithContext(ctx)
group.Go(func() error {
var err error
inputs, inputURLBlob, err = util.GetInputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
id.Project,
id.Domain,
taskExecution.InputUri,
objectStore)
return err
})

outputs, outputURLBlob, err := util.GetOutputs(ctx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
taskExecution.Closure,
id.Project,
id.Domain,
objectStore)
var outputs *core.LiteralMap
var outputURLBlob *admin.UrlBlob
group.Go(func() error {
var err error
outputs, outputURLBlob, err = util.GetOutputs(groupCtx,
m.urlData,
m.config.ApplicationConfiguration().GetRemoteDataConfig(),
m.storageClient,
taskExecution.Closure,
id.Project,
id.Domain,
objectStore)
return err
})

err = group.Wait()
if err != nil {
return nil, err
}

response := &admin.TaskExecutionGetDataResponse{
Inputs: inputURLBlob,
Outputs: outputURLBlob,
Expand Down

0 comments on commit f1569a2

Please sign in to comment.