Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix map task cache misses #363

Merged
merged 5 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 // indirect
github.com/aws/smithy-go v1.1.0 // indirect
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3E
github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/aws/smithy-go v1.1.0 h1:D6CSsM3gdxaGaqXnPgOBCeL6Mophqzu7KJOu7zW78sU=
github.com/aws/smithy-go v1.1.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 h1:VRtJdDi2lqc3MFwmouppm2jlm6icF+7H3WYKpLENMTo=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand Down
28 changes: 21 additions & 7 deletions go/tasks/pluginmachinery/catalog/async_client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ import (
"hash/fnv"
"reflect"

"github.com/flyteorg/flytestdlib/promutils"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flytestdlib/bitarray"

"github.com/flyteorg/flytestdlib/errors"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flytestdlib/promutils"
)

const specialEncoderKey = "abcdefghijklmnopqrstuvwxyz123456"
Expand Down Expand Up @@ -41,6 +39,18 @@ func consistentHash(str string) (string, error) {
return base32Encoder.EncodeToString(b), nil
}

func hashInputs(ctx context.Context, key Key) (string, error) {
inputs := &core.LiteralMap{}
if key.TypedInterface.Inputs != nil {
retInputs, err := key.InputReader.Get(ctx)
if err != nil {
return "", err
}
inputs = retInputs
}
return HashLiteralMap(ctx, inputs)
}

func (c AsyncClientImpl) Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) {
status := ResponseStatusReady
cachedResults := bitarray.NewBitSet(uint(len(requests)))
Expand Down Expand Up @@ -95,8 +105,12 @@ func (c AsyncClientImpl) Upload(ctx context.Context, requests ...UploadRequest)
status := ResponseStatusReady
var respErr error
for idx, request := range requests {
workItemID := formatWorkItemID(request.Key, idx, "")
err := c.Writer.Queue(ctx, workItemID, NewWriterWorkItem(
inputHash, err := hashInputs(ctx, request.Key)
if err != nil {
return nil, errors.Wrapf(ErrSystemError, err, "Failed to hash inputs for item: %v", request.Key)
}
workItemID := formatWorkItemID(request.Key, idx, inputHash)
err = c.Writer.Queue(ctx, workItemID, NewWriterWorkItem(
request.Key,
request.ArtifactData,
request.ArtifactMetadata))
Expand Down
98 changes: 91 additions & 7 deletions go/tasks/pluginmachinery/catalog/async_client_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,62 @@ import (
"reflect"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks"
"github.com/flyteorg/flytestdlib/bitarray"
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
)

var exampleInterface = &core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"a": {
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_INTEGER,
},
},
},
},
},
}
var input1 = &core.LiteralMap{
Literals: map[string]*core.Literal{
"a": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
},
},
}
var input2 = &core.LiteralMap{
Literals: map[string]*core.Literal{
"a": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 2,
},
},
},
},
},
},
},
}

