Skip to content

Commit

Permalink
Improving unit test coverage for simpler cases (flyteorg#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ketan Umare authored Apr 19, 2020
1 parent 0595256 commit 64dd8ed
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 203 deletions.
4 changes: 2 additions & 2 deletions pkg/controller/nodes/end/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExe
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil
}

func (e endHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error {
func (e endHandler) Abort(_ context.Context, _ handler.NodeExecutionContext, _ string) error {
return nil
}

func (e endHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error {
func (e endHandler) Finalize(_ context.Context, _ handler.NodeExecutionContext) error {
return nil
}

Expand Down
25 changes: 25 additions & 0 deletions pkg/controller/nodes/end/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package end

import (
"context"
"fmt"
"testing"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -86,6 +87,15 @@ func TestEndHandler_Handle(t *testing.T) {
return nCtx
}

t.Run("InputReadFailure", func(t *testing.T) {
ir := &mocks2.InputReader{}
ir.OnGetMatch(mock.Anything).Return(nil, fmt.Errorf("err"))
nCtx := &mocks.NodeExecutionContext{}
nCtx.OnInputReader().Return(ir)
_, err := e.Handle(ctx, nCtx)
assert.Error(t, err)
})

t.Run("NoInputs", func(t *testing.T) {
nCtx := createNodeCtx(nil, nil)
s, err := e.Handle(ctx, nCtx)
Expand Down Expand Up @@ -122,3 +132,18 @@ func TestEndHandler_Handle(t *testing.T) {
assert.Equal(t, handler.UnknownTransition, s)
})
}

func TestEndHandler_Abort(t *testing.T) {
e := New()
assert.NoError(t, e.Abort(context.TODO(), nil, ""))
}

func TestEndHandler_Finalize(t *testing.T) {
e := New()
assert.NoError(t, e.Finalize(context.TODO(), nil))
}

func TestEndHandler_FinalizeRequired(t *testing.T) {
e := New()
assert.False(t, e.FinalizeRequired())
}
16 changes: 0 additions & 16 deletions pkg/controller/nodes/handler/transition_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,6 @@ func (p PhaseInfo) GetReason() string {
return p.reason
}

func (p *PhaseInfo) SetOcurredAt(t time.Time) {
p.occurredAt = t
}

func (p *PhaseInfo) SetErr(err *core.ExecutionError) {
p.err = err
}

func (p *PhaseInfo) SetInfo(info *ExecutionInfo) {
p.info = info
}

func (p *PhaseInfo) SetReason() string {
return p.reason
}

var PhaseInfoUndefined = PhaseInfo{p: EPhaseUndefined}

func phaseInfo(p EPhase, err *core.ExecutionError, info *ExecutionInfo, reason string) PhaseInfo {
Expand Down
141 changes: 141 additions & 0 deletions pkg/controller/nodes/handler/transition_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package handler
import (
"testing"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/stretchr/testify/assert"
)

Expand All @@ -17,10 +18,13 @@ func TestEPhase_String(t *testing.T) {
p EPhase
}{
{"queued", EPhaseQueued},
{"not-ready", EPhaseNotReady},
{"timedout", EPhaseTimedout},
{"undefined", EPhaseUndefined},
{"success", EPhaseSuccess},
{"skip", EPhaseSkip},
{"failed", EPhaseFailed},
{"running", EPhaseRunning},
{"retryable-fail", EPhaseRetryableFailure},
}
for _, tt := range tests {
Expand All @@ -31,3 +35,140 @@ func TestEPhase_String(t *testing.T) {
})
}
}

func TestEPhase_IsTerminal(t *testing.T) {
tests := []struct {
name string
p EPhase
want bool
}{
{"success", EPhaseSuccess, true},
{"failure", EPhaseFailed, true},
{"timeout", EPhaseTimedout, true},
{"skip", EPhaseSkip, true},
{"any", EPhaseQueued, false},
{"retryable", EPhaseRetryableFailure, false},
{"run", EPhaseRunning, false},
{"nr", EPhaseNotReady, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.p.IsTerminal(); got != tt.want {
t.Errorf("IsTerminal() = %v, want %v", got, tt.want)
}
})
}
}

func TestPhaseInfo(t *testing.T) {
t.Run("undefined", func(t *testing.T) {
assert.Equal(t, EPhaseUndefined, PhaseInfoUndefined.GetPhase())
})

t.Run("success", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoSuccess(i)
assert.Equal(t, EPhaseSuccess, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("not-ready", func(t *testing.T) {
p := PhaseInfoNotReady("reason")
assert.Equal(t, EPhaseNotReady, p.GetPhase())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
assert.Equal(t, "reason", p.GetReason())
})

t.Run("queued", func(t *testing.T) {
p := PhaseInfoQueued("reason")
assert.Equal(t, EPhaseQueued, p.GetPhase())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
assert.Equal(t, "reason", p.GetReason())
})

t.Run("running", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoRunning(i)
assert.Equal(t, EPhaseRunning, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("skip", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoSkip(i, "reason")
assert.Equal(t, EPhaseSkip, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
assert.Equal(t, "reason", p.GetReason())
})

t.Run("timeout", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoTimedOut(i, "reason")
assert.Equal(t, EPhaseTimedout, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Nil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
assert.Equal(t, "reason", p.GetReason())
})

t.Run("failure", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoFailure("code", "reason", i)
assert.Equal(t, EPhaseFailed, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
if assert.NotNil(t, p.GetErr()) {
assert.Equal(t, "code", p.GetErr().Code)
assert.Equal(t, "reason", p.GetErr().Message)
}
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("failure-err", func(t *testing.T) {
i := &ExecutionInfo{}
e := &core.ExecutionError{}
p := PhaseInfoFailureErr(e, i)
assert.Equal(t, EPhaseFailed, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Equal(t, e, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("failure-err", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoFailureErr(nil, i)
assert.Equal(t, EPhaseFailed, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.NotNil(t, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("retryable-fail", func(t *testing.T) {
i := &ExecutionInfo{}
p := PhaseInfoRetryableFailure("code", "reason", i)
assert.Equal(t, EPhaseRetryableFailure, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
if assert.NotNil(t, p.GetErr()) {
assert.Equal(t, "code", p.GetErr().Code)
assert.Equal(t, "reason", p.GetErr().Message)
}
assert.NotNil(t, p.GetOccurredAt())
})

t.Run("retryable-fail-err", func(t *testing.T) {
i := &ExecutionInfo{}
e := &core.ExecutionError{}
p := PhaseInfoRetryableFailureErr(e, i)
assert.Equal(t, EPhaseRetryableFailure, p.GetPhase())
assert.Equal(t, i, p.GetInfo())
assert.Equal(t, e, p.GetErr())
assert.NotNil(t, p.GetOccurredAt())
})
}
7 changes: 7 additions & 0 deletions pkg/controller/nodes/handler/transition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ func TestDoTransition(t *testing.T) {
assert.Equal(t, storage.DataReference("uri"), tr.Info().GetInfo().OutputInfo.OutputURI)
})
}

func TestTransition_WithInfo(t *testing.T) {
tr := DoTransition(TransitionTypeEphemeral, PhaseInfoQueued("queued"))
assert.Equal(t, EPhaseQueued, tr.info.p)
tr = tr.WithInfo(PhaseInfoSuccess(&ExecutionInfo{}))
assert.Equal(t, EPhaseSuccess, tr.info.p)
}
17 changes: 16 additions & 1 deletion pkg/controller/nodes/start/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestStartNodeHandler_Initialize(t *testing.T) {
assert.NoError(t, h.Setup(context.TODO(), nil))
}

func TestStartNodeHandler_StartNode(t *testing.T) {
func TestStartNodeHandler_Handle(t *testing.T) {
ctx := context.Background()
h := New()
t.Run("Any", func(t *testing.T) {
Expand All @@ -31,3 +31,18 @@ func TestStartNodeHandler_StartNode(t *testing.T) {
assert.Equal(t, handler.EPhaseSuccess, s.Info().GetPhase())
})
}

func TestEndHandler_Abort(t *testing.T) {
e := New()
assert.NoError(t, e.Abort(context.TODO(), nil, ""))
}

func TestEndHandler_Finalize(t *testing.T) {
e := New()
assert.NoError(t, e.Finalize(context.TODO(), nil))
}

func TestEndHandler_FinalizeRequired(t *testing.T) {
e := New()
assert.False(t, e.FinalizeRequired())
}
18 changes: 6 additions & 12 deletions pkg/utils/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ func FixedLengthUniqueID(inputID string, maxLength int) (string, error) {
}

hasher := fnv.New32a()
_, err := hasher.Write([]byte(inputID))
if err != nil {
return "", err
}
// Using 32a an error can never happen, so this will always remain not covered by a unit test
_, _ = hasher.Write([]byte(inputID)) // #nosec
b := hasher.Sum(nil)
// expected length after this step is 8 chars (1 + 7 chars from base32Encoder.EncodeToString(b))
finalStr := "f" + base32Encoder.EncodeToString(b)
Expand All @@ -39,16 +37,12 @@ func FixedLengthUniqueIDForParts(maxLength int, parts ...string) (string, error)
b := strings.Builder{}
for i, p := range parts {
if i > 0 && b.Len() > 0 {
_, err := b.WriteRune('-')
if err != nil {
return "", err
}
// Ignoring the error as it always returns a nil error
_, _ = b.WriteRune('-') // #nosec
}

_, err := b.WriteString(p)
if err != nil {
return "", err
}
// Ignoring the error as this is always nil
_, _ = b.WriteString(p) // #nosec
}

return FixedLengthUniqueID(b.String(), maxLength)
Expand Down
32 changes: 0 additions & 32 deletions pkg/utils/event_helpers.go

This file was deleted.

5 changes: 3 additions & 2 deletions pkg/utils/failing_datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func TestFailingRawStore(t *testing.T) {
_, err = f.ReadRaw(ctx, "")
assert.Error(t, err)

err = f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil))
assert.Error(t, err)
assert.Error(t, f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil)))

assert.Error(t, f.CopyRaw(ctx, "", "", storage.Options{}))
}
11 changes: 0 additions & 11 deletions pkg/utils/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@ func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequireme
return res, nil
}

func GetWorkflowIDFromObject(obj metav1.Object) (v1alpha1.WorkflowID, error) {
controller := metav1.GetControllerOf(obj)
if controller == nil {
return "", NotTheOwnerError
}
if controller.Kind == v1alpha1.FlyteWorkflowKind {
return obj.GetNamespace() + "/" + controller.Name, nil
}
return "", NotTheOwnerError
}

func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) {
if reference == nil {
return "", NotTheOwnerError
Expand Down
Loading

0 comments on commit 64dd8ed

Please sign in to comment.