Skip to content

Commit

Permalink
[Artifacts/Serverless] Read GH handle from grpc metadata and propagat…
Browse files Browse the repository at this point in the history
…e if present (#261)
  • Loading branch information
wild-endeavor authored May 13, 2024
1 parent 997a60d commit 3605e3d
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 73 deletions.
69 changes: 0 additions & 69 deletions flyteadmin/pkg/manager/impl/exec_manager_other_test.go

This file was deleted.

10 changes: 8 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
Project: project,
Domain: domain,
Name: ak.GetName(),
Org: ak.GetOrg(),
},
Partitions: &core.Partitions{
Value: partitions,
Expand All @@ -865,7 +866,7 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
}

// ResolveParameterMapArtifacts will go through the parameter map, and resolve any artifact queries.
func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inputs *core.ParameterMap, inputsForQueryTemplating map[string]*core.Literal) (*core.ParameterMap, []*core.ArtifactID, error) {
func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inputs *core.ParameterMap, inputsForQueryTemplating map[string]*core.Literal, executionOrg string) (*core.ParameterMap, []*core.ArtifactID, error) {

// only top level replace for now. Need to make this recursive
var artifactIDs []*core.ArtifactID
Expand All @@ -874,6 +875,7 @@ func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inp
}
outputs := map[string]*core.Parameter{}

