Skip to content

Commit

Permalink
#minor Pluginmachinery support for Intra-task checkpointing (flyteorg…
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 authored Nov 24, 2021
1 parent 984071a commit 6771e7d
Show file tree
Hide file tree
Showing 30 changed files with 627 additions and 130 deletions.
4 changes: 2 additions & 2 deletions go/tasks/pluginmachinery/core/mocks/fake_k8s_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (m *FakeKubeCache) Get(ctx context.Context, key client.ObjectKey, out clien
item, found := m.Cache[formatKey(key, out.GetObjectKind().GroupVersionKind())]
if found {
// deep copy to avoid mutating cache
item = item.(runtime.Object).DeepCopyObject()
item = item.DeepCopyObject()
_, isUnstructured := out.(*unstructured.Unstructured)
if isUnstructured {
// Copy the value of the item in the cache to the returned value
Expand Down Expand Up @@ -96,7 +96,7 @@ func (m *FakeKubeCache) List(ctx context.Context, list client.ObjectList, opts .
}
}

objs = append(objs, val.(runtime.Object).DeepCopyObject())
objs = append(objs, val.DeepCopyObject())
}

return meta.SetList(list, objs)
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/pluginmachinery/core/mocks/fake_k8s_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (m *FakeKubeClient) Get(ctx context.Context, key client.ObjectKey, out clie
item, found := m.Cache[formatKey(key, out.GetObjectKind().GroupVersionKind())]
if found {
// deep copy to avoid mutating cache
item = item.(runtime.Object).DeepCopyObject()
item = item.DeepCopyObject()
_, isUnstructured := out.(*unstructured.Unstructured)
if isUnstructured {
// Copy the value of the item in the cache to the returned value
Expand Down Expand Up @@ -79,7 +79,7 @@ func (m *FakeKubeClient) List(ctx context.Context, list client.ObjectList, opts
}
}

objs = append(objs, val.(runtime.Object).DeepCopyObject())
objs = append(objs, val.DeepCopyObject())
}

return meta.SetList(list, objs)
Expand Down
60 changes: 40 additions & 20 deletions go/tasks/pluginmachinery/core/template/template.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,52 @@
// Package template exports the Render method
// Render Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive
// Supported templates are:
// - {{ .InputFile }} to receive the input file path. The protocol used will depend on the underlying system
// configuration. E.g. s3://bucket/key/to/file.pb or /var/run/local.pb are both valid.
// - {{ .OutputPrefix }} to receive the path prefix for where to store the outputs.
// - {{ .Inputs.myInput }} to receive the actual value of the input passed. See docs on LiteralMapToTemplateArgs for how
// what to expect each literal type to be serialized as.
// - {{ .RawOutputDataPrefix }} to receive a path where the raw output data should be ideally written. It is guaranteed
// to be unique per retry and finally one will be saved as the output path
// - {{ .PerRetryUniqueKey }} A key/id/str that is generated per retry and is guaranteed to be unique. Useful in query
// manipulations
// - {{ .TaskTemplatePath }} A path in blobstore/metadata store (e.g. s3, gcs etc) to where an offloaded version of the
// task template exists and can be accessed by the container / task execution environment. The template is a
// a serialized protobuf
// - {{ .PrevCheckpointPrefix }} A path to the checkpoint directory for the previous attempt. If this is the first attempt
// then this is replaced by an empty string
// - {{ .CheckpointOutputPrefix }} A Flyte aware path where the current execution should write the checkpoints.
package template

import (
"context"
"fmt"
"reflect"
"regexp"
"strings"

"github.com/flyteorg/flytestdlib/logger"

"reflect"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flytestdlib/logger"
"github.com/golang/protobuf/ptypes"
"github.com/pkg/errors"
)

var alphaNumericOnly = regexp.MustCompile("[^a-zA-Z0-9_]+")
var startsWithAlpha = regexp.MustCompile("^[^a-zA-Z_]+")

// Regexes for Supported templates
var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`)
var inputPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]InputPrefix\s*}}`)
var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`)
var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P<input_name>[^}\s]+)\s*}}`)
var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`)
var perRetryUniqueKey = regexp.MustCompile(`(?i){{\s*[\.$]PerRetryUniqueKey\s*}}`)
var taskTemplateRegex = regexp.MustCompile(`(?i){{\s*[\.$]TaskTemplatePath\s*}}`)
var prevCheckpointPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]PrevCheckpointPrefix\s*}}`)
var currCheckpointPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]CheckpointOutputPrefix\s*}}`)

type ErrorCollection struct {
Errors []error
}
Expand All @@ -33,22 +60,17 @@ func (e ErrorCollection) Error() string {
return sb.String()
}

// The Parameters struct is used by the Templating Engine to replace the templated parameters
// Parameters struct is used by the Templating Engine to replace the templated parameters
type Parameters struct {
TaskExecMetadata core.TaskExecutionMetadata
Inputs io.InputReader
OutputPath io.OutputFilePaths
Task core.TaskTemplatePath
}

// Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive
// Supported templates are:
// - {{ .InputFile }} to receive the input file path. The protocol used will depend on the underlying system
// configuration. E.g. s3://bucket/key/to/file.pb or /var/run/local.pb are both valid.
// - {{ .OutputPrefix }} to receive the path prefix for where to store the outputs.
// - {{ .Inputs.myInput }} to receive the actual value of the input passed. See docs on LiteralMapToTemplateArgs for how
// what to expect each literal type to be serialized as.
// Render Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive
// If a command isn't a valid template or failed to evaluate, it'll be returned as is.
// Refer to the package docs for a list of supported templates
// NOTE: I wanted to do in-place replacement, until I realized that in-place replacement will alter the definition of the
// graph. This is not desirable, as we may have to retry and in that case the replacement will not work and we want
// to create a new location for outputs
Expand Down Expand Up @@ -79,20 +101,18 @@ func Render(ctx context.Context, inputTemplate []string, params Parameters) ([]s
return res, nil
}

var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`)
var inputPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]InputPrefix\s*}}`)
var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`)
var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P<input_name>[^}\s]+)\s*}}`)
var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`)
var perRetryUniqueKey = regexp.MustCompile(`(?i){{\s*[\.$]PerRetryUniqueKey\s*}}`)
var taskTemplateRegex = regexp.MustCompile(`(?i){{\s*[\.$]TaskTemplatePath\s*}}`)