func TestAsyncClientImpl_Download(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -61,24 +109,50 @@ func TestAsyncClientImpl_Download(t *testing.T) {
func TestAsyncClientImpl_Upload(t *testing.T) {
ctx := context.Background()

inputHash1 := "{UNSPECIFIED {} [] 0}:-0-DNhkpTTPC5YDtRGb4yT-PFxgMSgHzHrKAQKgQGEfGRY"
inputHash2 := "{UNSPECIFIED {} [] 0}:-1-26M4dwarvBVJqJSUC4JC1GtRYgVBIAmQfsFSdLVMlAc"

q := &mocks.IndexedWorkQueue{}
info := &mocks.WorkItemInfo{}
info.OnItem().Return(NewReaderWorkItem(Key{}, &mocks2.OutputWriter{}))
info.OnStatus().Return(workqueue.WorkStatusSucceeded)
q.OnGet("{UNSPECIFIED {} [] 0}:-0-").Return(info, true, nil)
q.OnGet(inputHash1).Return(info, true, nil)
q.OnGet(inputHash2).Return(info, true, nil)
q.OnQueueMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

inputReader1 := &mocks2.InputReader{}
inputReader1.OnGetMatch(mock.Anything).Return(input1, nil)
inputReader2 := &mocks2.InputReader{}
inputReader2.OnGetMatch(mock.Anything).Return(input2, nil)

tests := []struct {
name string
requests []UploadRequest
wantPutFuture UploadFuture
wantErr bool
}{
{"UploadSucceeded", []UploadRequest{
{
Key: Key{},
{
"UploadSucceeded",
// The second request has the same Key.Identifier and Key.Cache version but a different
// Key.InputReader. This should lead to a different WorkItemID in the queue.
// See https://github.com/flyteorg/flyte/issues/3787 for more details
[]UploadRequest{
{
Key: Key{
TypedInterface: *exampleInterface,
InputReader: inputReader1,
},
},
{
Key: Key{
TypedInterface: *exampleInterface,
InputReader: inputReader2,
},
},
},
}, newUploadFuture(ResponseStatusReady, nil), false},
newUploadFuture(ResponseStatusReady, nil),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -93,6 +167,16 @@ func TestAsyncClientImpl_Upload(t *testing.T) {
if !reflect.DeepEqual(gotPutFuture, tt.wantPutFuture) {
t.Errorf("AsyncClientImpl.Sidecar() = %v, want %v", gotPutFuture, tt.wantPutFuture)
}
expectedWorkItemIDs := []string{inputHash1, inputHash2}
gottenWorkItemIDs := make([]string, 0)
for _, mockCall := range q.Calls {
if mockCall.Method == "Get" {
gottenWorkItemIDs = append(gottenWorkItemIDs, mockCall.Arguments[0].(string))
}
}
if !reflect.DeepEqual(gottenWorkItemIDs, expectedWorkItemIDs) {
t.Errorf("Retrieved workitem IDs = %v, want %v", gottenWorkItemIDs, expectedWorkItemIDs)
}
})
}
}
Expand Down
78 changes: 78 additions & 0 deletions go/tasks/pluginmachinery/catalog/hashing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package catalog

import (
"context"
"encoding/base64"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/pbhash"
)

var emptyLiteralMap = core.LiteralMap{Literals: map[string]*core.Literal{}}

// Hashify a literal, in other words, produce a new literal where the corresponding value is removed in case
// the literal hash is set.
func hashify(literal *core.Literal) *core.Literal {
// Two recursive cases:
// 1. A collection of literals or
// 2. A map of literals

if literal.GetCollection() != nil {
literals := literal.GetCollection().Literals
literalsHash := make([]*core.Literal, 0)
for _, lit := range literals {
literalsHash = append(literalsHash, hashify(lit))
}
return &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: literalsHash,
},
},
}
}
if literal.GetMap() != nil {
literalsMap := make(map[string]*core.Literal)
for key, lit := range literal.GetMap().Literals {
literalsMap[key] = hashify(lit)
}
return &core.Literal{
Value: &core.Literal_Map{
Map: &core.LiteralMap{
Literals: literalsMap,
},
},
}
}

// And a base case that consists of a scalar, where the hash might be set
if literal.GetHash() != "" {
return &core.Literal{
Hash: literal.GetHash(),
}
}
return literal
}

func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, error) {
if literalMap == nil || len(literalMap.Literals) == 0 {
literalMap = &emptyLiteralMap
}

// Hashify, i.e. generate a copy of the literal map where each literal value is removed
// in case the corresponding hash is set.
hashifiedLiteralMap := make(map[string]*core.Literal, len(literalMap.Literals))
for name, literal := range literalMap.Literals {
hashifiedLiteralMap[name] = hashify(literal)
}
hashifiedInputs := &core.LiteralMap{
Literals: hashifiedLiteralMap,
}

inputsHash, err := pbhash.ComputeHash(ctx, hashifiedInputs)
if err != nil {
return "", err
}

return base64.RawURLEncoding.EncodeToString(inputsHash), nil
}
Loading