From 9e1363fc04bedbe3558574a578e47974c2efc7ac Mon Sep 17 00:00:00 2001 From: Josiah Witt Date: Sun, 5 Feb 2023 16:30:59 -0500 Subject: [PATCH 1/2] Add Ensure to test context --- ensuring/ensuring.go | 6 +- ensuring/ensuring_test.go | 6 +- ensuring/init_test.go | 8 +-- ensuring/internal/testhelper/testhelper.go | 4 +- internal/mocks/mock_testctx/mock_testctx.go | 25 +++++++++ internal/testctx/testctx.go | 36 +++++++++++- internal/testctx/testctx_test.go | 62 ++++++++++++++++----- 7 files changed, 122 insertions(+), 25 deletions(-) diff --git a/ensuring/ensuring.go b/ensuring/ensuring.go index 2fec647..49da338 100644 --- a/ensuring/ensuring.go +++ b/ensuring/ensuring.go @@ -14,7 +14,7 @@ import ( ) //nolint:gochecknoglobals // This is stored as a variable so we can override it for tests in init_test.go. -var newTestContext = testctx.New +var newTestContextFunc = testctx.New // T implements a subset of methods on [testing.T]. // More methods may be added to T with a minor ensure release. @@ -118,3 +118,7 @@ func wrap(t T) E { return c } } + +func newTestContext(t T) testctx.Context { + return newTestContextFunc(t, func(t testctx.T) interface{} { return wrap(t) }) +} diff --git a/ensuring/ensuring_test.go b/ensuring/ensuring_test.go index 23ce5bd..ca86d01 100644 --- a/ensuring/ensuring_test.go +++ b/ensuring/ensuring_test.go @@ -188,7 +188,7 @@ func setupMockT(t *testing.T) *mock_testctx.MockT { ctrl := gomock.NewController(t) mockT := mock_testctx.NewMockT(ctrl) - testhelper.SetTestContext(t, mockT, testctx.New(mockT)) + testhelper.SetTestContext(t, mockT, testctx.New(mockT, wrapEnsure)) return mockT } @@ -205,3 +205,7 @@ func setupMockTWithCleanupCheck(t *testing.T) *mock_testctx.MockT { return mockT } + +func wrapEnsure(t testctx.T) interface{} { + return ensure.New(t) +} diff --git a/ensuring/init_test.go b/ensuring/init_test.go index d389df1..b61531e 100644 --- a/ensuring/init_test.go +++ b/ensuring/init_test.go @@ -5,8 +5,8 @@ import "github.com/JosiahWitt/ensure/ensuring/internal/testhelper" //nolint:gochecknoinits // Only to make testing easier. func init() { - // Initializes the unexported newTestContext variable to use the test implementation. - // This allows us to continue to keep the tests in the separate testing package and - // keep the newTestContext variable unexported. - newTestContext = testhelper.NewTestContext + // Initializes the unexported newTestContextFunc variable to use the test implementation. + // This allows us to continue to keep the tests in the separate testing package and keep + // the newTestContextFunc variable unexported. + newTestContextFunc = testhelper.NewTestContext } diff --git a/ensuring/internal/testhelper/testhelper.go b/ensuring/internal/testhelper/testhelper.go index 731089b..703a0ba 100644 --- a/ensuring/internal/testhelper/testhelper.go +++ b/ensuring/internal/testhelper/testhelper.go @@ -15,11 +15,11 @@ var ( // NewTestContext is called instead of [testctx.New] and is setup in ../../init_test.go. // This shouldn't be used by anything else. -func NewTestContext(t testctx.T) testctx.Context { +func NewTestContext(t testctx.T, wrapEnsure testctx.WrapEnsure) testctx.Context { ctx, ok := testContexts[t] if !ok { if allowAnyTestContexts { - return testctx.New(t) + return testctx.New(t, wrapEnsure) } panic("Missing mock test context") diff --git a/internal/mocks/mock_testctx/mock_testctx.go b/internal/mocks/mock_testctx/mock_testctx.go index 47b6b24..c2bb25a 100644 --- a/internal/mocks/mock_testctx/mock_testctx.go +++ b/internal/mocks/mock_testctx/mock_testctx.go @@ -239,6 +239,31 @@ func (m *MockContext) EXPECT() *MockContextMockRecorder { return m.recorder } +// Ensure mocks Ensure on Context. +func (m *MockContext) Ensure() interface{} { + m.ctrl.T.Helper() + inputs := []interface{}{} + ret := m.ctrl.Call(m, "Ensure", inputs...) + ret0, _ := ret[0].(interface{}) + return ret0 +} + +// Ensure sets up expectations for calls to Ensure. +// Calling this method multiple times allows expecting multiple calls to Ensure with a variety of parameters. +// +// Inputs: +// +// none +// +// Outputs: +// +// interface{} +func (mr *MockContextMockRecorder) Ensure() *gomock.Call { + mr.mock.ctrl.T.Helper() + inputs := []interface{}{} + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ensure", reflect.TypeOf((*MockContext)(nil).Ensure), inputs...) +} + // GoMockController mocks GoMockController on Context. func (m *MockContext) GoMockController() *gomock.Controller { m.ctrl.T.Helper() diff --git a/internal/testctx/testctx.go b/internal/testctx/testctx.go index 18c4576..4c66609 100644 --- a/internal/testctx/testctx.go +++ b/internal/testctx/testctx.go @@ -19,24 +19,41 @@ type T interface { var _ T = &testing.T{} +// WrapEnsure is a function that returns the [ensuring.E] for the provided [T]. +// It returns an interface instead of the concrete type to avoid an import cycle. +type WrapEnsure func(T) interface{} + // Context contains scoped test helpers. type Context interface { + // T returns the currently in-scope [T]. T() T + + // Run wraps the [testing.T] Run method, making it mockable. Run(name string, fn func(Context)) + + // GoMockController returns the [gomock.Controller] relating to the in-scope [T]. + // It memoizes the value for subsequent calls. GoMockController() *gomock.Controller + + // Ensure returns the [ensuring.E] relating to the in-scope [T]. + // It returns an interface instead of the concrete type to avoid an import cycle. + // It memoizes the value for subsequent calls. + Ensure() interface{} } type baseContext struct { t T goMockController *gomock.Controller + wrapEnsure WrapEnsure + ensure interface{} } var _ Context = &baseContext{} // New creates a new [Context]. -func New(t T) Context { - return &baseContext{t: t} +func New(t T, wrapEnsure WrapEnsure) Context { + return &baseContext{t: t, wrapEnsure: wrapEnsure} } // T returns the currently in-scope [T]. @@ -50,7 +67,7 @@ func (ctx *baseContext) Run(name string, fn func(Context)) { ctx.t.Run(name, func(t *testing.T) { t.Helper() - wrappedCtx := New(t) + wrappedCtx := New(t, ctx.wrapEnsure) fn(wrappedCtx) }) } @@ -67,3 +84,16 @@ func (ctx *baseContext) GoMockController() *gomock.Controller { return ctx.goMockController } + +// Ensure returns the [ensuring.E] relating to the in-scope [T]. +// It returns an interface instead of the concrete type to avoid an import cycle. +// It memoizes the value for subsequent calls. +func (ctx *baseContext) Ensure() interface{} { + ctx.t.Helper() + + if ctx.ensure == nil { + ctx.ensure = ctx.wrapEnsure(ctx.t) + } + + return ctx.ensure +} diff --git a/internal/testctx/testctx_test.go b/internal/testctx/testctx_test.go index 6e315e8..1431c1a 100644 --- a/internal/testctx/testctx_test.go +++ b/internal/testctx/testctx_test.go @@ -1,6 +1,7 @@ package testctx_test import ( + "fmt" "reflect" "testing" @@ -11,22 +12,22 @@ import ( ) func TestNew(t *testing.T) { - mockT := struct { - testctx.T - unique string - }{unique: "hello"} + ctrl := gomock.NewController(t) + mockT := mock_testctx.NewMockT(ctrl) + mockT.EXPECT().Helper().AnyTimes() - ctx := testctx.New(mockT) - eq(t, ctx.T(), mockT) + wrappedT := MockT{T: mockT, unique: "hello"} + wrapEnsure := func(t testctx.T) interface{} { return t.(MockT).unique + " world" } + + ctx := testctx.New(wrappedT, wrapEnsure) + eq(t, ctx.T().(MockT).unique, "hello") + eq(t, ctx.Ensure(), "hello world") } func TestT(t *testing.T) { - mockT := struct { - testctx.T - unique string - }{unique: "hello"} + mockT := MockT{unique: "hello"} - ctx := testctx.New(mockT) + ctx := testctx.New(mockT, nil) eq(t, ctx.T(), mockT) } @@ -39,16 +40,22 @@ func TestRun(t *testing.T) { fn(&testing.T{}) }) - ctx := testctx.New(outerT) + wrapEnsure := func(t testctx.T) interface{} { return fmt.Sprintf("%T", t) } + ctx := testctx.New(outerT, wrapEnsure) var actualInnerT *testing.T + var actualInnerEnsure string ctx.Run("everything works", func(ctx testctx.Context) { actualInnerT = ctx.T().(*testing.T) + actualInnerEnsure = ctx.Ensure().(string) }) neq(t, actualInnerT, nil) // It shouldn't be nil, indicating the callback wasn't called neq(t, actualInnerT, &testing.T{}) // It shouldn't be empty, indicating Helper() wasn't called neq(t, actualInnerT, outerT) // It shouldn't be the outerT + + // Show wrapEnsure was promoted correctly + eq(t, actualInnerEnsure, "*testing.T") } func TestGoMockController(t *testing.T) { @@ -61,7 +68,7 @@ func TestGoMockController(t *testing.T) { fn() }).Times(2) // We call it once and gomock.NewController calls it once - ctx := testctx.New(mockT) + ctx := testctx.New(mockT, nil) mockCtrl := ctx.GoMockController() eq(t, mockCtrl.T, mockT) @@ -87,7 +94,7 @@ func TestGoMockController(t *testing.T) { mockT.EXPECT().Helper().AnyTimes() mockT.EXPECT().Cleanup(gomock.Any()).AnyTimes() - ctx := testctx.New(mockT) + ctx := testctx.New(mockT, nil) mockCtrl := ctx.GoMockController() // SomeMethod is never "called", and should be noticed during cleanup @@ -99,6 +106,28 @@ func TestGoMockController(t *testing.T) { }) } +func TestEnsure(t *testing.T) { + ctrl := gomock.NewController(t) + mockT := mock_testctx.NewMockT(ctrl) + mockT.EXPECT().Helper().AnyTimes() + + wrappedT := MockT{T: mockT, unique: "hello"} + + callCount := 0 + wrapEnsure := func(t testctx.T) interface{} { + callCount++ + return t.(MockT).unique + " world" + } + + ctx := testctx.New(wrappedT, wrapEnsure) + eq(t, ctx.Ensure(), "hello world") + + // Show it's memoized + ctx.Ensure() + ctx.Ensure() + eq(t, callCount, 1) +} + func eq(t *testing.T, a, b interface{}) { t.Helper() if !reflect.DeepEqual(a, b) { @@ -113,6 +142,11 @@ func neq(t *testing.T, a, b interface{}) { } } +type MockT struct { + testctx.T + unique string +} + type exampleTypeWithMethod struct{} func (*exampleTypeWithMethod) SomeMethod(param bool) {} From 8a80acef8b442ee3a434fbbfc9943c29f104f7ad Mon Sep 17 00:00:00 2001 From: Josiah Witt Date: Tue, 7 Feb 2023 21:07:28 -0500 Subject: [PATCH 2/2] Add optional ensuring.E parameter to SetupMocks --- .golangci.yml | 4 + cmd/ensure/.golangci.yml | 4 + ensuring/internal/testhelper/testhelper.go | 20 +++ ensuring/run_table_test.go | 47 ++++- internal/plugins/internal/id/id.go | 2 + internal/plugins/setupmocks/setupmocks.go | 56 ++++-- .../plugins/setupmocks/setupmocks_test.go | 168 ++++++++++++++++-- internal/reflectensure/reflectensure.go | 15 ++ internal/reflectensure/reflectensure_test.go | 48 +++++ 9 files changed, 340 insertions(+), 24 deletions(-) create mode 100644 internal/reflectensure/reflectensure.go create mode 100644 internal/reflectensure/reflectensure_test.go diff --git a/.golangci.yml b/.golangci.yml index b1ba1d1..49a9040 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -46,3 +46,7 @@ issues: - ifshort - thelper - maintidx + + - text: singleCaseSwitch + linters: + - gocritic diff --git a/cmd/ensure/.golangci.yml b/cmd/ensure/.golangci.yml index 76490c0..0871fe9 100644 --- a/cmd/ensure/.golangci.yml +++ b/cmd/ensure/.golangci.yml @@ -48,6 +48,10 @@ issues: - thelper - maintidx + - text: singleCaseSwitch + linters: + - gocritic + # Exclude some linters from fixtures and scenarios - path: .*(/fixtures/)|(/scenarios/).* linters: diff --git a/ensuring/internal/testhelper/testhelper.go b/ensuring/internal/testhelper/testhelper.go index 703a0ba..3f419eb 100644 --- a/ensuring/internal/testhelper/testhelper.go +++ b/ensuring/internal/testhelper/testhelper.go @@ -2,6 +2,7 @@ package testhelper import ( + "fmt" "testing" "github.com/JosiahWitt/ensure/internal/testctx" @@ -11,11 +12,15 @@ import ( var ( testContexts = map[testctx.T]testctx.Context{} allowAnyTestContexts = false + + checkingWrapEnsure = false // Used to prevent an infinite loop ) // NewTestContext is called instead of [testctx.New] and is setup in ../../init_test.go. // This shouldn't be used by anything else. func NewTestContext(t testctx.T, wrapEnsure testctx.WrapEnsure) testctx.Context { + checkWrapEnsure(t, wrapEnsure) + ctx, ok := testContexts[t] if !ok { if allowAnyTestContexts { @@ -47,3 +52,18 @@ func AllowAnyTestContexts(t *testing.T) { allowAnyTestContexts = false }) } + +func checkWrapEnsure(t testctx.T, wrapEnsure testctx.WrapEnsure) { + // Prevent an infinite loop, since NewTestContext will be called by wrapEnsure + if checkingWrapEnsure { + return + } + + checkingWrapEnsure = true + defer func() { checkingWrapEnsure = false }() + + ensure := wrapEnsure(t) + if ensure == nil || fmt.Sprintf("%T", ensure) != "ensuring.E" { + panic(fmt.Sprintf("wrapEnsure doesn't function correctly: %[1]v (%[1]T)", ensure)) + } +} diff --git a/ensuring/run_table_test.go b/ensuring/run_table_test.go index b956f7d..7033179 100644 --- a/ensuring/run_table_test.go +++ b/ensuring/run_table_test.go @@ -69,6 +69,7 @@ func TestERunTableByIndex(t *testing.T) { innerContext := mock_testctx.NewMockContext(ctrl) innerContext.EXPECT().T().Return(innerMockT).AnyTimes() innerContext.EXPECT().GoMockController().Return(gomock.NewController(innerMockT)).AnyTimes() + innerContext.EXPECT().Ensure().Return(ensure.New(innerMockT)).AnyTimes() testhelper.SetTestContext(t, innerMockT, innerContext) innerMockT.EXPECT().Fatalf(gomock.Any()).Do(func(msg string, args ...interface{}) { @@ -753,7 +754,7 @@ func (runTableTests) setupMocksField() runTableTestEntryGroup { Prefix: "SetupMocks field", Entries: []runTableTestEntry{ { - Name: "with valid function", + Name: "with valid function with one param", ExpectedNames: []string{"name 1", "name 2"}, Table: []struct { Name string @@ -787,6 +788,44 @@ func (runTableTests) setupMocksField() runTableTestEntryGroup { } }, }, + { + Name: "with valid function with two params", + ExpectedNames: []string{"name 1", "name 2"}, + FatalMessagesContain: []string{"first SetupMocks", "second SetupMocks"}, // Not actual failures; only to show ensure is passed in correctly + Table: []struct { + Name string + Mocks *TwoValidMocks + SetupMocks func(*TwoValidMocks, ensuring.E) + }{ + { + Name: "name 1", + SetupMocks: func(tvm *TwoValidMocks, ensure ensuring.E) { + tvm.Valid1.CustomField = "updated name 1" + ensure.Failf("first SetupMocks") + }, + }, + { + Name: "name 2", + SetupMocks: func(tvm *TwoValidMocks, ensure ensuring.E) { + tvm.Valid1.CustomField = "updated name 2" + ensure.Failf("second SetupMocks") + }, + }, + }, + + CheckEntry: func(t *testing.T, rawTable interface{}) { + table := rawTable.([]struct { + Name string + Mocks *TwoValidMocks + SetupMocks func(*TwoValidMocks, ensuring.E) + }) + + for _, entry := range table { + entry.Mocks.check(t) + isTrue(t, entry.Mocks.Valid1.CustomField == "updated "+entry.Name) + } + }, + }, { Name: "with function not present for one", @@ -839,7 +878,7 @@ func (runTableTests) setupMocksField() runTableTestEntryGroup { { Name: "function missing param", - FatalMessagesContain: []string{"expected SetupMocks field to be a func(*ensuring_test.TwoValidMocks)"}, + FatalMessagesContain: []string{"expected SetupMocks field to be one of the following:"}, Table: []struct { Name string Mocks *TwoValidMocks @@ -858,7 +897,7 @@ func (runTableTests) setupMocksField() runTableTestEntryGroup { { Name: "function with invalid param", - FatalMessagesContain: []string{"expected SetupMocks field to be a func(*ensuring_test.TwoValidMocks)"}, + FatalMessagesContain: []string{"expected SetupMocks field to be one of the following:"}, Table: []struct { Name string Mocks *TwoValidMocks @@ -877,7 +916,7 @@ func (runTableTests) setupMocksField() runTableTestEntryGroup { { Name: "function with a return", - FatalMessagesContain: []string{"expected SetupMocks field to be a func(*ensuring_test.TwoValidMocks)"}, + FatalMessagesContain: []string{"expected SetupMocks field to be one of the following:"}, Table: []struct { Name string Mocks *TwoValidMocks diff --git a/internal/plugins/internal/id/id.go b/internal/plugins/internal/id/id.go index 6aa68f0..c9bd127 100644 --- a/internal/plugins/internal/id/id.go +++ b/internal/plugins/internal/id/id.go @@ -2,6 +2,8 @@ package id const ( + EnsuringE = "ensuring.E" + Mocks = "Mocks" SetupMocks = "SetupMocks" Subject = "Subject" diff --git a/internal/plugins/setupmocks/setupmocks.go b/internal/plugins/setupmocks/setupmocks.go index 554ad3b..9b25e64 100644 --- a/internal/plugins/setupmocks/setupmocks.go +++ b/internal/plugins/setupmocks/setupmocks.go @@ -2,10 +2,12 @@ package setupmocks import ( + "fmt" "reflect" "github.com/JosiahWitt/ensure/internal/plugins" "github.com/JosiahWitt/ensure/internal/plugins/internal/id" + "github.com/JosiahWitt/ensure/internal/reflectensure" "github.com/JosiahWitt/ensure/internal/stringerr" "github.com/JosiahWitt/ensure/internal/testctx" ) @@ -33,42 +35,67 @@ func (t *TablePlugin) ParseEntryType(entryType reflect.Type) (plugins.TableEntry return nil, stringerr.Newf("%s field must be set on the table to use %s", id.Mocks, id.SetupMocks) } - if err := validateSetupMocksFieldType(&setupMocksFunc, &mocksStruct); err != nil { + funcType, err := parseSetupMocksField(&setupMocksFunc, &mocksStruct) + if err != nil { return nil, err } h.hasSetupMocks = true + h.funcType = funcType } return h, nil } -func validateSetupMocksFieldType(setupMocksFunc, mocksStruct *reflect.StructField) error { +func parseSetupMocksField(setupMocksFunc, mocksStruct *reflect.StructField) (funcType, error) { t := setupMocksFunc.Type generateError := func() error { - return stringerr.Newf("expected %s field to be a func(%v), got: %v", id.SetupMocks, mocksStruct.Type, t) + return stringerr.NewBlock( + fmt.Sprintf("expected %s field to be one of the following", id.SetupMocks), + []error{ + stringerr.Newf("func(m %v)", mocksStruct.Type), + stringerr.Newf("func(m %v, %s %s)", mocksStruct.Type, id.Ensure, id.EnsuringE), + }, + fmt.Sprintf("Got: %v", t), + ) } if t.Kind() != reflect.Func { - return generateError() + return 0, generateError() } - invalidIns := t.NumIn() != 1 || t.In(0) != mocksStruct.Type - invalidOuts := t.NumOut() != 0 + validDefaultIns := t.NumIn() == 1 && t.In(0) == mocksStruct.Type + validEnsureIns := t.NumIn() == 2 && t.In(0) == mocksStruct.Type && reflectensure.IsEnsuringE(t.In(1)) + validIns := validDefaultIns || validEnsureIns - if invalidIns || invalidOuts { - return generateError() + validOuts := t.NumOut() == 0 + + if !validIns || !validOuts { + return 0, generateError() } - return nil + switch { + case validEnsureIns: + return funcTypeEnsure, nil + default: + return funcTypeDefault, nil + } } +type funcType int + +const ( + funcTypeDefault funcType = iota + funcTypeEnsure +) + // TableEntryHooks exposes the before and after hooks for each entry in the table. type TableEntryHooks struct { plugins.NoopAfterEntry hasSetupMocks bool + funcType funcType } var _ plugins.TableEntryHooks = &TableEntryHooks{} @@ -88,7 +115,16 @@ func (h *TableEntryHooks) BeforeEntry(ctx testctx.Context, entryValue reflect.Va } mocksField := v.FieldByName(id.Mocks) - setupMocksFunc.Call([]reflect.Value{mocksField}) + + var ins []reflect.Value + switch h.funcType { + case funcTypeDefault: + ins = []reflect.Value{mocksField} + case funcTypeEnsure: + ins = []reflect.Value{mocksField, reflect.ValueOf(ctx.Ensure())} + } + + setupMocksFunc.Call(ins) return nil } diff --git a/internal/plugins/setupmocks/setupmocks_test.go b/internal/plugins/setupmocks/setupmocks_test.go index 8fc6260..42a7596 100644 --- a/internal/plugins/setupmocks/setupmocks_test.go +++ b/internal/plugins/setupmocks/setupmocks_test.go @@ -6,8 +6,10 @@ import ( "github.com/JosiahWitt/ensure" "github.com/JosiahWitt/ensure/ensuring" + "github.com/JosiahWitt/ensure/internal/mocks/mock_testctx" "github.com/JosiahWitt/ensure/internal/plugins/setupmocks" "github.com/JosiahWitt/ensure/internal/stringerr" + "github.com/golang/mock/gomock" ) func TestParseEntryType(t *testing.T) { @@ -42,6 +44,15 @@ func TestParseEntryType(t *testing.T) { SetupMocks func(*Mocks) }{}, }, + { + Name: "returns no errors when SetupMocks is provided with an optional ensuring.E parameter", + + Entry: struct { + Name string + Mocks *Mocks + SetupMocks func(*Mocks, ensuring.E) + }{}, + }, { Name: "returns error when SetupMocks is provided, but Mocks is not provided", @@ -61,7 +72,12 @@ func TestParseEntryType(t *testing.T) { SetupMocks *func(*Mocks) }{}, - ExpectedError: stringerr.Newf("expected SetupMocks field to be a func(*setupmocks_test.Mocks), got: *func(*setupmocks_test.Mocks)"), + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: *func(*setupmocks_test.Mocks)", + ), }, { Name: "returns error when SetupMocks has no inputs", @@ -72,10 +88,15 @@ func TestParseEntryType(t *testing.T) { SetupMocks func() }{}, - ExpectedError: stringerr.Newf("expected SetupMocks field to be a func(*setupmocks_test.Mocks), got: func()"), + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func()", + ), }, { - Name: "returns error when SetupMocks has two inputs", + Name: "returns error when SetupMocks has two equal inputs", Entry: struct { Name string @@ -83,7 +104,12 @@ func TestParseEntryType(t *testing.T) { SetupMocks func(*Mocks, *Mocks) }{}, - ExpectedError: stringerr.Newf("expected SetupMocks field to be a func(*setupmocks_test.Mocks), got: func(*setupmocks_test.Mocks, *setupmocks_test.Mocks)"), + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(*setupmocks_test.Mocks, *setupmocks_test.Mocks)", + ), }, { Name: "returns error when SetupMocks has an invalid input", @@ -94,7 +120,60 @@ func TestParseEntryType(t *testing.T) { SetupMocks func(Mocks) }{}, - ExpectedError: stringerr.Newf("expected SetupMocks field to be a func(*setupmocks_test.Mocks), got: func(setupmocks_test.Mocks)"), + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(setupmocks_test.Mocks)", + ), + }, + { + Name: "returns error when SetupMocks has two inputs, and the first is invalid", + + Entry: struct { + Name string + Mocks *Mocks + SetupMocks func(Mocks, ensuring.E) + }{}, + + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(setupmocks_test.Mocks, ensuring.E)", + ), + }, + { + Name: "returns error when SetupMocks has two inputs, and the second is invalid", + + Entry: struct { + Name string + Mocks *Mocks + SetupMocks func(*Mocks, *ensuring.E) + }{}, + + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(*setupmocks_test.Mocks, *ensuring.E)", + ), + }, + { + Name: "returns error when SetupMocks has three inputs", + + Entry: struct { + Name string + Mocks *Mocks + SetupMocks func(*Mocks, ensuring.E, ensuring.E) + }{}, + + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(*setupmocks_test.Mocks, ensuring.E, ensuring.E)", + ), }, { Name: "returns error when SetupMocks returns values", @@ -105,7 +184,12 @@ func TestParseEntryType(t *testing.T) { SetupMocks func(*Mocks) *Mocks }{}, - ExpectedError: stringerr.Newf("expected SetupMocks field to be a func(*setupmocks_test.Mocks), got: func(*setupmocks_test.Mocks) *setupmocks_test.Mocks"), + ExpectedError: stringerr.Newf( + "expected SetupMocks field to be one of the following:\n" + + " - func(m *setupmocks_test.Mocks)\n" + + " - func(m *setupmocks_test.Mocks, ensure ensuring.E)\n" + + "Got: func(*setupmocks_test.Mocks) *setupmocks_test.Mocks", + ), }, } @@ -125,7 +209,8 @@ func TestParseEntryValue(t *testing.T) { table := []struct { Name string - Table interface{} + Table interface{} + SetupMockT func(m *mock_testctx.MockT, i int) ExpectedTable interface{} }{ @@ -168,7 +253,7 @@ func TestParseEntryValue(t *testing.T) { }, }, { - Name: "executes SetupMocks when SetupMocks and Mocks are provided", + Name: "executes SetupMocks when SetupMocks(*Mocks) and Mocks are provided", Table: []struct { Name string @@ -256,6 +341,58 @@ func TestParseEntryValue(t *testing.T) { }, }, }, + { + Name: "executes SetupMocks when SetupMocks(*Mocks, ensuring.E) and Mocks are provided", + + SetupMockT: func(m *mock_testctx.MockT, i int) { + switch i { + case 0: + m.EXPECT().Fatalf("first fail") + case 1: + m.EXPECT().Fatalf("second fail") + } + }, + + Table: []struct { + Name string + Mocks *Mocks + SetupMocks func(*Mocks, ensuring.E) + }{ + { + Name: "first", + SetupMocks: func(m *Mocks, ensure ensuring.E) { + m.A = "first mocks" + ensure.Failf("first fail") // Show ensure is connected correctly + }, + }, + { + Name: "second", + SetupMocks: func(m *Mocks, ensure ensuring.E) { + m.A = "second mocks" + ensure.Failf("second fail") // Show ensure is connected correctly + }, + }, + }, + + ExpectedTable: []struct { + Name string + Mocks *Mocks + SetupMocks func(*Mocks, ensuring.E) + }{ + { + Name: "first", + Mocks: &Mocks{ + A: "first mocks", + }, + }, + { + Name: "second", + Mocks: &Mocks{ + A: "second mocks", + }, + }, + }, + }, } ensure.RunTableByIndex(table, func(ensure ensuring.E, i int) { @@ -273,8 +410,19 @@ func TestParseEntryValue(t *testing.T) { mocksField.Set(reflect.New(reflect.TypeOf(Mocks{}))) } - ensure(tableEntryHooks.BeforeEntry(nil, entryVal, i)).IsNotError() - ensure(tableEntryHooks.AfterEntry(nil, entryVal, i)).IsNotError() + mockT := mock_testctx.NewMockT(ensure.GoMockController()) + mockT.EXPECT().Helper().AnyTimes() + mockT.EXPECT().Cleanup(gomock.Any()).AnyTimes() + + if entry.SetupMockT != nil { + entry.SetupMockT(mockT, i) + } + + mockCtx := mock_testctx.NewMockContext(ensure.GoMockController()) + mockCtx.EXPECT().Ensure().Return(ensure.New(mockT)).AnyTimes() + + ensure(tableEntryHooks.BeforeEntry(mockCtx, entryVal, i)).IsNotError() + ensure(tableEntryHooks.AfterEntry(mockCtx, entryVal, i)).IsNotError() } ensure(entry.Table).Equals(entry.ExpectedTable) diff --git a/internal/reflectensure/reflectensure.go b/internal/reflectensure/reflectensure.go new file mode 100644 index 0000000..5d3640e --- /dev/null +++ b/internal/reflectensure/reflectensure.go @@ -0,0 +1,15 @@ +// Package reflectensure provides a helper for identifying ensure types via reflection. +// It is used to avoid import cycles. +package reflectensure + +import "reflect" + +const ( + ensuringPath = "github.com/JosiahWitt/ensure/ensuring" + ensuringE = "E" +) + +// IsEnsuringE returns true only when [ensuring.E] or any of its aliases is provided. +func IsEnsuringE(t reflect.Type) bool { + return t.PkgPath() == ensuringPath && t.Name() == ensuringE +} diff --git a/internal/reflectensure/reflectensure_test.go b/internal/reflectensure/reflectensure_test.go new file mode 100644 index 0000000..141b48a --- /dev/null +++ b/internal/reflectensure/reflectensure_test.go @@ -0,0 +1,48 @@ +package reflectensure_test + +import ( + "reflect" + "testing" + + "github.com/JosiahWitt/ensure" + "github.com/JosiahWitt/ensure/ensurepkg" //lint:ignore SA1019 To ensure compatibility + "github.com/JosiahWitt/ensure/ensuring" + "github.com/JosiahWitt/ensure/internal/reflectensure" +) + +func TestIsEnsuringE(t *testing.T) { + ensure := ensure.New(t) + + ensure.Run("when provided ensuring.E", func(ensure ensuring.E) { + t := reflect.TypeOf(ensuring.E(nil)) + ensure(reflectensure.IsEnsuringE(t)).IsTrue() + }) + + ensure.Run("when provided pointer to ensuring.E", func(ensure ensuring.E) { + e := ensuring.E(nil) + t := reflect.TypeOf(&e) + ensure(reflectensure.IsEnsuringE(t)).IsFalse() + }) + + ensure.Run("when provided ensurepkg.Ensure", func(ensure ensuring.E) { + t := reflect.TypeOf(ensurepkg.Ensure(nil)) //lint:ignore SA1019 To ensure compatibility + ensure(reflectensure.IsEnsuringE(t)).IsTrue() + }) + + ensure.Run("when provided another type implementing ensuring.E", func(ensure ensuring.E) { + type E ensuring.E + t := reflect.TypeOf(E(nil)) + ensure(reflectensure.IsEnsuringE(t)).IsFalse() + }) + + ensure.Run("when provided another type named E", func(ensure ensuring.E) { + type E func(interface{}) *ensuring.Chain + t := reflect.TypeOf(E(nil)) + ensure(reflectensure.IsEnsuringE(t)).IsFalse() + }) + + ensure.Run("when provided another type in ensuring", func(ensure ensuring.E) { + t := reflect.TypeOf(ensuring.Chain{}) + ensure(reflectensure.IsEnsuringE(t)).IsFalse() + }) +}