Skip to content

Commit

Permalink
Refactor echo plugin (#5565)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jul 17, 2024
1 parent 9638db0 commit 0b39839
Showing 1 changed file with 60 additions and 34 deletions.
94 changes: 60 additions & 34 deletions flyteplugins/go/tasks/plugins/testing/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package testing
import (
"context"
"fmt"
"sync"
"time"

idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -20,6 +21,7 @@ const (
type EchoPlugin struct {
enqueueOwner core.EnqueueOwner
taskStartTimes map[string]time.Time
sync.Mutex
}

func (e *EchoPlugin) GetID() string {
Expand All @@ -30,9 +32,11 @@ func (e *EchoPlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
echoConfig := ConfigSection.GetConfig().(*Config)

// Enqueue the task to be re-evaluated after SleepDuration.
// If the task is already enqueued, return the start time of the task.
func (e *EchoPlugin) addTask(ctx context.Context, tCtx core.TaskExecutionContext) time.Time {
e.Lock()
defer e.Unlock()
var startTime time.Time
var exists bool
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
Expand All @@ -42,47 +46,34 @@ func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext)

// start timer to enqueue owner once task sleep duration has elapsed
go func() {
echoConfig := ConfigSection.GetConfig().(*Config)
time.Sleep(echoConfig.SleepDuration.Duration)
if err := e.enqueueOwner(tCtx.TaskExecutionMetadata().GetOwnerID()); err != nil {
logger.Warnf(ctx, "failed to enqueue owner [%s]: %v", tCtx.TaskExecutionMetadata().GetOwnerID(), err)
}
}()
}
return startTime
}

if time.Since(startTime) >= echoConfig.SleepDuration.Duration {
// copy inputs to outputs
inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}

if len(inputToOutputVariableMappings) > 0 {
inputLiterals, err := tCtx.InputReader().Get(ctx)
if err != nil {
return core.UnknownTransition, err
}

outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings))
for inputVariableName, outputVariableName := range inputToOutputVariableMappings {
outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName]
}
// Remove the task from the taskStartTimes map.
func (e *EchoPlugin) removeTask(taskExecutionID string) {
e.Lock()
defer e.Unlock()
delete(e.taskStartTimes, taskExecutionID)
}

outputLiteralMap := &idlcore.LiteralMap{
Literals: outputLiterals,
}
func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
echoConfig := ConfigSection.GetConfig().(*Config)

outputFile := tCtx.OutputWriter().GetOutputPath()
if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil {
return core.UnknownTransition, err
}
if echoConfig.SleepDuration.Duration == time.Duration(0) {
return copyInputsToOutputs(ctx, tCtx)
}

or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0)
if err = tCtx.OutputWriter().Put(ctx, or); err != nil {
return core.UnknownTransition, err
}
}
startTime := e.addTask(ctx, tCtx)

return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
if time.Since(startTime) >= echoConfig.SleepDuration.Duration {
return copyInputsToOutputs(ctx, tCtx)
}

return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
Expand All @@ -94,10 +85,45 @@ func (e *EchoPlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext)

func (e *EchoPlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
delete(e.taskStartTimes, taskExecutionID)
e.removeTask(taskExecutionID)
return nil
}

// copyInputsToOutputs copies the input literals to the output location.
func copyInputsToOutputs(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}

if len(inputToOutputVariableMappings) > 0 {
inputLiterals, err := tCtx.InputReader().Get(ctx)
if err != nil {
return core.UnknownTransition, err
}

outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings))
for inputVariableName, outputVariableName := range inputToOutputVariableMappings {
outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName]
}

outputLiteralMap := &idlcore.LiteralMap{
Literals: outputLiterals,
}

outputFile := tCtx.OutputWriter().GetOutputPath()
if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil {
return core.UnknownTransition, err
}

or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0)
if err = tCtx.OutputWriter().Put(ctx, or); err != nil {
return core.UnknownTransition, err
}
}
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

func compileInputToOutputVariableMappings(ctx context.Context, tCtx core.TaskExecutionContext) (map[string]string, error) {
// validate outputs are castable from inputs otherwise error as this plugin is not applicable
taskTemplate, err := tCtx.TaskReader().Read(ctx)
Expand Down

0 comments on commit 0b39839

Please sign in to comment.