Skip to content

Commit

Permalink
Added support for pyflyte serialize fast register (flyteorg#248)
Browse files Browse the repository at this point in the history
* Added support for pyflyte serialize fast register (flyteorg#239)
Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia authored Dec 29, 2021
1 parent 7bb9759 commit 0686a1c
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 7 deletions.
1 change: 1 addition & 0 deletions cmd/config/subcommand/register/files_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ type FilesConfig struct {
K8ServiceAccount string `json:"k8ServiceAccount" pflag:", deprecated. Please use --K8sServiceAccount"`
OutputLocationPrefix string `json:"outputLocationPrefix" pflag:", custom output location prefix for offloaded types (files/schemas)."`
SourceUploadPath string `json:"sourceUploadPath" pflag:", Location for source code in storage."`
DestinationDirectory string `json:"destinationDirectory" pflag:", Location of source code in container."`
DryRun bool `json:"dryRun" pflag:",execute command without making any modifications."`
}
1 change: 1 addition & 0 deletions cmd/config/subcommand/register/filesconfig_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions cmd/config/subcommand/register/filesconfig_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions cmd/register/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,25 @@ Override IamRole during registration:
::
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 -i "arn:aws:iam::123456789:role/dummy"
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 --assumableIamRole "arn:aws:iam::123456789:role/dummy"
Override Kubernetes service account during registration:
::
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 -k "kubernetes-service-account"
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 --k8sServiceAccount "kubernetes-service-account"
Override Output location prefix during registration:
::
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 -l "s3://dummy/prefix"
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 --outputLocationPrefix "s3://dummy/prefix"
Override Destination dir of source code in container during registration:
::
flytectl register file _pb_output/* -d development -p flytesnacks --continueOnError --version v2 --destinationDirectory "/root"
Usage
`
Expand Down
19 changes: 16 additions & 3 deletions cmd/register/register_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const registrationVersionPattern = "{{ registration.version }}"

// Additional variable define in fast serialized proto that needs to be replace in registration time
const registrationRemotePackagePattern = "{{ .remote_package_path }}"
const registrationDestDirPattern = "{{ .dest_dir }}"

// All supported extensions for compress
var supportedExtensions = []string{".tar", ".tgz", ".tar.gz"}
Expand Down Expand Up @@ -219,16 +220,22 @@ func hydrateIdentifier(identifier *core.Identifier, version string, force bool)
}
}

func hydrateTaskSpec(task *admin.TaskSpec, sourceCode string, sourceUploadPath string, version string) error {
func hydrateTaskSpec(task *admin.TaskSpec, sourceCode, sourceUploadPath, version, destinationDir string) error {
if task.Template.GetContainer() != nil {
for k := range task.Template.GetContainer().Args {
if task.Template.GetContainer().Args[k] == "" || task.Template.GetContainer().Args[k] == registrationRemotePackagePattern {
if task.Template.GetContainer().Args[k] == registrationRemotePackagePattern {
remotePath, err := getRemoteStoragePath(context.Background(), Client, sourceUploadPath, sourceCode, version)
if err != nil {
return err
}
task.Template.GetContainer().Args[k] = string(remotePath)
}
if task.Template.GetContainer().Args[k] == registrationDestDirPattern {
task.Template.GetContainer().Args[k] = "."
if len(destinationDir) > 0 {
task.Template.GetContainer().Args[k] = destinationDir
}
}
}
} else if task.Template.GetK8SPod() != nil && task.Template.GetK8SPod().PodSpec != nil {
var podSpec = v1.PodSpec{}
Expand All @@ -245,6 +252,12 @@ func hydrateTaskSpec(task *admin.TaskSpec, sourceCode string, sourceUploadPath s
}
podSpec.Containers[containerIdx].Args[argIdx] = string(remotePath)
}
if arg == registrationDestDirPattern {
podSpec.Containers[containerIdx].Args[argIdx] = "."
if len(destinationDir) > 0 {
podSpec.Containers[containerIdx].Args[argIdx] = destinationDir
}
}
}
}
podSpecStruct, err := utils.MarshalObjToStruct(podSpec)
Expand Down Expand Up @@ -340,7 +353,7 @@ func hydrateSpec(message proto.Message, sourceCode string, config rconfig.FilesC
taskSpec := message.(*admin.TaskSpec)
hydrateIdentifier(taskSpec.Template.Id, config.Version, config.Force)
// In case of fast serialize input proto also have on additional variable to substitute i.e destination bucket for source code
if err := hydrateTaskSpec(taskSpec, sourceCode, config.SourceUploadPath, config.Version); err != nil {
if err := hydrateTaskSpec(taskSpec, sourceCode, config.SourceUploadPath, config.Version, config.DestinationDirectory); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/register/register_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ func TestHydrateTaskSpec(t *testing.T) {
},
},
}
err = hydrateTaskSpec(task, "sourcey", rconfig.DefaultFilesConfig.SourceUploadPath, rconfig.DefaultFilesConfig.Version)
err = hydrateTaskSpec(task, "sourcey", rconfig.DefaultFilesConfig.SourceUploadPath, rconfig.DefaultFilesConfig.Version, "")
assert.NoError(t, err)
var hydratedPodSpec = v1.PodSpec{}
err = utils.UnmarshalStructToObj(task.Template.GetK8SPod().PodSpec, &hydratedPodSpec)
Expand Down

0 comments on commit 0686a1c

Please sign in to comment.