diff --git a/dataproxy/service.go b/dataproxy/service.go index 89db6dc62..5ab9e32a8 100644 --- a/dataproxy/service.go +++ b/dataproxy/service.go @@ -264,9 +264,12 @@ func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID Limit: 1, Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)), }) - if err != nil || len(taskExecs.TaskExecutions) == 0 { + if err != nil { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to list task executions [%v]. Error: %v", nodeExecID, err) } + if len(taskExecs.TaskExecutions) == 0 { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "no task executions were listed [%v]. Error: %v", nodeExecID, err) + } taskExec := taskExecs.TaskExecutions[0] return taskExec.Id, nil } diff --git a/dataproxy/service_test.go b/dataproxy/service_test.go index 87dcb938f..db1c0e61d 100644 --- a/dataproxy/service_test.go +++ b/dataproxy/service_test.go @@ -15,10 +15,10 @@ import ( commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" stdlibConfig "github.com/flyteorg/flytestdlib/config" - "google.golang.org/protobuf/types/known/durationpb" - + "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" + "google.golang.org/protobuf/types/known/durationpb" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -275,3 +275,46 @@ func TestService_GetData(t *testing.T) { assert.Error(t, err) }) } + +func TestService_Error(t *testing.T) { + dataStore := commonMocks.GetMockStorageClient() + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + taskExecutionManager := &mocks.MockTaskExecutionManager{} + s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager) + assert.NoError(t, err) + + t.Run("get a working set of urls without retry attempt", func(t *testing.T) { + taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { + return nil, errors.NewFlyteAdminErrorf(1, "not found") + }) + nodeExecID := core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "dev", + Name: "wfexecid", + }, + } + _, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID) + assert.Error(t, err, "failed to list") + }) + + t.Run("get a working set of urls without retry attempt", func(t *testing.T) { + taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { + return &admin.TaskExecutionList{ + TaskExecutions: nil, + Token: "", + }, nil + }) + nodeExecID := core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "dev", + Name: "wfexecid", + }, + } + _, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID) + assert.Error(t, err, "no task executions") + }) +} diff --git a/pkg/common/flyte_url.go b/pkg/common/flyte_url.go index 34e3c62fc..ec9a3ca08 100644 --- a/pkg/common/flyte_url.go +++ b/pkg/common/flyte_url.go @@ -4,7 +4,6 @@ import ( "fmt" "regexp" "strconv" - "strings" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -23,31 +22,39 @@ const ( ArtifactTypeD // deck ) -var re = regexp.MustCompile("flyte://v1/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(?:/([0-9]+))?/([iod])$") +var re = regexp.MustCompile("flyte://v1/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)(?:/(?P[0-9]+))?/(?P[iod])$") + +func MatchRegex(reg *regexp.Regexp, input string) map[string]string { + names := reg.SubexpNames() + res := reg.FindAllStringSubmatch(input, -1) + if len(res) == 0 { + return nil + } + dict := make(map[string]string, len(names)) + for i := 1; i < len(res[0]); i++ { + dict[names[i]] = res[0][i] + } + return dict +} func ParseFlyteURL(flyteURL string) (core.NodeExecutionIdentifier, *int, ArtifactType, error) { // flyteURL is of the form flyte://v1/project/domain/execution_id/node_id/attempt/[iod] // where i stands for inputs.pb o for outputs.pb and d for the flyte deck // If the retry attempt is missing, the io requested is assumed to be for the node instead of the task execution - re.MatchString(flyteURL) - matches := re.FindStringSubmatch(flyteURL) - if len(matches) != 7 && len(matches) != 6 { - return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, fmt.Errorf("failed to parse flyte url, only %d matches found", len(matches)) - } - proj := matches[1] - domain := matches[2] - executionID := matches[3] - nodeID := matches[4] + matches := MatchRegex(re, flyteURL) + proj := matches["project"] + domain := matches["domain"] + executionID := matches["exec"] + nodeID := matches["node"] var attempt *int // nil means node execution, not a task execution - if len(matches) == 7 && matches[5] != "" { - a, err := strconv.Atoi(matches[5]) + if matches["attempt"] != "" { + a, err := strconv.Atoi(matches["attempt"]) if err != nil { return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, fmt.Errorf("failed to parse attempt, %s", err) } attempt = &a } - ioLower := strings.ToLower(matches[len(matches)-1]) - ioType, err := ArtifactTypeString(ioLower) + ioType, err := ArtifactTypeString(matches["artifactType"]) if err != nil { return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, err } diff --git a/pkg/common/flyte_url_test.go b/pkg/common/flyte_url_test.go index 57c2cb51a..378860cf1 100644 --- a/pkg/common/flyte_url_test.go +++ b/pkg/common/flyte_url_test.go @@ -164,3 +164,11 @@ func TestFlyteURLsFromTaskExecutionID(t *testing.T) { assert.Equal(t, "", urls.GetDeck()) }) } + +func TestMatchRegexDirectly(t *testing.T) { + result := MatchRegex(re, "flyte://v1/fs/dev/abc/n0-dn0-9-n0-n0/i") + assert.Equal(t, "", result["attempt"]) + + result = MatchRegex(re, "flyteff://v2/fs/dfdsaev/abc/n0-dn0-9-n0-n0/i") + assert.Nil(t, result) +}