diff --git a/clock.go b/clock.go index 40555b3..2230962 100644 --- a/clock.go +++ b/clock.go @@ -165,9 +165,17 @@ func (m *Mock) After(d time.Duration) <-chan time.Time { // AfterFunc waits for the duration to elapse and then executes a function. // A Timer is returned that can be stopped. func (m *Mock) AfterFunc(d time.Duration, f func()) *Timer { - t := m.Timer(d) - t.C = nil - t.fn = f + m.mu.Lock() + defer m.mu.Unlock() + ch := make(chan time.Time, 1) + t := &Timer{ + c: ch, + fn: f, + mock: m, + next: m.now.Add(d), + stopped: false, + } + m.timers = append(m.timers, (*internalTimer)(t)) return t } diff --git a/clock_test.go b/clock_test.go index 63be66c..b63649a 100644 --- a/clock_test.go +++ b/clock_test.go @@ -650,5 +650,44 @@ func TestMock_ReentrantDeadlock(t *testing.T) { mockedClock.Add(15 * time.Second) } +func TestMock_AddAfterFuncRace(t *testing.T) { + // start blocks the goroutines in this test + // until we're ready for them to race. + start := make(chan struct{}) + + var wg sync.WaitGroup + + mockedClock := NewMock() + + called := false + defer func() { + if !called { + t.Errorf("AfterFunc did not call the function") + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-start + + mockedClock.AfterFunc(time.Millisecond, func() { + called = true + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-start + + mockedClock.Add(time.Millisecond) + mockedClock.Add(time.Millisecond) + }() + + close(start) // unblock the goroutines + wg.Wait() // and wait for them +} + func warn(v ...interface{}) { fmt.Fprintln(os.Stderr, v...) } func warnf(msg string, v ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", v...) }