From ea9211fad2cdc18345e94e1341b959cd74b16da6 Mon Sep 17 00:00:00 2001 From: Yuvraj Date: Tue, 21 Dec 2021 10:43:50 +0530 Subject: [PATCH] Added script for checking diff of generated code (#372) * Added ci check for go generate Signed-off-by: Yuvraj Signed-off-by: Katrina Rogan Co-authored-by: Katrina Rogan --- .github/workflows/pull_request.yml | 10 ++ Makefile | 3 - boilerplate/flyte/golang_support_tools/go.mod | 4 +- boilerplate/flyte/golang_support_tools/go.sum | 5 +- .../flyte/golang_test_targets/Makefile | 4 + .../flyte/golang_test_targets/go-gen.sh | 22 ++++ events/admin_eventsink_integration_test.go | 3 +- events/mocks/event_recorder.go | 121 +++++++++++++----- events/mocks/event_sink.go | 80 ++++++++++++ events/mocks/writer.go | 78 +++++++++++ events/node_event_recorder_test.go | 60 ++++----- events/task_event_recorder_test.go | 60 ++++----- events/workflow_event_recorder_test.go | 58 +++++---- go.mod | 1 + .../handler/mocks/node_execution_context.go | 1 - 15 files changed, 385 insertions(+), 125 deletions(-) create mode 100755 boilerplate/flyte/golang_test_targets/go-gen.sh create mode 100644 events/mocks/event_sink.go create mode 100644 events/mocks/writer.go diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index b5cc546af4..93fad8868c 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -128,3 +128,13 @@ jobs: GO111MODULE: "on" with: args: make install && make lint + + generate: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - uses: actions/setup-go@v2 + with: + go-version: '1.16' + - name: Go generate and diff + run: DELTA_CHECK=true make generate \ No newline at end of file diff --git a/Makefile b/Makefile index b2f078723c..0e653fa93b 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,3 @@ golden: go test ./cmd/kubectl-flyte/cmd -update go test ./pkg/compiler/test -update -.PHONY: generate -generate: download_tooling - @go generate ./... diff --git a/boilerplate/flyte/golang_support_tools/go.mod b/boilerplate/flyte/golang_support_tools/go.mod index 53f645159a..4e2e7e8d7c 100644 --- a/boilerplate/flyte/golang_support_tools/go.mod +++ b/boilerplate/flyte/golang_support_tools/go.mod @@ -4,9 +4,9 @@ go 1.16 require ( github.com/alvaroloes/enumer v1.1.2 - github.com/flyteorg/flytestdlib v0.3.22 + github.com/flyteorg/flytestdlib v0.4.7 github.com/golangci/golangci-lint v1.38.0 - github.com/pseudomuto/protoc-gen-doc v1.4.1 // indirect + github.com/pseudomuto/protoc-gen-doc v0.0.0-00010101000000-000000000000 // indirect github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 ) diff --git a/boilerplate/flyte/golang_support_tools/go.sum b/boilerplate/flyte/golang_support_tools/go.sum index 261048f745..a62010d29f 100644 --- a/boilerplate/flyte/golang_support_tools/go.sum +++ b/boilerplate/flyte/golang_support_tools/go.sum @@ -210,8 +210,8 @@ github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flytestdlib v0.3.22 h1:nJEPaCdxzXBaeg2p4fdo3I3Ua09NedFRaUwuLafLEdw= -github.com/flyteorg/flytestdlib v0.3.22/go.mod h1:1XG0DwYTUm34Yrffm1Qy9Tdr/pWQypEqTq5dUxw3/cM= +github.com/flyteorg/flytestdlib v0.4.7 h1:SMPPXI3j/MjP7D2fqaR+lPQkTrqYS7xZbwsgJI2F8SU= +github.com/flyteorg/flytestdlib v0.4.7/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs= github.com/flyteorg/protoc-gen-doc v1.4.2 h1:Otw0F+RHaPQ8XlpzhLLgjsCMcrAIcMO01Zh+ALe3rrE= github.com/flyteorg/protoc-gen-doc v1.4.2/go.mod h1:exDTOVwqpp30eV/EDPFLZy3Pwr2sn6hBC1WIYH/UbIg= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= @@ -1263,6 +1263,7 @@ honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.1.2 h1:SMdYLJl312RXuxXziCCHhRsp/tvct9cGKey0yv95tZM= honnef.co/go/tools v0.1.2/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= +k8s.io/api v0.20.2/go.mod h1:d7n6Ehyzx+S+cE3VhTGfVNNqtGc/oL9DCdYYahlurV8= k8s.io/apimachinery v0.0.0-20210217011835-527a61b4dffe/go.mod h1:Z7ps/g0rjlTeMstYrMOUttJfT2Gg34DEaG/f2PYLCWY= k8s.io/apimachinery v0.20.2 h1:hFx6Sbt1oG0n6DZ+g4bFt5f6BoMkOjKWsQFu077M3Vg= k8s.io/apimachinery v0.20.2/go.mod h1:WlLqWAHZGg07AeltaI0MV5uk1Omp8xaN0JGLY6gkRpU= diff --git a/boilerplate/flyte/golang_test_targets/Makefile b/boilerplate/flyte/golang_test_targets/Makefile index 21d8b5b776..280e1e55e4 100644 --- a/boilerplate/flyte/golang_test_targets/Makefile +++ b/boilerplate/flyte/golang_test_targets/Makefile @@ -8,6 +8,10 @@ download_tooling: #download dependencies (including test deps) for the package @boilerplate/flyte/golang_test_targets/download_tooling.sh +.PHONY: generate +generate: download_tooling #generate go code + @boilerplate/flyte/golang_test_targets/go-gen.sh + .PHONY: lint lint: download_tooling #lints the package for common code smells GL_DEBUG=linters_output,env golangci-lint run --deadline=5m --exclude deprecated -v diff --git a/boilerplate/flyte/golang_test_targets/go-gen.sh b/boilerplate/flyte/golang_test_targets/go-gen.sh new file mode 100755 index 0000000000..54bd6af61b --- /dev/null +++ b/boilerplate/flyte/golang_test_targets/go-gen.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -ex + +echo "Running go generate" +go generate ./... + +# This section is used by GitHub workflow to ensure that the generation step was run +if [ -n "$DELTA_CHECK" ]; then + DIRTY=$(git status --porcelain) + if [ -n "$DIRTY" ]; then + echo "FAILED: Go code updated without commiting generated code." + echo "Ensure make generate has run and all changes are committed." + DIFF=$(git diff) + echo "diff detected: $DIFF" + DIFF=$(git diff --name-only) + echo "files different: $DIFF" + exit 1 + else + echo "SUCCESS: Generated code is up to date." + fi +fi diff --git a/events/admin_eventsink_integration_test.go b/events/admin_eventsink_integration_test.go index 97d4ad1774..44c23a3448 100644 --- a/events/admin_eventsink_integration_test.go +++ b/events/admin_eventsink_integration_test.go @@ -6,11 +6,12 @@ package events import ( "context" "fmt" - "github.com/stretchr/testify/assert" netUrl "net/url" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteidl/clients/go/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" diff --git a/events/mocks/event_recorder.go b/events/mocks/event_recorder.go index 8efbe1ce1a..29efbf2755 100644 --- a/events/mocks/event_recorder.go +++ b/events/mocks/event_recorder.go @@ -1,51 +1,112 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + package mocks import ( - "context" + context "context" + + event "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + mock "github.com/stretchr/testify/mock" ) -type MockRecorder struct { - RecordNodeEventCb func(ctx context.Context, event *event.NodeExecutionEvent) error - RecordTaskEventCb func(ctx context.Context, event *event.TaskExecutionEvent) error - RecordWorkflowEventCb func(ctx context.Context, event *event.WorkflowExecutionEvent) error +// EventRecorder is an autogenerated mock type for the EventRecorder type +type EventRecorder struct { + mock.Mock } -func (m *MockRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent) error { - if m.RecordNodeEventCb != nil { - return m.RecordNodeEventCb(ctx, event) - } +type EventRecorder_RecordNodeEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordNodeEvent) Return(_a0 error) *EventRecorder_RecordNodeEvent { + return &EventRecorder_RecordNodeEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent) *EventRecorder_RecordNodeEvent { + c := _m.On("RecordNodeEvent", ctx, _a1) + return &EventRecorder_RecordNodeEvent{Call: c} +} - return nil +func (_m *EventRecorder) OnRecordNodeEventMatch(matchers ...interface{}) *EventRecorder_RecordNodeEvent { + c := _m.On("RecordNodeEvent", matchers...) + return &EventRecorder_RecordNodeEvent{Call: c} } -func (m *MockRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent) error { - if m.RecordTaskEventCb != nil { - return m.RecordTaskEventCb(ctx, event) +// RecordNodeEvent provides a mock function with given fields: ctx, _a1 +func (_m *EventRecorder) RecordNodeEvent(ctx context.Context, _a1 *event.NodeExecutionEvent) error { + ret := _m.Called(ctx, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.NodeExecutionEvent) error); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Error(0) } - return nil + return r0 +} + +type EventRecorder_RecordTaskEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordTaskEvent) Return(_a0 error) *EventRecorder_RecordTaskEvent { + return &EventRecorder_RecordTaskEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent) *EventRecorder_RecordTaskEvent { + c := _m.On("RecordTaskEvent", ctx, _a1) + return &EventRecorder_RecordTaskEvent{Call: c} +} + +func (_m *EventRecorder) OnRecordTaskEventMatch(matchers ...interface{}) *EventRecorder_RecordTaskEvent { + c := _m.On("RecordTaskEvent", matchers...) + return &EventRecorder_RecordTaskEvent{Call: c} } -func (m *MockRecorder) RecordWorkflowEvent(ctx context.Context, event *event.WorkflowExecutionEvent) error { - if m.RecordWorkflowEventCb != nil { - return m.RecordWorkflowEventCb(ctx, event) +// RecordTaskEvent provides a mock function with given fields: ctx, _a1 +func (_m *EventRecorder) RecordTaskEvent(ctx context.Context, _a1 *event.TaskExecutionEvent) error { + ret := _m.Called(ctx, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.TaskExecutionEvent) error); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Error(0) } - return nil + return r0 +} + +type EventRecorder_RecordWorkflowEvent struct { + *mock.Call +} + +func (_m EventRecorder_RecordWorkflowEvent) Return(_a0 error) *EventRecorder_RecordWorkflowEvent { + return &EventRecorder_RecordWorkflowEvent{Call: _m.Call.Return(_a0)} +} + +func (_m *EventRecorder) OnRecordWorkflowEvent(ctx context.Context, _a1 *event.WorkflowExecutionEvent) *EventRecorder_RecordWorkflowEvent { + c := _m.On("RecordWorkflowEvent", ctx, _a1) + return &EventRecorder_RecordWorkflowEvent{Call: c} } -func NewMock() *MockRecorder { - return &MockRecorder{ - RecordNodeEventCb: func(ctx context.Context, event *event.NodeExecutionEvent) error { - return nil - }, - RecordTaskEventCb: func(ctx context.Context, event *event.TaskExecutionEvent) error { - return nil - }, - RecordWorkflowEventCb: func(ctx context.Context, event *event.WorkflowExecutionEvent) error { - return nil - }, +func (_m *EventRecorder) OnRecordWorkflowEventMatch(matchers ...interface{}) *EventRecorder_RecordWorkflowEvent { + c := _m.On("RecordWorkflowEvent", matchers...) + return &EventRecorder_RecordWorkflowEvent{Call: c} +} + +// RecordWorkflowEvent provides a mock function with given fields: ctx, _a1 +func (_m *EventRecorder) RecordWorkflowEvent(ctx context.Context, _a1 *event.WorkflowExecutionEvent) error { + ret := _m.Called(ctx, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *event.WorkflowExecutionEvent) error); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Error(0) } + + return r0 } diff --git a/events/mocks/event_sink.go b/events/mocks/event_sink.go new file mode 100644 index 0000000000..61046ee46d --- /dev/null +++ b/events/mocks/event_sink.go @@ -0,0 +1,80 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + protoiface "google.golang.org/protobuf/runtime/protoiface" +) + +// EventSink is an autogenerated mock type for the EventSink type +type EventSink struct { + mock.Mock +} + +type EventSink_Close struct { + *mock.Call +} + +func (_m EventSink_Close) Return(_a0 error) *EventSink_Close { + return &EventSink_Close{Call: _m.Call.Return(_a0)} +} + +func (_m *EventSink) OnClose() *EventSink_Close { + c := _m.On("Close") + return &EventSink_Close{Call: c} +} + +func (_m *EventSink) OnCloseMatch(matchers ...interface{}) *EventSink_Close { + c := _m.On("Close", matchers...) + return &EventSink_Close{Call: c} +} + +// Close provides a mock function with given fields: +func (_m *EventSink) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type EventSink_Sink struct { + *mock.Call +} + +func (_m EventSink_Sink) Return(_a0 error) *EventSink_Sink { + return &EventSink_Sink{Call: _m.Call.Return(_a0)} +} + +func (_m *EventSink) OnSink(ctx context.Context, message protoiface.MessageV1) *EventSink_Sink { + c := _m.On("Sink", ctx, message) + return &EventSink_Sink{Call: c} +} + +func (_m *EventSink) OnSinkMatch(matchers ...interface{}) *EventSink_Sink { + c := _m.On("Sink", matchers...) + return &EventSink_Sink{Call: c} +} + +// Sink provides a mock function with given fields: ctx, message +func (_m *EventSink) Sink(ctx context.Context, message protoiface.MessageV1) error { + ret := _m.Called(ctx, message) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, protoiface.MessageV1) error); ok { + r0 = rf(ctx, message) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/events/mocks/writer.go b/events/mocks/writer.go new file mode 100644 index 0000000000..f03066760f --- /dev/null +++ b/events/mocks/writer.go @@ -0,0 +1,78 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// writer is an autogenerated mock type for the writer type +type writer struct { + mock.Mock +} + +type writer_Flush struct { + *mock.Call +} + +func (_m writer_Flush) Return(_a0 error) *writer_Flush { + return &writer_Flush{Call: _m.Call.Return(_a0)} +} + +func (_m *writer) OnFlush() *writer_Flush { + c := _m.On("Flush") + return &writer_Flush{Call: c} +} + +func (_m *writer) OnFlushMatch(matchers ...interface{}) *writer_Flush { + c := _m.On("Flush", matchers...) + return &writer_Flush{Call: c} +} + +// Flush provides a mock function with given fields: +func (_m *writer) Flush() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type writer_Write struct { + *mock.Call +} + +func (_m writer_Write) Return(_a0 error) *writer_Write { + return &writer_Write{Call: _m.Call.Return(_a0)} +} + +func (_m *writer) OnWrite(ctx context.Context, content string) *writer_Write { + c := _m.On("Write", ctx, content) + return &writer_Write{Call: c} +} + +func (_m *writer) OnWriteMatch(matchers ...interface{}) *writer_Write { + c := _m.On("Write", matchers...) + return &writer_Write{Call: c} +} + +// Write provides a mock function with given fields: ctx, content +func (_m *writer) Write(ctx context.Context, content string) error { + ret := _m.Called(ctx, content) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, content) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/events/node_event_recorder_test.go b/events/node_event_recorder_test.go index f202af77b0..8fcddd5696 100644 --- a/events/node_event_recorder_test.go +++ b/events/node_event_recorder_test.go @@ -36,11 +36,12 @@ func getRawOutputNodeEv() *event.NodeExecutionEvent { } func TestRecordNodeEvent_Success_ReferenceOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordNodeEventCb = func(ctx context.Context, event *event.NodeExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordNodeEventMatch(ctx, mock.MatchedBy(func(event *event.NodeExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceNodeEv())) - return nil - } + return true + })).Return(nil) mockStore := &storage.DataStore{ ComposedProtobufStore: &storageMocks.ComposedProtobufStore{}, ReferenceConstructor: &storageMocks.ReferenceConstructor{}, @@ -50,16 +51,17 @@ func TestRecordNodeEvent_Success_ReferenceOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordNodeEvent(context.TODO(), getReferenceNodeEv(), referenceEventConfig) + err := recorder.RecordNodeEvent(ctx, getReferenceNodeEv(), referenceEventConfig) assert.NoError(t, err) } func TestRecordNodeEvent_Success_InlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordNodeEventCb = func(ctx context.Context, event *event.NodeExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordNodeEventMatch(ctx, mock.MatchedBy(func(event *event.NodeExecutionEvent) bool { assert.True(t, proto.Equal(event, getRawOutputNodeEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -76,16 +78,17 @@ func TestRecordNodeEvent_Success_InlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordNodeEvent(context.TODO(), getReferenceNodeEv(), inlineEventConfig) + err := recorder.RecordNodeEvent(ctx, getReferenceNodeEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordNodeEvent_Failure_FetchInlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordNodeEventCb = func(ctx context.Context, event *event.NodeExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordNodeEventMatch(ctx, mock.MatchedBy(func(event *event.NodeExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceNodeEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -99,19 +102,19 @@ func TestRecordNodeEvent_Failure_FetchInlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordNodeEvent(context.TODO(), getReferenceNodeEv(), inlineEventConfig) + err := recorder.RecordNodeEvent(ctx, getReferenceNodeEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordNodeEvent_Failure_FallbackReference_Retry(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordNodeEventCb = func(ctx context.Context, event *event.NodeExecutionEvent) error { - if event.GetOutputData() != nil { - return status.Error(codes.ResourceExhausted, "message too large") - } - assert.True(t, proto.Equal(event, getReferenceNodeEv())) - return nil - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordNodeEventMatch(ctx, mock.MatchedBy(func(event *event.NodeExecutionEvent) bool { + return event.GetOutputData() != nil + })).Return(status.Error(codes.ResourceExhausted, "message too large")) + eventRecorder.OnRecordNodeEventMatch(ctx, mock.MatchedBy(func(event *event.NodeExecutionEvent) bool { + return event.GetOutputData() == nil && proto.Equal(event, getReferenceNodeEv()) + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -128,15 +131,14 @@ func TestRecordNodeEvent_Failure_FallbackReference_Retry(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordNodeEvent(context.TODO(), getReferenceNodeEv(), inlineEventConfigFallback) + err := recorder.RecordNodeEvent(ctx, getReferenceNodeEv(), inlineEventConfigFallback) assert.NoError(t, err) } func TestRecordNodeEvent_Failure_FallbackReference_Unretriable(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordNodeEventCb = func(ctx context.Context, event *event.NodeExecutionEvent) error { - return errors.New("foo") - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordNodeEventMatch(ctx, mock.Anything).Return(errors.New("foo")) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -153,6 +155,6 @@ func TestRecordNodeEvent_Failure_FallbackReference_Unretriable(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordNodeEvent(context.TODO(), getReferenceNodeEv(), inlineEventConfigFallback) + err := recorder.RecordNodeEvent(ctx, getReferenceNodeEv(), inlineEventConfigFallback) assert.EqualError(t, err, "foo") } diff --git a/events/task_event_recorder_test.go b/events/task_event_recorder_test.go index eaabb4bf38..e44b86c19f 100644 --- a/events/task_event_recorder_test.go +++ b/events/task_event_recorder_test.go @@ -47,11 +47,12 @@ func getRawOutputTaskEv() *event.TaskExecutionEvent { } func TestRecordTaskEvent_Success_ReferenceOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordTaskEventCb = func(ctx context.Context, event *event.TaskExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(ctx, mock.MatchedBy(func(event *event.TaskExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceTaskEv())) - return nil - } + return true + })).Return(nil) mockStore := &storage.DataStore{ ComposedProtobufStore: &storageMocks.ComposedProtobufStore{}, ReferenceConstructor: &storageMocks.ReferenceConstructor{}, @@ -61,16 +62,17 @@ func TestRecordTaskEvent_Success_ReferenceOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordTaskEvent(context.TODO(), getReferenceTaskEv(), referenceEventConfig) + err := recorder.RecordTaskEvent(ctx, getReferenceTaskEv(), referenceEventConfig) assert.NoError(t, err) } func TestRecordTaskEvent_Success_InlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordTaskEventCb = func(ctx context.Context, event *event.TaskExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(ctx, mock.MatchedBy(func(event *event.TaskExecutionEvent) bool { assert.True(t, proto.Equal(event, getRawOutputTaskEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -87,16 +89,17 @@ func TestRecordTaskEvent_Success_InlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordTaskEvent(context.TODO(), getReferenceTaskEv(), inlineEventConfig) + err := recorder.RecordTaskEvent(ctx, getReferenceTaskEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordTaskEvent_Failure_FetchInlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordTaskEventCb = func(ctx context.Context, event *event.TaskExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(ctx, mock.MatchedBy(func(event *event.TaskExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceTaskEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -110,19 +113,19 @@ func TestRecordTaskEvent_Failure_FetchInlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordTaskEvent(context.TODO(), getReferenceTaskEv(), inlineEventConfig) + err := recorder.RecordTaskEvent(ctx, getReferenceTaskEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordTaskEvent_Failure_FallbackReference_Retry(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordTaskEventCb = func(ctx context.Context, event *event.TaskExecutionEvent) error { - if event.GetOutputData() != nil { - return status.Error(codes.ResourceExhausted, "message too large") - } - assert.True(t, proto.Equal(event, getReferenceTaskEv())) - return nil - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(ctx, mock.MatchedBy(func(event *event.TaskExecutionEvent) bool { + return event.GetOutputData() != nil + })).Return(status.Error(codes.ResourceExhausted, "message too large")) + eventRecorder.OnRecordTaskEventMatch(ctx, mock.MatchedBy(func(event *event.TaskExecutionEvent) bool { + return event.GetOutputData() == nil && proto.Equal(event, getReferenceTaskEv()) + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -139,15 +142,14 @@ func TestRecordTaskEvent_Failure_FallbackReference_Retry(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordTaskEvent(context.TODO(), getReferenceTaskEv(), inlineEventConfigFallback) + err := recorder.RecordTaskEvent(ctx, getReferenceTaskEv(), inlineEventConfigFallback) assert.NoError(t, err) } func TestRecordTaskEvent_Failure_FallbackReference_Unretriable(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordTaskEventCb = func(ctx context.Context, event *event.TaskExecutionEvent) error { - return errors.New("foo") - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordTaskEventMatch(ctx, mock.Anything).Return(errors.New("foo")) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -164,6 +166,6 @@ func TestRecordTaskEvent_Failure_FallbackReference_Unretriable(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordTaskEvent(context.TODO(), getReferenceTaskEv(), inlineEventConfigFallback) + err := recorder.RecordTaskEvent(ctx, getReferenceTaskEv(), inlineEventConfigFallback) assert.EqualError(t, err, "foo") } diff --git a/events/workflow_event_recorder_test.go b/events/workflow_event_recorder_test.go index 312dd4f40b..154661941a 100644 --- a/events/workflow_event_recorder_test.go +++ b/events/workflow_event_recorder_test.go @@ -36,11 +36,12 @@ func getRawOutputWorkflowEv() *event.WorkflowExecutionEvent { } func TestRecordWorkflowEvent_Success_ReferenceOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordWorkflowEventCb = func(ctx context.Context, event *event.WorkflowExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.MatchedBy(func(event *event.WorkflowExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceWorkflowEv())) - return nil - } + return true + })).Return(nil) mockStore := &storage.DataStore{ ComposedProtobufStore: &storageMocks.ComposedProtobufStore{}, ReferenceConstructor: &storageMocks.ReferenceConstructor{}, @@ -50,16 +51,17 @@ func TestRecordWorkflowEvent_Success_ReferenceOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordWorkflowEvent(context.TODO(), getReferenceWorkflowEv(), referenceEventConfig) + err := recorder.RecordWorkflowEvent(ctx, getReferenceWorkflowEv(), referenceEventConfig) assert.NoError(t, err) } func TestRecordWorkflowEvent_Success_InlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordWorkflowEventCb = func(ctx context.Context, event *event.WorkflowExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.MatchedBy(func(event *event.WorkflowExecutionEvent) bool { assert.True(t, proto.Equal(event, getRawOutputWorkflowEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -76,16 +78,17 @@ func TestRecordWorkflowEvent_Success_InlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordWorkflowEvent(context.TODO(), getReferenceWorkflowEv(), inlineEventConfig) + err := recorder.RecordWorkflowEvent(ctx, getReferenceWorkflowEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordWorkflowEvent_Failure_FetchInlineOutputs(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordWorkflowEventCb = func(ctx context.Context, event *event.WorkflowExecutionEvent) error { + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.MatchedBy(func(event *event.WorkflowExecutionEvent) bool { assert.True(t, proto.Equal(event, getReferenceWorkflowEv())) - return nil - } + return true + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -99,19 +102,19 @@ func TestRecordWorkflowEvent_Failure_FetchInlineOutputs(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordWorkflowEvent(context.TODO(), getReferenceWorkflowEv(), inlineEventConfig) + err := recorder.RecordWorkflowEvent(ctx, getReferenceWorkflowEv(), inlineEventConfig) assert.NoError(t, err) } func TestRecordWorkflowEvent_Failure_FallbackReference_Retry(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordWorkflowEventCb = func(ctx context.Context, event *event.WorkflowExecutionEvent) error { - if event.GetOutputData() != nil { - return status.Error(codes.ResourceExhausted, "message too large") - } - assert.True(t, proto.Equal(event, getReferenceWorkflowEv())) - return nil - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.MatchedBy(func(event *event.WorkflowExecutionEvent) bool { + return event.GetOutputData() != nil + })).Return(status.Error(codes.ResourceExhausted, "message too large")) + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.MatchedBy(func(event *event.WorkflowExecutionEvent) bool { + return event.GetOutputData() == nil && proto.Equal(event, getReferenceWorkflowEv()) + })).Return(nil) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI @@ -128,15 +131,14 @@ func TestRecordWorkflowEvent_Failure_FallbackReference_Retry(t *testing.T) { eventRecorder: &eventRecorder, store: mockStore, } - err := recorder.RecordWorkflowEvent(context.TODO(), getReferenceWorkflowEv(), inlineEventConfigFallback) + err := recorder.RecordWorkflowEvent(ctx, getReferenceWorkflowEv(), inlineEventConfigFallback) assert.NoError(t, err) } func TestRecordWorkflowEvent_Failure_FallbackReference_Unretriable(t *testing.T) { - eventRecorder := mocks.MockRecorder{} - eventRecorder.RecordWorkflowEventCb = func(ctx context.Context, event *event.WorkflowExecutionEvent) error { - return errors.New("foo") - } + ctx := context.TODO() + eventRecorder := mocks.EventRecorder{} + eventRecorder.OnRecordWorkflowEventMatch(ctx, mock.Anything).Return(errors.New("foo")) pbStore := &storageMocks.ComposedProtobufStore{} pbStore.OnReadProtobufMatch(mock.Anything, mock.MatchedBy(func(ref storage.DataReference) bool { return ref.String() == referenceURI diff --git a/go.mod b/go.mod index c6c7c87bc0..5ccd3a6995 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/stretchr/testify v1.7.0 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba google.golang.org/grpc v1.36.0 + google.golang.org/protobuf v1.25.0 k8s.io/api v0.20.2 k8s.io/apimachinery v0.20.2 k8s.io/client-go v0.20.2 diff --git a/pkg/controller/nodes/handler/mocks/node_execution_context.go b/pkg/controller/nodes/handler/mocks/node_execution_context.go index 650ae5b0a0..aaedf7c56f 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_context.go +++ b/pkg/controller/nodes/handler/mocks/node_execution_context.go @@ -5,7 +5,6 @@ package mocks import ( events "github.com/flyteorg/flytepropeller/events" executors "github.com/flyteorg/flytepropeller/pkg/controller/executors" - handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" io "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"