diff --git a/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go b/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go index 7c48602b3c..f3f924c917 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/utils_test.go @@ -223,6 +223,7 @@ func TestStripTypeMetadata(t *testing.T) { }, }, }, + Structure: &core.TypeStructure{Tag: "str"}, }, }, }, diff --git a/flytepropeller/pkg/compiler/transformers/k8s/workflow.go b/flytepropeller/pkg/compiler/transformers/k8s/workflow.go index 8c72c3eea6..c06d34d81d 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/workflow.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/workflow.go @@ -168,11 +168,23 @@ func BuildFlyteWorkflow(wfClosure *core.CompiledWorkflowClosure, inputs *core.Li return nil, errs } - for _, t := range wfClosure.Tasks { + wf := wfClosure.Primary.Template + tasks := wfClosure.Tasks + // Fill in inputs in the start node. + if inputs != nil { + if ok := validateInputs(common.StartNodeID, wf.GetInterface(), *inputs, errs.NewScope()); !ok { + return nil, errs + } + } else if requiresInputs(wf) { + errs.Collect(errors.NewValueRequiredErr("root", "inputs")) + return nil, errs + } + + for _, t := range tasks { t.Template.Interface = StripInterfaceTypeMetadata(t.Template.Interface) } - primarySpec, err := buildFlyteWorkflowSpec(wfClosure.Primary, wfClosure.Tasks, errs.NewScope()) + primarySpec, err := buildFlyteWorkflowSpec(wfClosure.Primary, tasks, errs.NewScope()) if err != nil { errs.Collect(errors.NewWorkflowBuildError(err)) return nil, errs @@ -180,7 +192,7 @@ func BuildFlyteWorkflow(wfClosure *core.CompiledWorkflowClosure, inputs *core.Li subwfs := make(map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec, len(wfClosure.SubWorkflows)) for _, subWf := range wfClosure.SubWorkflows { - spec, err := buildFlyteWorkflowSpec(subWf, wfClosure.Tasks, errs.NewScope()) + spec, err := buildFlyteWorkflowSpec(subWf, tasks, errs.NewScope()) if err != nil { errs.Collect(errors.NewWorkflowBuildError(err)) } else { @@ -192,18 +204,6 @@ func BuildFlyteWorkflow(wfClosure *core.CompiledWorkflowClosure, inputs *core.Li return nil, errs } - wf := wfClosure.Primary.Template - tasks := wfClosure.Tasks - // Fill in inputs in the start node. - if inputs != nil { - if ok := validateInputs(common.StartNodeID, wf.GetInterface(), *inputs, errs.NewScope()); !ok { - return nil, errs - } - } else if requiresInputs(wf) { - errs.Collect(errors.NewValueRequiredErr("root", "inputs")) - return nil, errs - } - interruptible := false if wf.GetMetadataDefaults() != nil { interruptible = wf.GetMetadataDefaults().GetInterruptible() diff --git a/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go index d2387e2b2d..f9c9a8c517 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/workflow_test.go @@ -185,6 +185,72 @@ func TestBuildFlyteWorkflow_withInputs(t *testing.T) { assert.Equal(t, int64(123), wf.Inputs.Literals["x"].GetScalar().GetPrimitive().GetInteger()) } +func TestBuildFlyteWorkflow_withUnionInputs(t *testing.T) { + w := createSampleMockWorkflow() + + startNode := w.GetNodes()[common.StartNodeID].(*mockNode) + strType := core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}, Structure: &core.TypeStructure{Tag: "str"}} + floatType := core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}, Structure: &core.TypeStructure{Tag: "float"}} + vars := []*core.Variable{ + { + Type: &core.LiteralType{Type: &core.LiteralType_UnionType{UnionType: &core.UnionType{Variants: []*core.LiteralType{&strType, &floatType}}}}, + }, + { + Type: &core.LiteralType{Type: &core.LiteralType_UnionType{UnionType: &core.UnionType{Variants: []*core.LiteralType{&strType, &floatType}}}}, + }, + } + + w.Template.Interface = &core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + startNode.iface = &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": vars[0], + "y": vars[1], + }, + }, + } + + stringLiteral, err := coreutils.MakePrimitiveLiteral("hello") + assert.NoError(t, err) + floatLiteral, err := coreutils.MakePrimitiveLiteral(1.0) + assert.NoError(t, err) + inputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": {Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Union{Union: &core.Union{Value: floatLiteral, Type: &floatType}}}}}, + "y": {Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Union{Union: &core.Union{Value: stringLiteral, Type: &strType}}}}}, + }, + } + + errors.SetConfig(errors.Config{IncludeSource: true}) + wf, err := BuildFlyteWorkflow( + &core.CompiledWorkflowClosure{ + Primary: w.GetCoreWorkflow(), + Tasks: []*core.CompiledTask{ + { + Template: &core.TaskTemplate{ + Id: &core.Identifier{Name: "ref_1"}, + }, + }, + }, + }, + inputs, nil, "") + assert.NoError(t, err) + assert.NotNil(t, wf) + errors.SetConfig(errors.Config{}) + + assert.Equal(t, 2, len(wf.Inputs.Literals)) + assert.Equal(t, 1.0, wf.Inputs.Literals["x"].GetScalar().GetUnion().GetValue().GetScalar().GetPrimitive().GetFloatValue()) + assert.Equal(t, "hello", wf.Inputs.Literals["y"].GetScalar().GetUnion().GetValue().GetScalar().GetPrimitive().GetStringValue()) +} + func TestGenerateName(t *testing.T) { t.Run("Invalid params", func(t *testing.T) { _, _, _, _, _, err := generateName(nil, nil)