Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Address resolution #546

Merged
merged 16 commits into from
May 5, 2023
115 changes: 114 additions & 1 deletion dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@ import (
"encoding/base32"
"encoding/base64"
"fmt"
"github.com/flyteorg/flyteadmin/pkg/common"
"net/url"
"reflect"
"strconv"
"strings"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/logger"

"github.com/flyteorg/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"

Expand Down Expand Up @@ -37,6 +43,7 @@ type Service struct {
dataStore *storage.DataStore
shardSelector ioutils.ShardSelector
nodeExecutionManager interfaces.NodeExecutionInterface
taskExecutionManager interfaces.TaskExecutionInterface
}

// CreateUploadLocation creates a temporary signed url to allow callers to upload content.
Expand Down Expand Up @@ -231,9 +238,114 @@ func createStorageLocation(ctx context.Context, store *storage.DataStore,
return storagePath, nil
}

func (s Service) validateResolveArtifactRequest(req *service.GetDataRequest) error {
if req.GetFlyteUrl() == "" {
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("source is required. Provided empty string")
}
if !strings.HasPrefix(req.GetFlyteUrl(), "flyte://") {
return fmt.Errorf("request does not start with the correct prefix")
}

return nil
}

func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: &nodeExecID,
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strongly typed way of doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really... i think we should punt on this.

})
if err != nil || len(taskExecs.TaskExecutions) == 0 {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to list task executions [%v]. Error: %v", nodeExecID, err)
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
}
taskExec := taskExecs.TaskExecutions[0]
return taskExec.Id, nil
}

func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
*service.GetDataResponse, error) {

logger.Debugf(ctx, "resolving flyte url query: %s", req.GetFlyteUrl())
err := s.validateResolveArtifactRequest(req)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to validate resolve artifact request. Error: %v", err)
}

nodeExecID, attempt, ioType, err := common.ParseFlyteURL(req.GetFlyteUrl())
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse artifact url Error: %v", err)
}

// Get the data location, then decide how/where to fetch it from
if attempt == nil {
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
})
if err != nil {
return nil, err
}

var lm *core.LiteralMap
if ioType == common.INPUT {
lm = resp.FullInputs
} else if ioType == common.OUTPUT {
lm = resp.FullOutputs
} else {
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
return nil, err
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_FlyteDeckDownloadLink{
FlyteDeckDownloadLink: resp,
},
}, nil
}

return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}
// Rest of the logic handles task attempt lookups
var lm *core.LiteralMap
taskExecID, err := s.GetTaskExecutionID(ctx, *attempt, nodeExecID)
if err != nil {
return nil, err
}

reqT := admin.TaskExecutionGetDataRequest{
Id: taskExecID,
}
resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT)
if err != nil {
return nil, err
}

if ioType == common.INPUT {
lm = resp.FullInputs
} else if ioType == common.OUTPUT {
lm = resp.FullOutputs
} else {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "deck type cannot be specified with a retry attempt, just use the node instead")
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}

func NewService(cfg config.DataProxyConfig,
nodeExec interfaces.NodeExecutionInterface,
dataStore *storage.DataStore) (Service, error) {
dataStore *storage.DataStore,
taskExec interfaces.TaskExecutionInterface) (Service, error) {

// Context is not used in the constructor. Should ideally be removed.
selector, err := ioutils.NewBase36PrefixShardSelector(context.TODO())
Expand All @@ -246,5 +358,6 @@ func NewService(cfg config.DataProxyConfig,
dataStore: dataStore,
shardSelector: selector,
nodeExecutionManager: nodeExec,
taskExecutionManager: taskExec,
}, nil
}
28 changes: 24 additions & 4 deletions dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package dataproxy

import (
"context"
"fmt"
"github.com/flyteorg/flyteadmin/pkg/common"
"testing"
"time"

Expand Down Expand Up @@ -32,9 +34,10 @@ func TestNewService(t *testing.T) {
assert.NoError(t, err)

nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{
Upload: config.DataProxyUploadConfig{},
}, nodeExecutionManager, dataStore)
}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
assert.NotNil(t, s)
}
Expand All @@ -57,7 +60,8 @@ func TestCreateUploadLocation(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore)
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
t.Run("No project/domain", func(t *testing.T) {
_, err = s.CreateUploadLocation(context.Background(), &service.CreateUploadLocationRequest{})
Expand Down Expand Up @@ -92,8 +96,9 @@ func TestCreateDownloadLink(t *testing.T) {
},
}, nil
})
taskExecutionManager := &mocks.MockTaskExecutionManager{}

s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
Expand Down Expand Up @@ -128,7 +133,8 @@ func TestCreateDownloadLink(t *testing.T) {
func TestCreateDownloadLocation(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
Expand Down Expand Up @@ -161,3 +167,17 @@ func TestCreateDownloadLocation(t *testing.T) {
assert.NoError(t, err)
})
}

