diff --git a/flyteplugins/go/tasks/plugins/array/catalog.go b/flyteplugins/go/tasks/plugins/array/catalog.go index c86e16853..df57ac9ef 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog.go +++ b/flyteplugins/go/tasks/plugins/array/catalog.go @@ -79,12 +79,20 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex size := -1 var literalCollection *idlCore.LiteralCollection - var discoveredInputName string + literals := make([][]*idlCore.Literal, 0) + discoveredInputNames := make([]string, 0) for inputName, literal := range inputs.Literals { if literalCollection = literal.GetCollection(); literalCollection != nil { - size = len(literal.GetCollection().Literals) - discoveredInputName = inputName - break + // validate length of input list + if size != -1 && size != len(literalCollection.Literals) { + state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason("all maptask input lists must be the same length") + return state, nil + } + + literals = append(literals, literalCollection.Literals) + discoveredInputNames = append(discoveredInputNames, inputName) + + size = len(literalCollection.Literals) } } @@ -105,7 +113,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex arrayJobSize = int64(size) // build input readers - inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literalCollection.Literals, discoveredInputName) + inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames) } if arrayJobSize > maxArrayJobSize { @@ -242,16 +250,17 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state } var literalCollection *idlCore.LiteralCollection - var discoveredInputName string + literals := make([][]*idlCore.Literal, 0) + discoveredInputNames := make([]string, 0) for inputName, literal := range inputs.Literals { if literalCollection = literal.GetCollection(); literalCollection != nil { - discoveredInputName = inputName - break + literals = append(literals, literalCollection.Literals) + discoveredInputNames = append(discoveredInputNames, inputName) } } // build input readers - inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literalCollection.Literals, discoveredInputName) + inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames) } // output reader @@ -470,14 +479,19 @@ func ConstructCatalogReaderWorkItems(ctx context.Context, taskReader core.TaskRe // ConstructStaticInputReaders constructs input readers that comply with the io.InputReader interface but have their // inputs already populated. -func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs []*idlCore.Literal, inputName string) []io.InputReader { +func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs [][]*idlCore.Literal, inputNames []string) []io.InputReader { inputReaders := make([]io.InputReader, 0, len(inputs)) - for i := 0; i < len(inputs); i++ { - inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlCore.LiteralMap{ - Literals: map[string]*idlCore.Literal{ - inputName: inputs[i], - }, - })) + if len(inputs) == 0 { + return inputReaders + } + + for i := 0; i < len(inputs[0]); i++ { + literals := make(map[string]*idlCore.Literal) + for j := 0; j < len(inputNames); j++ { + literals[inputNames[j]] = inputs[j][i] + } + + inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlCore.LiteralMap{Literals: literals})) } return inputReaders diff --git a/flyteplugins/go/tasks/plugins/array/catalog_test.go b/flyteplugins/go/tasks/plugins/array/catalog_test.go index 04f4577ca..454d23794 100644 --- a/flyteplugins/go/tasks/plugins/array/catalog_test.go +++ b/flyteplugins/go/tasks/plugins/array/catalog_test.go @@ -5,36 +5,103 @@ import ( "errors" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - structpb "github.com/golang/protobuf/ptypes/struct" - - stdErrors "github.com/flyteorg/flytestdlib/errors" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" - - "github.com/flyteorg/flytestdlib/bitarray" - - core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - catalogMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog/mocks" + core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/bitarray" + stdErrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" - pluginMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/go-test/deep" - arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" + structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/go-test/deep" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + dummyInputLiteral = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 3, + }, + }, + }, + }, + }, + } + singleInput = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + }, + } + multipleInputs = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + "bar": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + }, + } + multipleInputsInvalid = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + "bar": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + dummyInputLiteral, dummyInputLiteral, + }, + }, + }, + }, + }, + } ) func TestNewLiteralScalarOfInteger(t *testing.T) { @@ -55,7 +122,7 @@ func TestCatalogBitsetToLiteralCollection(t *testing.T) { } func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTemplate, future catalog.DownloadFuture, - expectedState *arrayCore.State, maxArrayJobSize int64, expectedError error) { + inputs *core.LiteralMap, expectedState *arrayCore.State, maxArrayJobSize int64, expectedError error) { ctx := context.Background() @@ -70,32 +137,7 @@ func runDetermineDiscoverabilityTest(t testing.TB, taskTemplate *core.TaskTempla ir := &ioMocks.InputReader{} ir.OnGetInputPrefixPath().Return("/prefix/") - dummyInputLiteral := &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Primitive{ - Primitive: &core.Primitive{ - Value: &core.Primitive_Integer{ - Integer: 3, - }, - }, - }, - }, - }, - } - ir.On("Get", mock.Anything).Return(&core.LiteralMap{ - Literals: map[string]*core.Literal{ - "foo": { - Value: &core.Literal_Collection{ - Collection: &core.LiteralCollection{ - Literals: []*core.Literal{ - dummyInputLiteral, dummyInputLiteral, dummyInputLiteral, - }, - }, - }, - }, - }, - }, nil) + ir.On("Get", mock.Anything).Return(inputs, nil) ow := &ioMocks.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") @@ -151,7 +193,7 @@ func TestDetermineDiscoverability(t *testing.T) { f.OnGetResponse().Return(download, nil) t.Run("Bad Task Spec", func(t *testing.T) { - runDetermineDiscoverabilityTest(t, template, f, nil, 0, stdErrors.Errorf(pluginErrors.BadTaskSpecification, "")) + runDetermineDiscoverabilityTest(t, template, f, singleInput, nil, 0, stdErrors.Errorf(pluginErrors.BadTaskSpecification, "")) }) template = &core.TaskTemplate{ @@ -178,7 +220,7 @@ func TestDetermineDiscoverability(t *testing.T) { t.Run("Run AWS Batch single job", func(t *testing.T) { toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1) template.Type = AwsBatchTaskType - runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, f, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, ExecutionArraySize: 1, @@ -192,7 +234,7 @@ func TestDetermineDiscoverability(t *testing.T) { t.Run("Not discoverable", func(t *testing.T) { toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(1), 1) - runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, f, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, ExecutionArraySize: 1, @@ -213,7 +255,7 @@ func TestDetermineDiscoverability(t *testing.T) { toCache := bitarray.NewBitSet(1) toCache.Set(0) - runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, f, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, ExecutionArraySize: 1, @@ -231,7 +273,7 @@ func TestDetermineDiscoverability(t *testing.T) { download.OnGetCachedResults().Return(cachedResults).Once() toCache := bitarray.NewBitSet(1) - runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, f, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, ExecutionArraySize: 1, @@ -247,7 +289,7 @@ func TestDetermineDiscoverability(t *testing.T) { future.OnGetResponseStatus().Return(catalog.ResponseStatusNotReady) future.On("OnReady", mock.Anything).Return(func(_ context.Context, _ catalog.Future) {}) - runDetermineDiscoverabilityTest(t, template, future, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, future, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhaseStart, PhaseVersion: core2.DefaultPhaseVersion, OriginalArraySize: 1, @@ -256,7 +298,7 @@ func TestDetermineDiscoverability(t *testing.T) { }) t.Run("MaxArrayJobSizeFailure", func(t *testing.T) { - runDetermineDiscoverabilityTest(t, template, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, template, f, singleInput, &arrayCore.State{ CurrentPhase: arrayCore.PhasePermanentFailure, PhaseVersion: core2.DefaultPhaseVersion, OriginalArraySize: 1, @@ -277,46 +319,60 @@ func TestDiscoverabilityTaskType1(t *testing.T) { f.OnGetResponseError().Return(nil) f.OnGetResponse().Return(download, nil) + arrayJob := &plugins.ArrayJob{ + SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ + MinSuccessRatio: 0.5, + }, + } + var arrayJobCustom structpb.Struct + err := utils.MarshalStruct(arrayJob, &arrayJobCustom) + assert.NoError(t, err) + templateType1 := &core.TaskTemplate{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "p", + Domain: "d", + Name: "n", + Version: "1", + }, + Interface: &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{ + "foo": { + Description: "foo", + }, + }}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: []string{"cmd"}, + Args: []string{"{{$inputPrefix}}"}, + Image: "img1", + }, + }, + TaskTypeVersion: 1, + Custom: &arrayJobCustom, + } + t.Run("Not discoverable", func(t *testing.T) { download.OnGetCachedResults().Return(bitarray.NewBitSet(1)).Once() toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3)) - arrayJob := &plugins.ArrayJob{ - SuccessCriteria: &plugins.ArrayJob_MinSuccessRatio{ - MinSuccessRatio: 0.5, - }, - } - var arrayJobCustom structpb.Struct - err := utils.MarshalStruct(arrayJob, &arrayJobCustom) - assert.NoError(t, err) - templateType1 := &core.TaskTemplate{ - Id: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Project: "p", - Domain: "d", - Name: "n", - Version: "1", - }, - Interface: &core.TypedInterface{ - Inputs: &core.VariableMap{Variables: map[string]*core.Variable{ - "foo": { - Description: "foo", - }, - }}, - Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, - }, - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Command: []string{"cmd"}, - Args: []string{"{{$inputPrefix}}"}, - Image: "img1", - }, - }, - TaskTypeVersion: 1, - Custom: &arrayJobCustom, - } + runDetermineDiscoverabilityTest(t, templateType1, f, singleInput, &arrayCore.State{ + CurrentPhase: arrayCore.PhasePreLaunch, + PhaseVersion: core2.DefaultPhaseVersion, + ExecutionArraySize: 3, + OriginalArraySize: 3, + OriginalMinSuccesses: 2, + IndexesToCache: toCache, + Reason: "Task is not discoverable.", + }, 3, nil) + }) + + t.Run("MultipleInputs", func(t *testing.T) { + toCache := arrayCore.InvertBitSet(bitarray.NewBitSet(uint(3)), uint(3)) - runDetermineDiscoverabilityTest(t, templateType1, f, &arrayCore.State{ + runDetermineDiscoverabilityTest(t, templateType1, f, multipleInputs, &arrayCore.State{ CurrentPhase: arrayCore.PhasePreLaunch, PhaseVersion: core2.DefaultPhaseVersion, ExecutionArraySize: 3, @@ -326,4 +382,16 @@ func TestDiscoverabilityTaskType1(t *testing.T) { Reason: "Task is not discoverable.", }, 3, nil) }) + + t.Run("MultipleInputsInvalid", func(t *testing.T) { + runDetermineDiscoverabilityTest(t, templateType1, f, multipleInputsInvalid, &arrayCore.State{ + CurrentPhase: arrayCore.PhasePermanentFailure, + PhaseVersion: core2.DefaultPhaseVersion, + ExecutionArraySize: 0, + OriginalArraySize: 0, + OriginalMinSuccesses: 0, + IndexesToCache: nil, + Reason: "all maptask input lists must be the same length", + }, 3, nil) + }) }