func render(ctx context.Context, inputTemplate string, params Parameters, perRetryKey string) (string, error) {

val := inputFileRegex.ReplaceAllString(inputTemplate, params.Inputs.GetInputPath().String())
val = outputRegex.ReplaceAllString(val, params.OutputPath.GetOutputPrefixPath().String())
val = inputPrefixRegex.ReplaceAllString(val, params.Inputs.GetInputPrefixPath().String())
val = rawOutputDataPrefixRegex.ReplaceAllString(val, params.OutputPath.GetRawOutputPrefix().String())
prevCheckpoint := params.OutputPath.GetPreviousCheckpointsPrefix().String()
if prevCheckpoint == "" {
prevCheckpoint = "\"\""
}
val = prevCheckpointPrefixRegex.ReplaceAllString(val, prevCheckpoint)
val = currCheckpointPrefixRegex.ReplaceAllString(val, params.OutputPath.GetCheckpointPrefix().String())
val = perRetryUniqueKey.ReplaceAllString(val, perRetryKey)

// For Task template, we will replace only if there is a match. This is because, task template replacement
Expand Down
126 changes: 126 additions & 0 deletions go/tasks/pluginmachinery/core/template/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ func (d dummyInputReader) Get(ctx context.Context) (*core.LiteralMap, error) {
type dummyOutputPaths struct {
outputPath storage.DataReference
rawOutputDataPrefix storage.DataReference
prevCheckpointPath storage.DataReference
checkpointPath storage.DataReference
}

func (d dummyOutputPaths) GetPreviousCheckpointsPrefix() storage.DataReference {
return d.prevCheckpointPath
}

func (d dummyOutputPaths) GetCheckpointPrefix() storage.DataReference {
return d.checkpointPath
}

func (d dummyOutputPaths) GetRawOutputPrefix() storage.DataReference {
Expand Down Expand Up @@ -443,6 +453,122 @@ func TestReplaceTemplateCommandArgs(t *testing.T) {
}, params)
assert.Error(t, err)
})

