golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/testsync.go (about)

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package http2
     5  
     6  import (
     7  	"context"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  // testSyncHooks coordinates goroutines in tests.
    13  //
    14  // For example, a call to ClientConn.RoundTrip involves several goroutines, including:
    15  //   - the goroutine running RoundTrip;
    16  //   - the clientStream.doRequest goroutine, which writes the request; and
    17  //   - the clientStream.readLoop goroutine, which reads the response.
    18  //
    19  // Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
    20  // are blocked waiting for some condition such as reading the Request.Body or waiting for
    21  // flow control to become available.
    22  //
    23  // The testSyncHooks also manage timers and synthetic time in tests.
    24  // This permits us to, for example, start a request and cause it to time out waiting for
    25  // response headers without resorting to time.Sleep calls.
    26  type testSyncHooks struct {
    27  	// active/inactive act as a mutex and condition variable.
    28  	//
    29  	//  - neither chan contains a value: testSyncHooks is locked.
    30  	//  - active contains a value: unlocked, and at least one goroutine is not blocked
    31  	//  - inactive contains a value: unlocked, and all goroutines are blocked
    32  	active   chan struct{}
    33  	inactive chan struct{}
    34  
    35  	// goroutine counts
    36  	total    int                     // total goroutines
    37  	condwait map[*sync.Cond]int      // blocked in sync.Cond.Wait
    38  	blocked  []*testBlockedGoroutine // otherwise blocked
    39  
    40  	// fake time
    41  	now    time.Time
    42  	timers []*fakeTimer
    43  
    44  	// Transport testing: Report various events.
    45  	newclientconn func(*ClientConn)
    46  	newstream     func(*clientStream)
    47  }
    48  
    49  // testBlockedGoroutine is a blocked goroutine.
    50  type testBlockedGoroutine struct {
    51  	f  func() bool   // blocked until f returns true
    52  	ch chan struct{} // closed when unblocked
    53  }
    54  
    55  func newTestSyncHooks() *testSyncHooks {
    56  	h := &testSyncHooks{
    57  		active:   make(chan struct{}, 1),
    58  		inactive: make(chan struct{}, 1),
    59  		condwait: map[*sync.Cond]int{},
    60  	}
    61  	h.inactive <- struct{}{}
    62  	h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
    63  	return h
    64  }
    65  
    66  // lock acquires the testSyncHooks mutex.
    67  func (h *testSyncHooks) lock() {
    68  	select {
    69  	case <-h.active:
    70  	case <-h.inactive:
    71  	}
    72  }
    73  
    74  // waitInactive waits for all goroutines to become inactive.
    75  func (h *testSyncHooks) waitInactive() {
    76  	for {
    77  		<-h.inactive
    78  		if !h.unlock() {
    79  			break
    80  		}
    81  	}
    82  }
    83  
    84  // unlock releases the testSyncHooks mutex.
    85  // It reports whether any goroutines are active.
    86  func (h *testSyncHooks) unlock() (active bool) {
    87  	// Look for a blocked goroutine which can be unblocked.
    88  	blocked := h.blocked[:0]
    89  	unblocked := false
    90  	for _, b := range h.blocked {
    91  		if !unblocked && b.f() {
    92  			unblocked = true
    93  			close(b.ch)
    94  		} else {
    95  			blocked = append(blocked, b)
    96  		}
    97  	}
    98  	h.blocked = blocked
    99  
   100  	// Count goroutines blocked on condition variables.
   101  	condwait := 0
   102  	for _, count := range h.condwait {
   103  		condwait += count
   104  	}
   105  
   106  	if h.total > condwait+len(blocked) {
   107  		h.active <- struct{}{}
   108  		return true
   109  	} else {
   110  		h.inactive <- struct{}{}
   111  		return false
   112  	}
   113  }
   114  
   115  // goRun starts a new goroutine.
   116  func (h *testSyncHooks) goRun(f func()) {
   117  	h.lock()
   118  	h.total++
   119  	h.unlock()
   120  	go func() {
   121  		defer func() {
   122  			h.lock()
   123  			h.total--
   124  			h.unlock()
   125  		}()
   126  		f()
   127  	}()
   128  }
   129  
   130  // blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
   131  // It waits until f returns true before proceeding.
   132  //
   133  // Example usage:
   134  //
   135  //	h.blockUntil(func() bool {
   136  //		// Is the context done yet?
   137  //		select {
   138  //		case <-ctx.Done():
   139  //		default:
   140  //			return false
   141  //		}
   142  //		return true
   143  //	})
   144  //	// Wait for the context to become done.
   145  //	<-ctx.Done()
   146  //
   147  // The function f passed to blockUntil must be non-blocking and idempotent.
   148  func (h *testSyncHooks) blockUntil(f func() bool) {
   149  	if f() {
   150  		return
   151  	}
   152  	ch := make(chan struct{})
   153  	h.lock()
   154  	h.blocked = append(h.blocked, &testBlockedGoroutine{
   155  		f:  f,
   156  		ch: ch,
   157  	})
   158  	h.unlock()
   159  	<-ch
   160  }
   161  
   162  // broadcast is sync.Cond.Broadcast.
   163  func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
   164  	h.lock()
   165  	delete(h.condwait, cond)
   166  	h.unlock()
   167  	cond.Broadcast()
   168  }
   169  
   170  // broadcast is sync.Cond.Wait.
   171  func (h *testSyncHooks) condWait(cond *sync.Cond) {
   172  	h.lock()
   173  	h.condwait[cond]++
   174  	h.unlock()
   175  }
   176  
   177  // newTimer creates a new fake timer.
   178  func (h *testSyncHooks) newTimer(d time.Duration) timer {
   179  	h.lock()
   180  	defer h.unlock()
   181  	t := &fakeTimer{
   182  		hooks: h,
   183  		when:  h.now.Add(d),
   184  		c:     make(chan time.Time),
   185  	}
   186  	h.timers = append(h.timers, t)
   187  	return t
   188  }
   189  
   190  // afterFunc creates a new fake AfterFunc timer.
   191  func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
   192  	h.lock()
   193  	defer h.unlock()
   194  	t := &fakeTimer{
   195  		hooks: h,
   196  		when:  h.now.Add(d),
   197  		f:     f,
   198  	}
   199  	h.timers = append(h.timers, t)
   200  	return t
   201  }
   202  
   203  func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
   204  	ctx, cancel := context.WithCancel(ctx)
   205  	t := h.afterFunc(d, cancel)
   206  	return ctx, func() {
   207  		t.Stop()
   208  		cancel()
   209  	}
   210  }
   211  
   212  func (h *testSyncHooks) timeUntilEvent() time.Duration {
   213  	h.lock()
   214  	defer h.unlock()
   215  	var next time.Time
   216  	for _, t := range h.timers {
   217  		if next.IsZero() || t.when.Before(next) {
   218  			next = t.when
   219  		}
   220  	}
   221  	if d := next.Sub(h.now); d > 0 {
   222  		return d
   223  	}
   224  	return 0
   225  }
   226  
   227  // advance advances time and causes synthetic timers to fire.
   228  func (h *testSyncHooks) advance(d time.Duration) {
   229  	h.lock()
   230  	defer h.unlock()
   231  	h.now = h.now.Add(d)
   232  	timers := h.timers[:0]
   233  	for _, t := range h.timers {
   234  		t := t // remove after go.mod depends on go1.22
   235  		t.mu.Lock()
   236  		switch {
   237  		case t.when.After(h.now):
   238  			timers = append(timers, t)
   239  		case t.when.IsZero():
   240  			// stopped timer
   241  		default:
   242  			t.when = time.Time{}
   243  			if t.c != nil {
   244  				close(t.c)
   245  			}
   246  			if t.f != nil {
   247  				h.total++
   248  				go func() {
   249  					defer func() {
   250  						h.lock()
   251  						h.total--
   252  						h.unlock()
   253  					}()
   254  					t.f()
   255  				}()
   256  			}
   257  		}
   258  		t.mu.Unlock()
   259  	}
   260  	h.timers = timers
   261  }
   262  
   263  // A timer wraps a time.Timer, or a synthetic equivalent in tests.
   264  // Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
   265  type timer interface {
   266  	C() <-chan time.Time
   267  	Stop() bool
   268  	Reset(d time.Duration) bool
   269  }
   270  
   271  // timeTimer implements timer using real time.
   272  type timeTimer struct {
   273  	t *time.Timer
   274  	c chan time.Time
   275  }
   276  
   277  // newTimeTimer creates a new timer using real time.
   278  func newTimeTimer(d time.Duration) timer {
   279  	ch := make(chan time.Time)
   280  	t := time.AfterFunc(d, func() {
   281  		close(ch)
   282  	})
   283  	return &timeTimer{t, ch}
   284  }
   285  
   286  // newTimeAfterFunc creates an AfterFunc timer using real time.
   287  func newTimeAfterFunc(d time.Duration, f func()) timer {
   288  	return &timeTimer{
   289  		t: time.AfterFunc(d, f),
   290  	}
   291  }
   292  
   293  func (t timeTimer) C() <-chan time.Time        { return t.c }
   294  func (t timeTimer) Stop() bool                 { return t.t.Stop() }
   295  func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
   296  
   297  // fakeTimer implements timer using fake time.
   298  type fakeTimer struct {
   299  	hooks *testSyncHooks
   300  
   301  	mu   sync.Mutex
   302  	when time.Time      // when the timer will fire
   303  	c    chan time.Time // closed when the timer fires; mutually exclusive with f
   304  	f    func()         // called when the timer fires; mutually exclusive with c
   305  }
   306  
   307  func (t *fakeTimer) C() <-chan time.Time { return t.c }
   308  
   309  func (t *fakeTimer) Stop() bool {
   310  	t.mu.Lock()
   311  	defer t.mu.Unlock()
   312  	stopped := t.when.IsZero()
   313  	t.when = time.Time{}
   314  	return stopped
   315  }
   316  
   317  func (t *fakeTimer) Reset(d time.Duration) bool {
   318  	if t.c != nil || t.f == nil {
   319  		panic("fakeTimer only supports Reset on AfterFunc timers")
   320  	}
   321  	t.mu.Lock()
   322  	defer t.mu.Unlock()
   323  	t.hooks.lock()
   324  	defer t.hooks.unlock()
   325  	active := !t.when.IsZero()
   326  	t.when = t.hooks.now.Add(d)
   327  	if !active {
   328  		t.hooks.timers = append(t.hooks.timers, t)
   329  	}
   330  	return active
   331  }