func TestParseFlyteUrl(t *testing.T) {
t.Run("valid", func(t *testing.T) {
ne, attempt, kind, err := common.ParseFlyteURL("flyte://v1/fs/dev/abc/n0/0/o")
assert.NoError(t, err)
fmt.Println(ne, attempt, kind, err)
ne, attempt, kind, err = common.ParseFlyteURL("flyte://v1/fs/dev/abc/n0/i")
assert.NoError(t, err)
fmt.Println(ne, attempt, kind, err)
ne, attempt, kind, err = common.ParseFlyteURL("flyte://v1/fs/dev/abc/n0/d")
assert.NoError(t, err)
fmt.Println(ne, attempt, kind, err)
})
}
4 changes: 2 additions & 2 deletions flyteadmin_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ externalEvents:
eventTypes: all
Logger:
show-source: true
level: 6
level: 5
storage:
type: stow
stow:
Expand All @@ -129,7 +129,7 @@ storage:
secret_key: miniostorage
signedUrl:
stowConfigOverride:
endpoint: http://localhost:30084
endpoint: http://localhost:30002
cache:
max_size_mbs: 10
target_gc_percent: 100
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,4 @@ require (
)

replace github.com/robfig/cron/v3 => github.com/unionai/cron/v3 v3.0.2-0.20210825070134-bfc34418fe84
replace github.com/flyteorg/flyteidl => ../flyteidl
92 changes: 92 additions & 0 deletions pkg/common/flyte_url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package common

import (
"fmt"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"regexp"
"strconv"
)

type IOType string
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved

const (
UndefinedIo IOType = ""
INPUT = "i"
OUTPUT = "o"
DECK = "d"
)

func ParseFlyteURL(flyteURL string) (core.NodeExecutionIdentifier, *int, IOType, error) {
// flyteURL is of the form flyte://v1/project/domain/execution_id/node_id/attempt/[iod]
// where i stands for inputs.pb o for outputs.pb and d for the flyte deck
// If the retry attempt is missing, the io requested is assumed to be for the node instead of the task execution
zero := 0
re, err := regexp.Compile("flyte://v1/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(?:/([0-9]+))?/([iod])")
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return core.NodeExecutionIdentifier{}, &zero, UndefinedIo, err
}
re.MatchString(flyteURL)
matches := re.FindStringSubmatch(flyteURL)
if len(matches) != 7 && len(matches) != 6 {
return core.NodeExecutionIdentifier{}, &zero, UndefinedIo, fmt.Errorf("failed to parse flyte url, only %d matches found", len(matches))
}
proj := matches[1]
domain := matches[2]
executionID := matches[3]
nodeID := matches[4]
var attempt *int // nil means node execution, not a task execution
if len(matches) == 7 && matches[5] != "" {
a, err := strconv.Atoi(matches[5])
if err != nil {
return core.NodeExecutionIdentifier{}, &zero, UndefinedIo, fmt.Errorf("failed to parse attempt, %s", err)
}
attempt = &a
}
var ioType IOType
switch matches[len(matches)-1] {
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
case "i":
ioType = INPUT
case "o":
ioType = OUTPUT
case "d":
ioType = DECK
}

return core.NodeExecutionIdentifier{
NodeId: nodeID,
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: proj,
Domain: domain,
Name: executionID,
},
}, attempt, ioType, nil
}

func FlyteURLsFromNodeExecutionID(nodeExecutionID core.NodeExecutionIdentifier, deck bool) *admin.FlyteURLs {
base := fmt.Sprintf("flyte://v1/%s/%s/%s/%s", nodeExecutionID.ExecutionId.Project,
nodeExecutionID.ExecutionId.Domain, nodeExecutionID.ExecutionId.Name, nodeExecutionID.NodeId)

res := &admin.FlyteURLs{
Inputs: fmt.Sprintf("%s/%s", base, INPUT),
Outputs: fmt.Sprintf("%s/%s", base, OUTPUT),
}
if deck {
res.Deck = fmt.Sprintf("%s/%s", base, DECK)
}
return res
}

func FlyteURLsFromTaskExecutionID(taskExecutionID core.TaskExecutionIdentifier, deck bool) *admin.FlyteURLs {
base := fmt.Sprintf("flyte://v1/%s/%s/%s/%s/%s", taskExecutionID.NodeExecutionId.ExecutionId.Project,
taskExecutionID.NodeExecutionId.ExecutionId.Domain, taskExecutionID.NodeExecutionId.ExecutionId.Name, taskExecutionID.NodeExecutionId.NodeId, strconv.Itoa(int(taskExecutionID.RetryAttempt)))

res := &admin.FlyteURLs{
Inputs: fmt.Sprintf("%s/%s", base, INPUT),
Outputs: fmt.Sprintf("%s/%s", base, OUTPUT),
}
if deck {
res.Deck = fmt.Sprintf("%s/%s", base, DECK)
}
return res
}
1 change: 1 addition & 0 deletions pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ func (m *NodeExecutionManager) GetNodeExecutionData(
Outputs: outputURLBlob,
FullInputs: inputs,
FullOutputs: outputs,
FlyteUrls: common.FlyteURLsFromNodeExecutionID(*request.Id, nodeExecution.GetClosure() != nil && nodeExecution.GetClosure().GetDeckUri() != ""),
}

if len(nodeExecutionModel.DynamicWorkflowRemoteClosureReference) > 0 {
Expand Down
1 change: 1 addition & 0 deletions pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ func (m *TaskExecutionManager) GetTaskExecutionData(
Outputs: outputURLBlob,
FullInputs: inputs,
FullOutputs: outputs,
FlyteUrls: common.FlyteURLsFromTaskExecutionID(*request.Id, false),
}

m.metrics.TaskExecutionInputBytes.Observe(float64(response.Inputs.Bytes))
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService())
}

dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient)
dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient, adminServer.TaskExecutionManager)
if err != nil {
return nil, fmt.Errorf("failed to initialize dataProxy service. Error: %w", err)
}
Expand Down