t.Run("missing checkpoint args", func(t *testing.T) {
params := Parameters{
TaskExecMetadata: taskMetadata,
Inputs: in,
OutputPath: dummyOutputPaths{
outputPath: out.outputPath,
rawOutputDataPrefix: out.rawOutputDataPrefix,
prevCheckpointPath: "s3://prev-checkpoint/prefix",
checkpointPath: "s3://new-checkpoint/prefix",
},
}
actual, err := Render(context.TODO(), []string{
"hello",
"world",
"{{ .Input }}",
"{{ .OutputPrefix }}",
}, params)
assert.NoError(t, err)
assert.Equal(t, []string{
"hello",
"world",
"input/blah",
"output/blah",
}, actual)
})

t.Run("no prev checkpoint", func(t *testing.T) {
params := Parameters{
TaskExecMetadata: taskMetadata,
Inputs: in,
OutputPath: dummyOutputPaths{
outputPath: out.outputPath,
rawOutputDataPrefix: out.rawOutputDataPrefix,
prevCheckpointPath: "",
checkpointPath: "s3://new-checkpoint/prefix",
},
}
actual, err := Render(context.TODO(), []string{
"hello",
"world",
"{{ .Input }}",
"{{ .OutputPrefix }}",
"--prev={{ .PrevCheckpointPrefix }}",
"--checkpoint={{ .CheckpointOutputPrefix }}",
}, params)
assert.NoError(t, err)
assert.Equal(t, []string{
"hello",
"world",
"input/blah",
"output/blah",
"--prev=\"\"",
"--checkpoint=s3://new-checkpoint/prefix",
}, actual)
})

t.Run("all checkpoints", func(t *testing.T) {
params := Parameters{
TaskExecMetadata: taskMetadata,
Inputs: in,
OutputPath: dummyOutputPaths{
outputPath: out.outputPath,
rawOutputDataPrefix: out.rawOutputDataPrefix,
prevCheckpointPath: "s3://prev-checkpoint/prefix",
checkpointPath: "s3://new-checkpoint/prefix",
},
}
actual, err := Render(context.TODO(), []string{
"hello",
"world",
"{{ .Input }}",
"{{ .OutputPrefix }}",
"--prev={{ .PrevCheckpointPrefix }}",
"--checkpoint={{ .CheckpointOutputPrefix }}",
}, params)
assert.NoError(t, err)
assert.Equal(t, []string{
"hello",
"world",
"input/blah",
"output/blah",
"--prev=s3://prev-checkpoint/prefix",
"--checkpoint=s3://new-checkpoint/prefix",
}, actual)
})

t.Run("all checkpoints ignore case", func(t *testing.T) {
params := Parameters{
TaskExecMetadata: taskMetadata,
Inputs: in,
OutputPath: dummyOutputPaths{
outputPath: out.outputPath,
rawOutputDataPrefix: out.rawOutputDataPrefix,
prevCheckpointPath: "s3://prev-checkpoint/prefix",
checkpointPath: "s3://new-checkpoint/prefix",
},
}
actual, err := Render(context.TODO(), []string{
"hello",
"world",
"{{ .Input }}",
"{{ .OutputPrefix }}",
"--prev={{ .prevcheckpointprefix }}",
"--checkpoint={{ .checkpointoutputprefix }}",
}, params)
assert.NoError(t, err)
assert.Equal(t, []string{
"hello",
"world",
"input/blah",
"output/blah",
"--prev=s3://prev-checkpoint/prefix",
"--checkpoint=s3://new-checkpoint/prefix",
}, actual)
})
}

func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/container_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ func getTemplateParametersForTest(resourceRequirements, platformResources *v1.Re
mockOutputPathPrefix := storage.DataReference("s3://output/path")
mockOutputPath.OnGetRawOutputPrefix().Return(mockOutputPathPrefix)
mockOutputPath.OnGetOutputPrefixPath().Return(mockOutputPathPrefix)
mockOutputPath.OnGetCheckpointPrefix().Return("/checkpoint")
mockOutputPath.OnGetPreviousCheckpointsPrefix().Return("/prev")

return template.Parameters{
TaskExecMetadata: &mockTaskExecMetadata,
Expand Down
2 changes: 2 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func dummyExecContext(r *v1.ResourceRequirements) pluginsCore.TaskExecutionConte
ow := &pluginsIOMock.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("")
ow.OnGetRawOutputPrefix().Return("")
ow.OnGetCheckpointPrefix().Return("/checkpoint")
ow.OnGetPreviousCheckpointsPrefix().Return("/prev")

tCtx := &pluginsCoreMock.TaskExecutionContext{}
tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(r))
Expand Down
Loading

0 comments on commit 6771e7d

Please sign in to comment.