Skip to content

Commit

Permalink
Use proto file name to infer message type (flyteorg#436)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Oct 23, 2023
1 parent cdfa222 commit 1eda51e
Showing 1 changed file with 56 additions and 19 deletions.
75 changes: 56 additions & 19 deletions cmd/register/register_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,74 @@ var projectColumns = []printer.Column{
{Header: "Additional Info", JSONPath: "$.Info"},
}

// Regex to match file name like xxx_1.pb, xxx_2.pb, or xxx_3.pb, and the subgroup catches the number 1, 2 or 3
// This is used to match proto files created by pyflyte, where xxx_1.pb is a task spec, xxx_2.pb is a workflow spec, and xxx_3.pb is launch plan
var fnameRegex = regexp.MustCompile(`^.*_(?P<index>[1-3])\.pb$`)

type unMarshalFunc = func(ctx context.Context, fileContents []byte, fname string, errCollection errors2.ErrorCollection) (proto.Message, error)

// Order matters here
var unMarshalFuncs = []unMarshalFunc{
unMarshalTask,
unMarshalWorkflow,
unMarshalLaunchPlan,
}

func UnMarshalContents(ctx context.Context, fileContents []byte, fname string) (proto.Message, error) {
workflowSpec := &admin.WorkflowSpec{}
errCollection := errors2.ErrorCollection{}
err := proto.Unmarshal(fileContents, workflowSpec)
if err == nil {
return workflowSpec, nil

for _, f := range reorderUnMarshalFuncs(fname) {
if m, err := f(ctx, fileContents, fname, errCollection); err == nil {
return m, nil
}
}

errCollection.Append(fmt.Errorf("as a Workflow: %w", err))
return nil, fmt.Errorf("failed unmarshalling file %v. Errors: %w", fname, errCollection.ErrorOrDefault())
}

logger.Debugf(ctx, "Failed to unmarshal file %v for workflow type", fname)
taskSpec := &admin.TaskSpec{}
err = proto.Unmarshal(fileContents, taskSpec)
if err == nil {
return taskSpec, nil
}
func unMarshalTask(ctx context.Context, fileContents []byte, fname string, errCollection errors2.ErrorCollection) (proto.Message, error) {
return unMarshal(ctx, fileContents, fname, errCollection, "Task", "task", &admin.TaskSpec{})
}

errCollection.Append(fmt.Errorf("as a Task: %w", err))
func unMarshalWorkflow(ctx context.Context, fileContents []byte, fname string, errCollection errors2.ErrorCollection) (proto.Message, error) {
return unMarshal(ctx, fileContents, fname, errCollection, "Workflow", "workflow", &admin.WorkflowSpec{})
}

func unMarshalLaunchPlan(ctx context.Context, fileContents []byte, fname string, errCollection errors2.ErrorCollection) (proto.Message, error) {
return unMarshal(ctx, fileContents, fname, errCollection, "Launchplan", "launch plan", &admin.LaunchPlan{})
}

logger.Debugf(ctx, "Failed to unmarshal file %v for task type", fname)
launchPlan := &admin.LaunchPlan{}
err = proto.Unmarshal(fileContents, launchPlan)
func unMarshal(ctx context.Context, fileContents []byte, fname string, errCollection errors2.ErrorCollection, tpe string, typeAlt string, m proto.Message) (proto.Message, error) {
err := proto.Unmarshal(fileContents, m)
if err == nil {
return launchPlan, nil
return m, nil
}

errCollection.Append(fmt.Errorf("as a Launchplan: %w", err))
errCollection.Append(fmt.Errorf("as a %s type: %w", tpe, err))
logger.Debugf(ctx, "Failed to unmarshal file %s for %v type", fname, typeAlt)
return nil, err
}

logger.Debugf(ctx, "Failed to unmarshal file %v for launch plan type", fname)
return nil, fmt.Errorf("failed unmarshalling file %v. Errors: %w", fname, errCollection.ErrorOrDefault())
func reorderUnMarshalFuncs(fname string) []unMarshalFunc {
if match := fnameRegex.FindStringSubmatch(fname); match != nil {
indexStr := match[fnameRegex.SubexpIndex("index")]
index, err := strconv.Atoi(indexStr)
if err != nil {
panic(fmt.Sprintf("unexpected error when coverting [%s] to int, file name [%s]", indexStr, fname))
}

var reordered []unMarshalFunc
for i, f := range unMarshalFuncs {
if i == index-1 {
reordered = append([]unMarshalFunc{f}, reordered...)
} else {
reordered = append(reordered, f)
}
}
return reordered
}

return unMarshalFuncs
}

func register(ctx context.Context, message proto.Message, cmdCtx cmdCore.CommandContext, dryRun, enableSchedule bool) error {
Expand Down

0 comments on commit 1eda51e

Please sign in to comment.