Skip to content

Commit

Permalink
#minor Fix single task execution node input generation (flyteorg#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jun 29, 2020
1 parent dcf45c9 commit 935bb89
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 637 deletions.
1 change: 0 additions & 1 deletion flyteadmin/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ require (
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/lib/pq v1.3.0
github.com/lyft/flyteidl v0.17.34
github.com/lyft/flyteplugins v0.3.35 // indirect
github.com/lyft/flytepropeller v0.2.64
github.com/lyft/flytestdlib v0.3.9
github.com/magiconair/properties v1.8.1
Expand Down
3 changes: 3 additions & 0 deletions flyteadmin/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ github.com/lyft/flyteidl v0.17.8/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/
github.com/lyft/flyteidl v0.17.27 h1:0EdSHauzdPEYmubYib/XC6fLb+srzP4yDRN1P9o4W/I=
github.com/lyft/flyteidl v0.17.27/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.32/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.34 h1:8ERT/8vY40dOPPJrdD8ossBb30WkvzUx/IAFMR/7+9U=
github.com/lyft/flyteidl v0.17.34/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteplugins v0.3.11/go.mod h1:FOSo04q4EheU6lm0oZFvfYAWgjrum/BDUK+mUT7qDFA=
github.com/lyft/flyteplugins v0.3.33/go.mod h1:HHO6KC/2z77n9o9KM697YvSP85IWDe6jl6tAIrMLqWU=
Expand Down Expand Up @@ -663,6 +664,7 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
Expand Down Expand Up @@ -1025,6 +1027,7 @@ gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
Expand Down
135 changes: 5 additions & 130 deletions flyteadmin/pkg/manager/impl/util/single_task_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ var defaultTimeout = ptypes.DurationProto(24 * time.Hour)

const systemNamePrefix = ".flytegen.%s"

const noInputNodeID = ""

func generateNodeNameFromTask(taskName string) string {
if len(taskName) >= maxNodeIDLength {
taskName = taskName[len(taskName)-maxNodeIDLength:]
Expand All @@ -48,42 +50,7 @@ func generateWorkflowNameFromTask(taskName string) string {
return fmt.Sprintf(systemNamePrefix, taskName)
}

func getBinding(literal *core.Literal) *core.BindingData {
if literal.GetScalar() != nil {
return &core.BindingData{
Value: &core.BindingData_Scalar{
Scalar: literal.GetScalar(),
},
}
} else if literal.GetCollection() != nil {
bindings := make([]*core.BindingData, len(literal.GetCollection().Literals))
for idx, subLiteral := range literal.GetCollection().Literals {
bindings[idx] = getBinding(subLiteral)
}
return &core.BindingData{
Value: &core.BindingData_Collection{
Collection: &core.BindingDataCollection{
Bindings: bindings,
},
},
}
} else if literal.GetMap() != nil {
bindings := make(map[string]*core.BindingData)
for key, subLiteral := range literal.GetMap().Literals {
bindings[key] = getBinding(subLiteral)
}
return &core.BindingData{
Value: &core.BindingData_Map{
Map: &core.BindingDataMap{
Bindings: bindings,
},
},
}
}
return nil
}

func generateBindingsFromOutputs(outputs core.VariableMap, nodeID string) []*core.Binding {
func generateBindings(outputs core.VariableMap, nodeID string) []*core.Binding {
bindings := make([]*core.Binding, 0, len(outputs.Variables))
for key := range outputs.Variables {
binding := &core.Binding{
Expand All @@ -103,87 +70,6 @@ func generateBindingsFromOutputs(outputs core.VariableMap, nodeID string) []*cor
return bindings
}

func generateBindingsFromInputs(inputTemplate core.VariableMap, inputs core.LiteralMap) ([]*core.Binding, error) {
bindings := make([]*core.Binding, 0, len(inputTemplate.Variables))
for key, val := range inputTemplate.Variables {
binding := &core.Binding{
Var: key,
}
var bindingData core.BindingData
if val.Type.GetSimple() != core.SimpleType_NONE {
if inputs.Literals[key] != nil {
bindingData = core.BindingData{
Value: &core.BindingData_Scalar{
Scalar: inputs.Literals[key].GetScalar(),
},
}
}

} else if val.Type.GetSchema() != nil {
if inputs.Literals[key] != nil && inputs.Literals[key].GetScalar() != nil {
bindingData = core.BindingData{
Value: &core.BindingData_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Schema{
Schema: inputs.Literals[key].GetScalar().GetSchema(),
},
},
},
}
}
} else if val.Type.GetCollectionType() != nil {
if inputs.Literals[key] != nil && inputs.Literals[key].GetCollection() != nil &&
inputs.Literals[key].GetCollection().GetLiterals() != nil {
collectionBindings := make([]*core.BindingData, len(inputs.Literals[key].GetCollection().GetLiterals()))
for idx, literal := range inputs.Literals[key].GetCollection().GetLiterals() {
collectionBindings[idx] = getBinding(literal)

}
bindingData = core.BindingData{
Value: &core.BindingData_Collection{
Collection: &core.BindingDataCollection{
Bindings: collectionBindings,
},
},
}
}
} else if val.Type.GetMapValueType() != nil {
if inputs.Literals[key] != nil && inputs.Literals[key].GetMap() != nil &&
inputs.Literals[key].GetMap().Literals != nil {
bindingDataMap := make(map[string]*core.BindingData)
for k, v := range inputs.Literals[key].GetMap().Literals {
bindingDataMap[k] = getBinding(v)
}

bindingData = core.BindingData{
Value: &core.BindingData_Map{
Map: &core.BindingDataMap{
Bindings: bindingDataMap,
},
},
}
}
} else if val.Type.GetBlob() != nil {
if inputs.Literals[key] != nil && inputs.Literals[key].GetScalar() != nil {
bindingData = core.BindingData{
Value: &core.BindingData_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Blob{
Blob: inputs.Literals[key].GetScalar().GetBlob(),
},
},
},
}
}
} else {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "Unrecognized value type [%+v]", val.GetType())
}
binding.Binding = &bindingData
bindings = append(bindings, binding)
}
return bindings, nil
}

func CreateOrGetWorkflowModel(
ctx context.Context, request admin.ExecutionCreateRequest, db repositories.RepositoryInterface,
workflowManager interfaces.WorkflowInterface, namedEntityManager interfaces.NamedEntityInterface, taskIdentifier *core.Identifier,
Expand All @@ -207,17 +93,6 @@ func CreateOrGetWorkflowModel(
return nil, err
}
// If we got this far, there is no existing workflow. Create a skeleton one now.
var requestInputs = core.LiteralMap{
Literals: make(map[string]*core.Literal),
}
if request.Inputs != nil {
requestInputs = *request.Inputs
}
generatedInputs, err := generateBindingsFromInputs(*task.Closure.CompiledTask.Template.Interface.Inputs, requestInputs)
if err != nil {
logger.Warningf(ctx, "Failed to generate requestInputs from task input bindings: %v", err)
return nil, err
}
workflowSpec := admin.WorkflowSpec{
Template: &core.WorkflowTemplate{
Id: &workflowIdentifier,
Expand All @@ -230,7 +105,7 @@ func CreateOrGetWorkflowModel(
Retries: &defaultRetryStrategy,
Timeout: defaultTimeout,
},
Inputs: generatedInputs,
Inputs: generateBindings(*task.Closure.CompiledTask.Template.Interface.Inputs, noInputNodeID),
Target: &core.Node_TaskNode{
TaskNode: &core.TaskNode{
Reference: &core.TaskNode_ReferenceId{
Expand All @@ -241,7 +116,7 @@ func CreateOrGetWorkflowModel(
},
},

Outputs: generateBindingsFromOutputs(*task.Closure.CompiledTask.Template.Interface.Outputs, generateNodeNameFromTask(taskIdentifier.Name)),
Outputs: generateBindings(*task.Closure.CompiledTask.Template.Interface.Outputs, generateNodeNameFromTask(taskIdentifier.Name)),
},
}

Expand Down
Loading

0 comments on commit 935bb89

Please sign in to comment.