Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add GetDeckPath to OutputReader (#268)
Browse files Browse the repository at this point in the history
* Add GetDeckPath to RemoteFileOutputReader

Signed-off-by: Kevin Su <[email protected]>

* make generate

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* more tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jun 7, 2022
1 parent 2c83e24 commit f577f62
Show file tree
Hide file tree
Showing 22 changed files with 237 additions and 12 deletions.
4 changes: 4 additions & 0 deletions go/tasks/pluginmachinery/core/template/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ type dummyOutputPaths struct {
checkpointPath storage.DataReference
}

func (d dummyOutputPaths) GetDeckPath() storage.DataReference {
panic("should not be called")
}

func (d dummyOutputPaths) GetPreviousCheckpointsPrefix() storage.DataReference {
return d.prevCheckpointPath
}
Expand Down
4 changes: 4 additions & 0 deletions go/tasks/pluginmachinery/io/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type OutputReader interface {
Exists(ctx context.Context) (bool, error)
// Read returns the output -> *core.LiteralMap (nil if void), *ExecutionError if user error when reading the output and error to indicate system problems
Read(ctx context.Context) (*core.LiteralMap, *ExecutionError, error)
// GetDeckPath returns a fully qualified path (URN) of deck file.
GetDeckPath() *storage.DataReference
}

// CheckpointPaths provides the paths / keys to input Checkpoints directory and an output checkpoints directory.
Expand Down Expand Up @@ -77,6 +79,8 @@ type OutputFilePaths interface {
GetOutputPrefixPath() storage.DataReference
// GetOutputPath returns a fully qualified path (URN) to where the framework expects the output to exist in the configured storage backend
GetOutputPath() storage.DataReference
// GetDeckPath returns a fully qualified path (URN) to where the framework expects the deck.html to exist in the configured storage backend
GetDeckPath() storage.DataReference
// GetErrorPath returns a fully qualified path (URN) where the error information should be placed as a protobuf core.ErrorDocument. It is not directly
// used by the framework, but could be used in the future
GetErrorPath() storage.DataReference
Expand Down
32 changes: 32 additions & 0 deletions go/tasks/pluginmachinery/io/mocks/output_file_paths.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions go/tasks/pluginmachinery/io/mocks/output_reader.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions go/tasks/pluginmachinery/io/mocks/output_writer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion go/tasks/pluginmachinery/ioutils/in_memory_output_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"context"
"fmt"

"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
)

type InMemoryOutputReader struct {
literals *core.LiteralMap
DeckPath *storage.DataReference
err *io.ExecutionError
}

Expand Down Expand Up @@ -40,9 +43,14 @@ func (r InMemoryOutputReader) Read(ctx context.Context) (*core.LiteralMap, *io.E
return r.literals, r.err, nil
}

func NewInMemoryOutputReader(literals *core.LiteralMap, err *io.ExecutionError) InMemoryOutputReader {
func (r InMemoryOutputReader) GetDeckPath() *storage.DataReference {
return r.DeckPath
}

func NewInMemoryOutputReader(literals *core.LiteralMap, DeckPath *storage.DataReference, err *io.ExecutionError) InMemoryOutputReader {
return InMemoryOutputReader{
literals: literals,
DeckPath: DeckPath,
err: err,
}
}
44 changes: 44 additions & 0 deletions go/tasks/pluginmachinery/ioutils/in_memory_output_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ioutils

import (
"context"
"testing"

flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/storage"
"github.com/stretchr/testify/assert"
)

func TestInMemoryOutputReader(t *testing.T) {
deckPath := storage.DataReference("s3://bucket/key")
lt := map[string]*flyteIdlCore.Literal{
"results": {
Value: &flyteIdlCore.Literal_Scalar{
Scalar: &flyteIdlCore.Scalar{
Value: &flyteIdlCore.Scalar_Primitive{
Primitive: &flyteIdlCore.Primitive{Value: &flyteIdlCore.Primitive_Integer{Integer: 3}},
},
},
},
},
}
or := NewInMemoryOutputReader(&flyteIdlCore.LiteralMap{Literals: lt}, &deckPath, nil)

assert.Equal(t, &deckPath, or.GetDeckPath())
ctx := context.TODO()

ok, err := or.IsError(ctx)
assert.False(t, ok)
assert.NoError(t, err)

assert.False(t, or.IsFile(ctx))

ok, err = or.Exists(ctx)
assert.True(t, ok)
assert.NoError(t, err)

literalMap, executionErr, err := or.Read(ctx)
assert.Equal(t, lt, literalMap.Literals)
assert.Nil(t, executionErr)
assert.NoError(t, err)
}
3 changes: 3 additions & 0 deletions go/tasks/pluginmachinery/ioutils/paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ const (
// OutputsSuffix specifies that outputs are assumed to be written to this "file"/"suffix" under the given prefix
// The outputs file has a format of core.LiteralMap
OutputsSuffix = "outputs.pb"
// deckSuffix specifies that deck file are assumed to be written to this "file"/"suffix" under the given prefix
// The deck file has a format of HTML
deckSuffix = "deck.html"
// ErrorsSuffix specifies that the errors are written to this prefix/file under the given prefix. The Error File
// has a format of core.ErrorDocument
ErrorsSuffix = "error.pb"
Expand Down
5 changes: 5 additions & 0 deletions go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ func (r RemoteFileOutputReader) IsFile(ctx context.Context) bool {
return true
}

func (r RemoteFileOutputReader) GetDeckPath() *storage.DataReference {
path := r.outPath.GetDeckPath()
return &path
}

func NewRemoteFileOutputReader(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader {
return RemoteFileOutputReader{
outPath: outPaths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"

"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
storageMocks "github.com/flyteorg/flytestdlib/storage/mocks"
Expand All @@ -16,6 +18,8 @@ func TestReadOrigin(t *testing.T) {

opath := &pluginsIOMock.OutputFilePaths{}
opath.OnGetErrorPath().Return("")
deckPath := "deck.html"
opath.OnGetDeckPath().Return(storage.DataReference(deckPath))

t.Run("user", func(t *testing.T) {
errorDoc := &core.ErrorDocument{
Expand Down Expand Up @@ -44,6 +48,7 @@ func TestReadOrigin(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, core.ExecutionError_USER, ee.Kind)
assert.False(t, ee.IsRecoverable)
assert.Equal(t, deckPath, r.GetDeckPath().String())
})

t.Run("system", func(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ func (w RemoteFileOutputPaths) GetOutputPath() storage.DataReference {
return constructPath(w.store, w.outputPrefix, OutputsSuffix)
}

func (w RemoteFileOutputPaths) GetDeckPath() storage.DataReference {
return constructPath(w.store, w.outputPrefix, deckSuffix)
}

func (w RemoteFileOutputPaths) GetErrorPath() storage.DataReference {
return constructPath(w.store, w.outputPrefix, ErrorsSuffix)
}
Expand Down
48 changes: 48 additions & 0 deletions go/tasks/pluginmachinery/ioutils/remote_file_output_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ioutils

import (
"context"
"testing"

"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/storage"
"github.com/stretchr/testify/assert"
)

func TestRemoteFileOutputWriter(t *testing.T) {
ctx := context.TODO()
memStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.Nil(t, err)

outputPrefix := storage.DataReference("output")
rawOutputPrefix := storage.DataReference("sandbox")
previousCheckpointPath := storage.DataReference("checkpoint")

checkpointPath := NewCheckpointRemoteFilePaths(
ctx,
memStore,
outputPrefix,
NewRawOutputPaths(ctx, rawOutputPrefix),
previousCheckpointPath,
)

t.Run("Test NewCheckpointRemoteFilePaths", func(t *testing.T) {
assert.Equal(t, previousCheckpointPath, checkpointPath.GetPreviousCheckpointsPrefix())
assert.Equal(t, outputPrefix, checkpointPath.GetOutputPrefixPath())

assert.Equal(t, constructPath(memStore, rawOutputPrefix, CheckpointPrefix), checkpointPath.GetCheckpointPrefix())
assert.Equal(t, constructPath(memStore, outputPrefix, OutputsSuffix), checkpointPath.GetOutputPath())
assert.Equal(t, constructPath(memStore, outputPrefix, deckSuffix), checkpointPath.GetDeckPath())
assert.Equal(t, constructPath(memStore, outputPrefix, ErrorsSuffix), checkpointPath.GetErrorPath())
assert.Equal(t, constructPath(memStore, outputPrefix, FuturesSuffix), checkpointPath.GetFuturesPath())
})

t.Run("Test NewRemoteFileOutputWriter", func(t *testing.T) {
p := NewRemoteFileOutputWriter(ctx, memStore, checkpointPath)

assert.Equal(t, constructPath(memStore, rawOutputPrefix, CheckpointPrefix), p.GetCheckpointPrefix())
assert.Equal(t, constructPath(memStore, outputPrefix, OutputsSuffix), p.GetOutputPath())
assert.Equal(t, constructPath(memStore, outputPrefix, deckSuffix), p.GetDeckPath())
assert.Equal(t, constructPath(memStore, outputPrefix, ErrorsSuffix), p.GetErrorPath())
})
}
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
return nil, err
}

if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, &io.ExecutionError{
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{
ExecutionError: &core2.ExecutionError{
Code: "",
Message: subJob.Status.Message,
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return currentState, externalResources, err
}

if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, &io.ExecutionError{
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{
ExecutionError: phaseInfo.Err(),
IsRecoverable: phaseInfo.Phase() != core.PhasePermanentFailure,
})); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/outputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (w assembleOutputsWorker) Process(ctx context.Context, workItem workqueue.W
}

ow := ioutils.NewRemoteFileOutputWriter(ctx, i.dataStore, i.outputPaths)
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(finalOutputs, nil)); err != nil {
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(finalOutputs, nil, nil)); err != nil {
return workqueue.WorkStatusNotDone, err
}

Expand Down Expand Up @@ -313,7 +313,7 @@ func (a assembleErrorsWorker) Process(ctx context.Context, workItem workqueue.Wo
}

ow := ioutils.NewRemoteFileOutputWriter(ctx, w.dataStore, w.outputPaths)
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, &io.ExecutionError{
if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{
ExecutionError: &core.ExecutionError{
Code: "",
Message: msg,
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/hive/execution_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func WriteOutputs(ctx context.Context, tCtx core.TaskExecutionContext, currentSt
},
},
},
}, nil))
}, nil, nil))
if err != nil {
logger.Errorf(ctx, "Error writing outputs file: [%s]", err)
return currentState, err
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/k8s/sagemaker/builtin_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (m awsSagemakerPlugin) getTaskPhaseForTrainingJob(
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task")
}
// Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter
if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil)); err != nil {
if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil, nil)); err != nil {
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unable to write output to the remote location")
}
logger.Debugf(ctx, "Successfully produced and returned outputs")
Expand Down
Loading

0 comments on commit f577f62

Please sign in to comment.