google.golang.org/grpc@v1.72.2/internal/leakcheck/leakcheck.go (about)

     1  /*
     2   *
     3   * Copyright 2017 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package leakcheck contains functions to check leaked goroutines and buffers.
    20  //
    21  // Call the following at the beginning of test:
    22  //
    23  //	defer leakcheck.NewLeakChecker(t).Check()
    24  package leakcheck
    25  
    26  import (
    27  	"context"
    28  	"fmt"
    29  	"runtime"
    30  	"runtime/debug"
    31  	"slices"
    32  	"sort"
    33  	"strconv"
    34  	"strings"
    35  	"sync"
    36  	"sync/atomic"
    37  	"time"
    38  
    39  	"google.golang.org/grpc/internal"
    40  	"google.golang.org/grpc/mem"
    41  )
    42  
    43  // failTestsOnLeakedBuffers is a special flag that will cause tests to fail if
    44  // leaked buffers are detected, instead of simply logging them as an
    45  // informational failure. This can be enabled with the "checkbuffers" compile
    46  // flag, e.g.:
    47  //
    48  //	go test -tags=checkbuffers
    49  var failTestsOnLeakedBuffers = false
    50  
    51  func init() {
    52  	defaultPool := mem.DefaultBufferPool()
    53  	globalPool.Store(&defaultPool)
    54  	(internal.SetDefaultBufferPoolForTesting.(func(mem.BufferPool)))(&globalPool)
    55  }
    56  
    57  var globalPool swappableBufferPool
    58  var globalTimerTracker *timerFactory
    59  
    60  type swappableBufferPool struct {
    61  	atomic.Pointer[mem.BufferPool]
    62  }
    63  
    64  func (b *swappableBufferPool) Get(length int) *[]byte {
    65  	return (*b.Load()).Get(length)
    66  }
    67  
    68  func (b *swappableBufferPool) Put(buf *[]byte) {
    69  	(*b.Load()).Put(buf)
    70  }
    71  
    72  // SetTrackingBufferPool replaces the default buffer pool in the mem package to
    73  // one that tracks where buffers are allocated. CheckTrackingBufferPool should
    74  // then be invoked at the end of the test to validate that all buffers pulled
    75  // from the pool were returned.
    76  func SetTrackingBufferPool(logger Logger) {
    77  	newPool := mem.BufferPool(&trackingBufferPool{
    78  		pool:             *globalPool.Load(),
    79  		logger:           logger,
    80  		allocatedBuffers: make(map[*[]byte][]uintptr),
    81  	})
    82  	globalPool.Store(&newPool)
    83  }
    84  
    85  // CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails
    86  // unit tests if not all buffers were returned. It is invalid to invoke this
    87  // function without previously having invoked SetTrackingBufferPool.
    88  func CheckTrackingBufferPool() {
    89  	p := (*globalPool.Load()).(*trackingBufferPool)
    90  	p.lock.Lock()
    91  	defer p.lock.Unlock()
    92  
    93  	globalPool.Store(&p.pool)
    94  
    95  	type uniqueTrace struct {
    96  		stack []uintptr
    97  		count int
    98  	}
    99  
   100  	var totalLeakedBuffers int
   101  	var uniqueTraces []uniqueTrace
   102  	for _, stack := range p.allocatedBuffers {
   103  		idx, ok := slices.BinarySearchFunc(uniqueTraces, stack, func(trace uniqueTrace, stack []uintptr) int {
   104  			return slices.Compare(trace.stack, stack)
   105  		})
   106  		if !ok {
   107  			uniqueTraces = slices.Insert(uniqueTraces, idx, uniqueTrace{stack: stack})
   108  		}
   109  		uniqueTraces[idx].count++
   110  		totalLeakedBuffers++
   111  	}
   112  
   113  	for _, ut := range uniqueTraces {
   114  		frames := runtime.CallersFrames(ut.stack)
   115  		var trace strings.Builder
   116  		for {
   117  			f, ok := frames.Next()
   118  			if !ok {
   119  				break
   120  			}
   121  			trace.WriteString(f.Function)
   122  			trace.WriteString("\n\t")
   123  			trace.WriteString(f.File)
   124  			trace.WriteString(":")
   125  			trace.WriteString(strconv.Itoa(f.Line))
   126  			trace.WriteString("\n")
   127  		}
   128  		format := "%d allocated buffers never freed:\n%s"
   129  		args := []any{ut.count, trace.String()}
   130  		if failTestsOnLeakedBuffers {
   131  			p.logger.Errorf(format, args...)
   132  		} else {
   133  			p.logger.Logf("WARNING "+format, args...)
   134  		}
   135  	}
   136  
   137  	if totalLeakedBuffers > 0 {
   138  		p.logger.Logf("%g%% of buffers never freed", float64(totalLeakedBuffers)/float64(p.bufferCount))
   139  	}
   140  }
   141  
   142  type trackingBufferPool struct {
   143  	pool   mem.BufferPool
   144  	logger Logger
   145  
   146  	lock             sync.Mutex
   147  	bufferCount      int
   148  	allocatedBuffers map[*[]byte][]uintptr
   149  }
   150  
   151  func (p *trackingBufferPool) Get(length int) *[]byte {
   152  	p.lock.Lock()
   153  	defer p.lock.Unlock()
   154  	p.bufferCount++
   155  	buf := p.pool.Get(length)
   156  	p.allocatedBuffers[buf] = currentStack(2)
   157  	return buf
   158  }
   159  
   160  func (p *trackingBufferPool) Put(buf *[]byte) {
   161  	p.lock.Lock()
   162  	defer p.lock.Unlock()
   163  
   164  	if _, ok := p.allocatedBuffers[buf]; !ok {
   165  		p.logger.Errorf("Unknown buffer freed:\n%s", string(debug.Stack()))
   166  	} else {
   167  		delete(p.allocatedBuffers, buf)
   168  	}
   169  	p.pool.Put(buf)
   170  }
   171  
   172  var goroutinesToIgnore = []string{
   173  	"testing.Main(",
   174  	"testing.tRunner(",
   175  	"testing.(*M).",
   176  	"runtime.goexit",
   177  	"created by runtime.gc",
   178  	"created by runtime/trace.Start",
   179  	"interestingGoroutines",
   180  	"runtime.MHeap_Scavenger",
   181  	"signal.signal_recv",
   182  	"sigterm.handler",
   183  	"runtime_mcall",
   184  	"(*loggingT).flushDaemon",
   185  	"goroutine in C code",
   186  	// Ignore the http read/write goroutines. gce metadata.OnGCE() was leaking
   187  	// these, root cause unknown.
   188  	//
   189  	// https://github.com/grpc/grpc-go/issues/5171
   190  	// https://github.com/grpc/grpc-go/issues/5173
   191  	"created by net/http.(*Transport).dialConn",
   192  }
   193  
   194  // RegisterIgnoreGoroutine appends s into the ignore goroutine list. The
   195  // goroutines whose stack trace contains s will not be identified as leaked
   196  // goroutines. Not thread-safe, only call this function in init().
   197  func RegisterIgnoreGoroutine(s string) {
   198  	goroutinesToIgnore = append(goroutinesToIgnore, s)
   199  }
   200  
   201  func ignore(g string) bool {
   202  	sl := strings.SplitN(g, "\n", 2)
   203  	if len(sl) != 2 {
   204  		return true
   205  	}
   206  	stack := strings.TrimSpace(sl[1])
   207  	if strings.HasPrefix(stack, "testing.RunTests") {
   208  		return true
   209  	}
   210  
   211  	if stack == "" {
   212  		return true
   213  	}
   214  
   215  	for _, s := range goroutinesToIgnore {
   216  		if strings.Contains(stack, s) {
   217  			return true
   218  		}
   219  	}
   220  
   221  	return false
   222  }
   223  
   224  // interestingGoroutines returns all goroutines we care about for the purpose of
   225  // leak checking. It excludes testing or runtime ones.
   226  func interestingGoroutines() (gs []string) {
   227  	buf := make([]byte, 2<<20)
   228  	buf = buf[:runtime.Stack(buf, true)]
   229  	for _, g := range strings.Split(string(buf), "\n\n") {
   230  		if !ignore(g) {
   231  			gs = append(gs, g)
   232  		}
   233  	}
   234  	sort.Strings(gs)
   235  	return
   236  }
   237  
   238  // Logger is the interface that wraps the Logf and Errorf method. It's a subset
   239  // of testing.TB to make it easy to use this package.
   240  type Logger interface {
   241  	Logf(format string, args ...any)
   242  	Errorf(format string, args ...any)
   243  }
   244  
   245  // CheckGoroutines looks at the currently-running goroutines and checks if there
   246  // are any interesting (created by gRPC) goroutines leaked. It waits up to 10
   247  // seconds in the error cases.
   248  func CheckGoroutines(ctx context.Context, logger Logger) {
   249  	// Loop, waiting for goroutines to shut down.
   250  	// Wait up to timeout, but finish as quickly as possible.
   251  	var leaked []string
   252  	for ctx.Err() == nil {
   253  		if leaked = interestingGoroutines(); len(leaked) == 0 {
   254  			return
   255  		}
   256  		time.Sleep(50 * time.Millisecond)
   257  	}
   258  	for _, g := range leaked {
   259  		logger.Errorf("Leaked goroutine: %v", g)
   260  	}
   261  }
   262  
   263  // LeakChecker captures a Logger and is returned by NewLeakChecker as a
   264  // convenient method to set up leak check tests in a unit test.
   265  type LeakChecker struct {
   266  	logger Logger
   267  }
   268  
   269  // NewLeakChecker offers a convenient way to set up the leak checks for a
   270  // specific unit test. It can be used as follows, at the beginning of tests:
   271  //
   272  //	defer leakcheck.NewLeakChecker(t).Check()
   273  //
   274  // It initially invokes SetTrackingBufferPool to set up buffer tracking, then the
   275  // deferred LeakChecker.Check call will invoke CheckTrackingBufferPool and
   276  // CheckGoroutines with a default timeout of 10 seconds.
   277  func NewLeakChecker(logger Logger) *LeakChecker {
   278  	SetTrackingBufferPool(logger)
   279  	return &LeakChecker{logger: logger}
   280  }
   281  
   282  type timerFactory struct {
   283  	mu              sync.Mutex
   284  	allocatedTimers map[internal.Timer][]uintptr
   285  }
   286  
   287  func (tf *timerFactory) timeAfterFunc(d time.Duration, f func()) internal.Timer {
   288  	tf.mu.Lock()
   289  	defer tf.mu.Unlock()
   290  	ch := make(chan internal.Timer, 1)
   291  	timer := time.AfterFunc(d, func() {
   292  		f()
   293  		tf.remove(<-ch)
   294  	})
   295  	ch <- timer
   296  	tf.allocatedTimers[timer] = currentStack(2)
   297  	return &trackingTimer{
   298  		Timer:  timer,
   299  		parent: tf,
   300  	}
   301  }
   302  
   303  func (tf *timerFactory) remove(timer internal.Timer) {
   304  	tf.mu.Lock()
   305  	defer tf.mu.Unlock()
   306  	delete(tf.allocatedTimers, timer)
   307  }
   308  
   309  func (tf *timerFactory) pendingTimers() []string {
   310  	tf.mu.Lock()
   311  	defer tf.mu.Unlock()
   312  	leaked := []string{}
   313  	for _, stack := range tf.allocatedTimers {
   314  		leaked = append(leaked, fmt.Sprintf("Allocated timer never cancelled:\n%s", traceToString(stack)))
   315  	}
   316  	return leaked
   317  }
   318  
   319  type trackingTimer struct {
   320  	internal.Timer
   321  	parent *timerFactory
   322  }
   323  
   324  func (t *trackingTimer) Stop() bool {
   325  	t.parent.remove(t.Timer)
   326  	return t.Timer.Stop()
   327  }
   328  
   329  // TrackTimers replaces internal.TimerAfterFunc with one that tracks timer
   330  // creations, stoppages and expirations. CheckTimers should then be invoked at
   331  // the end of the test to validate that all timers created have either executed
   332  // or are cancelled.
   333  func TrackTimers() {
   334  	globalTimerTracker = &timerFactory{
   335  		allocatedTimers: make(map[internal.Timer][]uintptr),
   336  	}
   337  	internal.TimeAfterFunc = globalTimerTracker.timeAfterFunc
   338  }
   339  
   340  // CheckTimers undoes the effects of TrackTimers, and fails unit tests if not
   341  // all timers were cancelled or executed. It is invalid to invoke this function
   342  // without previously having invoked TrackTimers.
   343  func CheckTimers(ctx context.Context, logger Logger) {
   344  	tt := globalTimerTracker
   345  
   346  	// Loop, waiting for timers to be cancelled.
   347  	// Wait up to timeout, but finish as quickly as possible.
   348  	var leaked []string
   349  	for ctx.Err() == nil {
   350  		if leaked = tt.pendingTimers(); len(leaked) == 0 {
   351  			return
   352  		}
   353  		time.Sleep(50 * time.Millisecond)
   354  	}
   355  	for _, g := range leaked {
   356  		logger.Errorf("Leaked timers: %v", g)
   357  	}
   358  
   359  	// Reset the internal function.
   360  	internal.TimeAfterFunc = func(d time.Duration, f func()) internal.Timer {
   361  		return time.AfterFunc(d, f)
   362  	}
   363  }
   364  
   365  func currentStack(skip int) []uintptr {
   366  	var stackBuf [16]uintptr
   367  	var stack []uintptr
   368  	skip++
   369  	for {
   370  		n := runtime.Callers(skip, stackBuf[:])
   371  		stack = append(stack, stackBuf[:n]...)
   372  		if n < len(stackBuf) {
   373  			break
   374  		}
   375  		skip += len(stackBuf)
   376  	}
   377  	return stack
   378  }
   379  
   380  func traceToString(stack []uintptr) string {
   381  	frames := runtime.CallersFrames(stack)
   382  	var trace strings.Builder
   383  	for {
   384  		f, ok := frames.Next()
   385  		if !ok {
   386  			break
   387  		}
   388  		trace.WriteString(f.Function)
   389  		trace.WriteString("\n\t")
   390  		trace.WriteString(f.File)
   391  		trace.WriteString(":")
   392  		trace.WriteString(strconv.Itoa(f.Line))
   393  		trace.WriteString("\n")
   394  	}
   395  	return trace.String()
   396  }