go.mway.dev/chrono@v0.6.1-0.20240126030049-189c5aef20d2/clock/fake_clock.go (about)

     1  // Copyright (c) 2023 Matt Way
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to
     5  // deal in the Software without restriction, including without limitation the
     6  // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
     7  // sell copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    18  // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
    19  // IN THE THE SOFTWARE.
    20  
    21  package clock
    22  
    23  import (
    24  	"sort"
    25  	"sync"
    26  	"time"
    27  
    28  	"go.uber.org/atomic"
    29  )
    30  
    31  var _ Clock = (*FakeClock)(nil)
    32  
    33  // A FakeClock is a manually-adjusted clock useful for mocking the flow of time
    34  // in tests. It does not keep time by itself: use [FakeClock.Add],
    35  // [FakeClock.SetTime], and related functions to manage the clock's time.
    36  type FakeClock struct {
    37  	timers []*fakeTimer
    38  	now    atomic.Int64
    39  	mu     sync.Mutex
    40  	clk    monotonicClock
    41  }
    42  
    43  // NewFakeClock creates a new [FakeClock].
    44  func NewFakeClock() *FakeClock {
    45  	c := &FakeClock{}
    46  	c.clk = monotonicClock{
    47  		fn: func() int64 {
    48  			return c.now.Load()
    49  		},
    50  	}
    51  	return c
    52  }
    53  
    54  // Add adds d to the clock's internal time.
    55  func (c *FakeClock) Add(d time.Duration) {
    56  	c.checkTimers(c.now.Add(int64(d)))
    57  }
    58  
    59  // After returns a channel that receives the current time after d has elapsed.
    60  func (c *FakeClock) After(d time.Duration) <-chan time.Time {
    61  	return c.addTimer(d, nil).ch
    62  }
    63  
    64  // AfterFunc returns a timer that will invoke the given function after d has
    65  // elapsed. The timer may be stopped and reset.
    66  func (c *FakeClock) AfterFunc(d time.Duration, fn func()) *Timer {
    67  	x := c.addTimer(d, fn)
    68  	return &Timer{
    69  		C:    x.ch,
    70  		fake: x,
    71  	}
    72  }
    73  
    74  // Nanotime returns the clock's internal time as integer nanoseconds.
    75  func (c *FakeClock) Nanotime() int64 {
    76  	return c.clk.Nanotime()
    77  }
    78  
    79  // NewTicker returns a new [Ticker] that receives time ticks every d. If d is not
    80  // greater than zero, [NewTicker] will panic.
    81  func (c *FakeClock) NewTicker(d time.Duration) *Ticker {
    82  	if d <= 0 {
    83  		panic("non-positive interval for FakeClock.NewTicker")
    84  	}
    85  
    86  	x := c.addTicker(d)
    87  	return &Ticker{
    88  		C:    x.ch,
    89  		fake: x,
    90  	}
    91  }
    92  
    93  // NewTimer returns a new [Timer] that receives a time tick after d.
    94  func (c *FakeClock) NewTimer(d time.Duration) *Timer {
    95  	x := c.addTimer(d, nil)
    96  	return &Timer{
    97  		C:    x.ch,
    98  		fake: x,
    99  	}
   100  }
   101  
   102  // Now returns the clock's internal time as a [time.Time].
   103  func (c *FakeClock) Now() time.Time {
   104  	return c.clk.Now()
   105  }
   106  
   107  // SetTime sets the clock's time to t.
   108  func (c *FakeClock) SetTime(t time.Time) {
   109  	c.SetNanotime(t.UnixNano())
   110  }
   111  
   112  // SetNanotime sets the clock's time to ns.
   113  func (c *FakeClock) SetNanotime(ns int64) {
   114  	c.now.Store(ns)
   115  	c.checkTimers(ns)
   116  }
   117  
   118  // Since returns the amount of time that elapsed between the clock's internal
   119  // time and t.
   120  func (c *FakeClock) Since(t time.Time) time.Duration {
   121  	return c.SinceNanotime(t.UnixNano())
   122  }
   123  
   124  // SinceNanotime returns the amount of time that elapsed between the clock's
   125  // internal time and ns.
   126  func (c *FakeClock) SinceNanotime(ns int64) time.Duration {
   127  	return time.Duration(c.Nanotime() - ns)
   128  }
   129  
   130  // Sleep blocks for d.
   131  //
   132  // Note that Sleep must be called from a different goroutine than the clock's
   133  // time is being managed on, or the program will deadlock.
   134  func (c *FakeClock) Sleep(d time.Duration) {
   135  	timer := c.addTimer(d, nil)
   136  	defer c.removeTimer(timer)
   137  	<-timer.ch
   138  }
   139  
   140  // NewStopwatch returns a new [Stopwatch] that uses the current clock for
   141  // measuring time. The clock's current time is used as the stopwatch's epoch.
   142  func (c *FakeClock) NewStopwatch() *Stopwatch {
   143  	return newStopwatch(c)
   144  }
   145  
   146  // Tick returns a new channel that receives time ticks every d. It is
   147  // equivalent to writing c.NewTicker(d).C(). The given duration must be greater
   148  // than 0.
   149  func (c *FakeClock) Tick(d time.Duration) <-chan time.Time {
   150  	if d < 0 {
   151  		panic("non-positive interval for FakeClock.Tick")
   152  	}
   153  	return c.NewTicker(d).C
   154  }
   155  
   156  func (c *FakeClock) addTicker(d time.Duration) *fakeTimer {
   157  	fake := newFakeTicker(c, d)
   158  
   159  	c.mu.Lock()
   160  	defer c.mu.Unlock()
   161  
   162  	c.timers = append(c.timers, fake)
   163  	c.sortTimersNosync()
   164  
   165  	return fake
   166  }
   167  
   168  func (c *FakeClock) addTimer(d time.Duration, fn func()) *fakeTimer {
   169  	fake := newFakeTimer(c, d, fn)
   170  
   171  	c.mu.Lock()
   172  	defer c.mu.Unlock()
   173  
   174  	c.timers = append(c.timers, fake)
   175  	c.sortTimersNosync()
   176  
   177  	return fake
   178  }
   179  
   180  func (c *FakeClock) checkTimers(now int64) {
   181  	c.mu.Lock()
   182  	defer c.mu.Unlock()
   183  
   184  	num := len(c.timers)
   185  	for i := 0; i < num; /* noincr */ {
   186  		if when := c.timers[i].when; when < 0 || when > now {
   187  			return
   188  		}
   189  
   190  		// This timer should tick. If it has a function, the function should be
   191  		// called in its own goroutine; otherwise, the channel should receive a
   192  		// tick.
   193  		if c.timers[i].fn != nil {
   194  			go c.timers[i].fn()
   195  		} else {
   196  			tick(c.timers[i].ch, c.timers[i].when)
   197  		}
   198  
   199  		// If this is a ticker, extend when by period.
   200  		if c.timers[i].period != 0 {
   201  			c.timers[i].when = now + c.timers[i].period
   202  			i++
   203  			continue
   204  		}
   205  
   206  		// Otherwise, remove the timer since it just fired.
   207  		if i < len(c.timers)-1 {
   208  			copy(c.timers[i:], c.timers[i+1:])
   209  		}
   210  
   211  		c.timers[num-1] = nil
   212  		c.timers = c.timers[:num-1]
   213  		num--
   214  	}
   215  }
   216  
   217  func (c *FakeClock) resetTimer(fake *fakeTimer, d time.Duration) bool {
   218  	now := fake.clk.Nanotime()
   219  
   220  	c.mu.Lock()
   221  	defer c.mu.Unlock()
   222  
   223  	// Check if the timer exists using its previous value.
   224  	pos := c.insertPosNosync(fake.when)
   225  
   226  	fake.when = now + int64(d)
   227  	if fake.period != 0 {
   228  		fake.period = int64(d)
   229  	}
   230  
   231  	// The timer doesn't exist; insert it into its new position based on the
   232  	// current time and given duration.
   233  	if n := len(c.timers); n == 0 || pos >= n || c.timers[pos] != fake {
   234  		c.timers = append(c.timers, fake)
   235  		c.sortTimersNosync()
   236  		return false
   237  	}
   238  
   239  	return true
   240  }
   241  
   242  func (c *FakeClock) removeTimer(fake *fakeTimer) bool {
   243  	c.mu.Lock()
   244  	defer c.mu.Unlock()
   245  
   246  	if len(c.timers) == 0 {
   247  		return false
   248  	}
   249  
   250  	// Find a candidate timer's index based on the original expiration.
   251  	pos := c.insertPosNosync(fake.when)
   252  
   253  	// Pathological case: insertPosNosync will always return {0,2} for [2]
   254  	if pos > 0 && len(c.timers) == 2 {
   255  		pos--
   256  	}
   257  
   258  	// Ensure that this is the expected timer.
   259  	if pos >= len(c.timers) || c.timers[pos] != fake {
   260  		return false
   261  	}
   262  
   263  	if pos < len(c.timers)-1 {
   264  		copy(c.timers[pos:], c.timers[pos+1:])
   265  	}
   266  	c.timers = c.timers[:len(c.timers)-1]
   267  
   268  	return true
   269  }
   270  
   271  func (c *FakeClock) insertPosNosync(when int64) int {
   272  	// Inline the stdlib search for parity. Ref:
   273  	// https://cs.opensource.google/go/go/+/refs/tags/go1.18.1:src/sort/search.go;l=59-74
   274  	i, j := 0, len(c.timers)
   275  	for i < j {
   276  		h := int(uint(i+j) >> 1)
   277  		if cur := c.timers[i].when; cur >= 0 && cur < when {
   278  			i = h + 1
   279  		} else {
   280  			j = h
   281  		}
   282  	}
   283  
   284  	return i
   285  }
   286  
   287  func (c *FakeClock) sortTimersNosync() {
   288  	sort.Slice(c.timers, func(i int, j int) bool {
   289  		a, b := c.timers[i], c.timers[j]
   290  		return b.when < 0 || (a.when >= 0 && a.when < b.when)
   291  	})
   292  }
   293  
   294  type fakeTimer struct {
   295  	clk    *FakeClock
   296  	ch     chan time.Time
   297  	fn     func() // timer only
   298  	when   int64  // timer expiration or next tick
   299  	period int64  // ticker only
   300  }
   301  
   302  func newFakeTimer(clk *FakeClock, d time.Duration, fn func()) *fakeTimer {
   303  	return &fakeTimer{
   304  		clk:  clk,
   305  		ch:   make(chan time.Time, 1),
   306  		fn:   fn,
   307  		when: clk.Nanotime() + int64(d),
   308  	}
   309  }
   310  
   311  func newFakeTicker(clk *FakeClock, d time.Duration) *fakeTimer {
   312  	return &fakeTimer{
   313  		clk:    clk,
   314  		ch:     make(chan time.Time, 1),
   315  		when:   clk.Nanotime() + int64(d),
   316  		period: int64(d),
   317  	}
   318  }
   319  
   320  func (f *fakeTimer) resetTimer(d time.Duration) bool {
   321  	return f.clk.resetTimer(f, d)
   322  }
   323  
   324  func (f *fakeTimer) removeTimer() bool {
   325  	return f.clk.removeTimer(f)
   326  }
   327  
   328  func tick(ch chan time.Time, ns int64) {
   329  	select {
   330  	case ch <- time.Unix(0, ns):
   331  	default:
   332  	}
   333  }