diff --git a/flytectl/cmd/register/register_util.go b/flytectl/cmd/register/register_util.go index 1cdc7893c5..cc0130d7dd 100644 --- a/flytectl/cmd/register/register_util.go +++ b/flytectl/cmd/register/register_util.go @@ -46,6 +46,12 @@ const registrationVersionPattern = "{{ registration.version }}" // Additional variable define in fast serialized proto that needs to be replace in registration time const registrationRemotePackagePattern = "{{ .remote_package_path }}" +// All supported extensions for compress +var supportedExtensions = []string{".tar", ".tgz", ".tar.gz"} + +// All supported extensions for gzip compress +var validGzipExtensions = []string{".tgz", ".tar.gz"} + type Result struct { Name string Status string @@ -445,15 +451,14 @@ func registerFile(ctx context.Context, fileName, sourceCode string, registerResu func getArchiveReaderCloser(ctx context.Context, ref string) (io.ReadCloser, error) { dataRef := storage.DataReference(ref) scheme, _, key, err := dataRef.Split() - segments := strings.Split(key, ".") - ext := segments[len(segments)-1] if err != nil { return nil, err } var dataRefReaderCloser io.ReadCloser - if ext != "tar" && ext != "tgz" { - return nil, errors.New("only .tar and .tgz extension archives are supported") + isValid, extension := checkSupportedExtensionForCompress(key) + if !isValid { + return nil, errors.New("only .tar, .tar.gz and .tgz extension archives are supported") } if scheme == "http" || scheme == "https" { @@ -464,9 +469,13 @@ func getArchiveReaderCloser(ctx context.Context, ref string) (io.ReadCloser, err if err != nil { return nil, err } - if ext == "tgz" { - if dataRefReaderCloser, err = gzip.NewReader(dataRefReaderCloser); err != nil { - return nil, err + + for _, ext := range validGzipExtensions { + if ext == extension { + if dataRefReaderCloser, err = gzip.NewReader(dataRefReaderCloser); err != nil { + return nil, err + } + break } } return dataRefReaderCloser, err @@ -486,7 +495,8 @@ func getJSONSpec(message proto.Message) string { func filterExampleFromRelease(releases github.RepositoryRelease) []github.ReleaseAsset { var assets []github.ReleaseAsset for _, v := range releases.Assets { - if strings.HasSuffix(*v.Name, ".tgz") { + isValid, _ := checkSupportedExtensionForCompress(*v.Name) + if isValid { assets = append(assets, v) } } @@ -604,3 +614,12 @@ func deprecatedCheck(ctx context.Context) { rconfig.DefaultFilesConfig.K8sServiceAccount = rconfig.DefaultFilesConfig.K8ServiceAccount } } + +func checkSupportedExtensionForCompress(file string) (bool, string) { + for _, extension := range supportedExtensions { + if strings.HasSuffix(file, extension) { + return true, extension + } + } + return false, "" +} diff --git a/flytectl/cmd/register/register_util_test.go b/flytectl/cmd/register/register_util_test.go index 4963ebe4b3..7f784776b3 100644 --- a/flytectl/cmd/register/register_util_test.go +++ b/flytectl/cmd/register/register_util_test.go @@ -163,7 +163,7 @@ func TestGetSortedArchivedInvalidArchiveFileList(t *testing.T) { assert.Equal(t, 0, len(fileList)) assert.True(t, strings.HasPrefix(tmpDir, "/tmp/register")) assert.NotNil(t, err) - assert.Equal(t, errors.New("only .tar and .tgz extension archives are supported"), err) + assert.Equal(t, errors.New("only .tar, .tar.gz and .tgz extension archives are supported"), err) // Clean up the temp directory. assert.Nil(t, os.RemoveAll(tmpDir), "unable to delete temp dir %v", tmpDir) }