// copy ghHandle into request
for k, v := range inputs.Parameters {
if inputsForQueryTemplating != nil {
if _, ok := inputsForQueryTemplating[k]; ok {
Expand All @@ -895,6 +897,10 @@ func (m *ExecutionManager) ResolveParameterMapArtifacts(ctx context.Context, inp
logger.Errorf(ctx, "Failed to fill in template args for [%s] [%v]", k, err)
return nil, nil, err
}

if filledInQuery.GetArtifactId().GetArtifactKey() != nil {
filledInQuery.GetArtifactId().GetArtifactKey().Org = executionOrg
}
req := &artifactsIdl.GetArtifactRequest{
Query: &filledInQuery,
Details: false,
Expand Down Expand Up @@ -1006,7 +1012,7 @@ func (m *ExecutionManager) launchExecution(
// and so we can fill in template args.
// ArtifactIDs are also returned for lineage purposes.
ctxPD := contextutils.WithProjectDomain(ctx, request.Project, request.Domain)
lpExpectedInputs, usedArtifactIDs, err = m.ResolveParameterMapArtifacts(ctxPD, launchPlan.Closure.ExpectedInputs, inputsForQueryTemplating)
lpExpectedInputs, usedArtifactIDs, err = m.ResolveParameterMapArtifacts(ctxPD, launchPlan.Closure.ExpectedInputs, inputsForQueryTemplating, request.Org)
if err != nil {
logger.Errorf(ctx, "Error looking up launch plan closure parameter map: %v", err)
return nil, nil, err
Expand Down
124 changes: 122 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/flyteorg/flyte/flyteadmin/auth"
"github.com/flyteorg/flyte/flyteadmin/pkg/artifacts"
artifactMocks "github.com/flyteorg/flyte/flyteadmin/pkg/artifacts/mocks"
eventWriterMocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/events/mocks"
notificationMocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks"
"github.com/flyteorg/flyte/flyteadmin/pkg/common"
Expand All @@ -48,6 +49,7 @@ import (
"github.com/flyteorg/flyte/flyteadmin/plugins"
"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
artifactsIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifacts"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/event"
"github.com/flyteorg/flyte/flytestdlib/contextutils"
Expand Down Expand Up @@ -5959,10 +5961,11 @@ func TestQueryTemplate(t *testing.T) {
Name: "testname",
}

akNameOnly := &core.ArtifactKey{
akNamePlusOrg := &core.ArtifactKey{
Project: "",
Domain: "",
Name: "testname",
Org: "my-gh-handle",
}

t.Run("test all present, nothing to fill in", func(t *testing.T) {
Expand All @@ -5975,7 +5978,7 @@ func TestQueryTemplate(t *testing.T) {
q := core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
ArtifactKey: akNameOnly,
ArtifactKey: akNamePlusOrg,
Partitions: p,
TimePartition: nil,
},
Expand All @@ -5985,6 +5988,7 @@ func TestQueryTemplate(t *testing.T) {
filledQuery, err := m.fillInTemplateArgs(ctx, q, otherInputs.Literals)
assert.NoError(t, err)
assert.True(t, proto.Equal(&q, &filledQuery))
assert.Equal(t, "my-gh-handle", filledQuery.GetArtifactId().GetArtifactKey().GetOrg())

q.GetArtifactId().ArtifactKey = ak
ctx = contextutils.WithProjectDomain(ctx, "project", "domain")
Expand Down Expand Up @@ -6080,3 +6084,119 @@ func TestLiteralParsing(t *testing.T) {
})
}
}

func TestResolveNotWorking(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

execManager := NewExecutionManager(nil, nil, mockConfig, nil, mockScope.NewTestScope(), mockScope.NewTestScope(), nil, nil, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil)).(*ExecutionManager)

pm, artifactIDs, err := execManager.ResolveParameterMapArtifacts(context.Background(), nil, nil, "")
assert.Nil(t, err)
fmt.Println(pm, artifactIDs)

}

func TestTrackingBitExtract(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

execManager := NewExecutionManager(nil, nil, mockConfig, nil, mockScope.NewTestScope(), mockScope.NewTestScope(), nil, nil, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}, artifacts.NewArtifactRegistry(context.Background(), nil)).(*ExecutionManager)

lit := core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
Metadata: map[string]string{"_ua": "proj/domain/name@version"},
}
inputMap := core.LiteralMap{
Literals: map[string]*core.Literal{
"a": &lit,
},
}
inputColl := core.LiteralCollection{
Literals: []*core.Literal{
&lit,
},
}

var trackers = make(map[string]string)
execManager.ExtractArtifactTrackers(trackers, &lit)
assert.Equal(t, 1, len(trackers))

trackers = make(map[string]string)
execManager.ExtractArtifactTrackers(trackers, &core.Literal{Value: &core.Literal_Map{Map: &inputMap}})
assert.Equal(t, 1, len(trackers))

trackers = make(map[string]string)
execManager.ExtractArtifactTrackers(trackers, &core.Literal{Value: &core.Literal_Collection{Collection: &inputColl}})
assert.Equal(t, 1, len(trackers))
assert.Equal(t, "", trackers["proj/domain/name@version"])
}

func TestResolveParameterMapArtifacts(t *testing.T) {
ak := &core.ArtifactKey{
Project: "project",
Domain: "domain",
Name: "testname",
}
returnID := &core.ArtifactID{
ArtifactKey: ak,
Version: "abc",
}
one, err := coreutils.MakeLiteral(1)
assert.NoError(t, err)

pMap := map[string]*core.LabelValue{
"partition1": {Value: &core.LabelValue_StaticValue{StaticValue: "my value"}},
"partition2": {Value: &core.LabelValue_StaticValue{StaticValue: "my value 2"}},
}
p := &core.Partitions{Value: pMap}

q := core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
ArtifactKey: ak,
Partitions: p,
},
},
}

inputs := core.ParameterMap{
Parameters: map[string]*core.Parameter{
"input1": {
Var: nil,
Behavior: &core.Parameter_ArtifactQuery{ArtifactQuery: &q},
},
},
}

t.Run("context metadata provides github handle", func(t *testing.T) {
client := artifactMocks.ArtifactRegistryClient{}
client.On("GetArtifact", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
req := args.Get(1).(*artifactsIdl.GetArtifactRequest)
assert.Equal(t, "user-handle", req.GetQuery().GetArtifactId().GetArtifactKey().GetOrg())
}).Return(&artifactsIdl.GetArtifactResponse{Artifact: &artifactsIdl.Artifact{
ArtifactId: returnID,
Spec: &artifactsIdl.ArtifactSpec{
Value: one,
},
}}, nil)

ctx := context.Background()

m := ExecutionManager{
artifactRegistry: &artifacts.ArtifactRegistry{Client: &client},
}

_, x, err := m.ResolveParameterMapArtifacts(ctx, &inputs, nil, "user-handle")
assert.NoError(t, err)
assert.Equal(t, 1, len(x))
})
}

0 comments on commit 3605e3d

Please sign in to comment.