diff --git a/flytectl/cmd/register/register_util.go b/flytectl/cmd/register/register_util.go index fa51cc29c8..1cdc7893c5 100644 --- a/flytectl/cmd/register/register_util.go +++ b/flytectl/cmd/register/register_util.go @@ -18,6 +18,7 @@ import ( "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/utils" "github.com/google/go-github/github" @@ -34,6 +35,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + v1 "k8s.io/api/core/v1" ) // Variable define in serialized proto that needs to be replace in registration time @@ -222,6 +224,33 @@ func hydrateTaskSpec(task *admin.TaskSpec, sourceCode string) error { task.Template.GetContainer().Args[k] = string(remotePath) } } + } else if task.Template.GetK8SPod() != nil && task.Template.GetK8SPod().PodSpec != nil { + var podSpec = v1.PodSpec{} + err := utils.UnmarshalStructToObj(task.Template.GetK8SPod().PodSpec, &podSpec) + if err != nil { + return err + } + for containerIdx, container := range podSpec.Containers { + for argIdx, arg := range container.Args { + if arg == registrationRemotePackagePattern { + remotePath, err := getRemoteStoragePath(context.Background(), Client, rconfig.DefaultFilesConfig.SourceUploadPath, sourceCode, rconfig.DefaultFilesConfig.Version) + if err != nil { + return err + } + podSpec.Containers[containerIdx].Args[argIdx] = string(remotePath) + } + } + } + podSpecStruct, err := utils.MarshalObjToStruct(podSpec) + if err != nil { + return err + } + task.Template.Target = &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + Metadata: task.Template.GetK8SPod().Metadata, + PodSpec: podSpecStruct, + }, + } } return nil } diff --git a/flytectl/cmd/register/register_util_test.go b/flytectl/cmd/register/register_util_test.go index bb85c1c22f..4963ebe4b3 100644 --- a/flytectl/cmd/register/register_util_test.go +++ b/flytectl/cmd/register/register_util_test.go @@ -9,6 +9,10 @@ import ( "strings" "testing" + "github.com/flyteorg/flytestdlib/utils" + + v1 "k8s.io/api/core/v1" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -397,3 +401,55 @@ func TestHydrateNode(t *testing.T) { assert.NotNil(t, err) }) } + +func TestHydrateTaskSpec(t *testing.T) { + testScope := promutils.NewTestScope() + labeled.SetMetricKeys(contextutils.AppNameKey, contextutils.ProjectKey, contextutils.DomainKey) + s, err := storage.NewDataStore(&storage.Config{ + Type: storage.TypeMemory, + }, testScope.NewSubScope("flytectl")) + assert.Nil(t, err) + Client = s + + metadata := &core.K8SObjectMetadata{ + Labels: map[string]string{ + "l": "a", + }, + Annotations: map[string]string{ + "a": "b", + }, + } + + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + { + Args: []string{"foo", "bar"}, + }, + { + Args: []string{"baz", registrationRemotePackagePattern}, + }, + }, + } + podSpecStruct, err := utils.MarshalObjToStruct(podSpec) + if err != nil { + t.Fatal(err) + } + + task := &admin.TaskSpec{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + Metadata: metadata, + PodSpec: podSpecStruct, + }, + }, + }, + } + err = hydrateTaskSpec(task, "sourcey") + assert.NoError(t, err) + var hydratedPodSpec = v1.PodSpec{} + err = utils.UnmarshalStructToObj(task.Template.GetK8SPod().PodSpec, &hydratedPodSpec) + assert.NoError(t, err) + assert.Len(t, hydratedPodSpec.Containers[1].Args, 2) + assert.True(t, strings.HasSuffix(hydratedPodSpec.Containers[1].Args[1], "sourcey")) +} diff --git a/flytectl/go.sum b/flytectl/go.sum index 475f57a778..28983a22c6 100644 --- a/flytectl/go.sum +++ b/flytectl/go.sum @@ -534,6 +534,7 @@ github.com/hashicorp/go-version v1.3.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=