github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/utils/singleflight.go (about)

     1  package utils
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"runtime"
     8  	"runtime/debug"
     9  	"sync"
    10  )
    11  
    12  // Copyright 2013 The Go Authors. All rights reserved.
    13  // Use of this source code is governed by a BSD-style
    14  // license that can be found in the LICENSE file.
    15  
    16  // errGoexit indicates the runtime.Goexit was called in
    17  // the user given function.
    18  var errGoexit = errors.New("runtime.Goexit was called")
    19  
    20  // A panicError is an arbitrary value recovered from a panic
    21  // with the stack trace during the execution of given function.
    22  type panicError struct {
    23  	value any
    24  	stack []byte
    25  }
    26  
    27  // Error implements error interface.
    28  func (p *panicError) Error() string {
    29  	return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
    30  }
    31  
    32  func (p *panicError) Unwrap() error {
    33  	err, ok := p.value.(error)
    34  	if !ok {
    35  		return nil
    36  	}
    37  
    38  	return err
    39  }
    40  
    41  func newPanicError(v any) error {
    42  	stack := debug.Stack()
    43  
    44  	// The first line of the stack trace is of the form "goroutine N [status]:"
    45  	// but by the time the panic reaches Do the goroutine may no longer exist
    46  	// and its status will have changed. Trim out the misleading line.
    47  	if line := bytes.IndexByte(stack[:], '\n'); line >= 0 {
    48  		stack = stack[line+1:]
    49  	}
    50  
    51  	return &panicError{value: v, stack: stack}
    52  }
    53  
    54  // call is an in-flight or completed singleflight.Do call
    55  type call struct {
    56  	wg sync.WaitGroup
    57  
    58  	// These fields are written once before the WaitGroup is done
    59  	// and are only read after the WaitGroup is done.
    60  	val any
    61  	err error
    62  
    63  	// These fields are read and written with the singleflight
    64  	// mutex held before the WaitGroup is done, and are read but
    65  	// not written after the WaitGroup is done.
    66  	dups  int
    67  	chans []chan<- Result
    68  }
    69  
    70  // Result holds the results of Do, so they can be passed
    71  // on a channel.
    72  type Result struct {
    73  	Val    any
    74  	Err    error
    75  	Shared bool
    76  }
    77  
    78  type SingleFlightCall[T any] func(fn func() (T, error)) (T, error)
    79  
    80  // SingleFlight executes and returns the results of the given function, making
    81  // sure that only one execution is in-flight for a given key at a
    82  // time. If a duplicate comes in, the duplicate caller waits for the
    83  // original to complete and receives the same results.
    84  // The return value shared indicates whether v was given to multiple cal
    85  func SingleFlight[T any]() SingleFlightCall[T] {
    86  	mu := sync.Mutex{}
    87  	var (
    88  		c     *call
    89  		ready = true
    90  	)
    91  
    92  	doCall := genDoCall[T](&mu, &ready)
    93  
    94  	return func(fn func() (T, error)) (T, error) {
    95  		mu.Lock()
    96  		if ready {
    97  			ready = false
    98  			c = new(call)
    99  			c.wg.Add(1)
   100  			mu.Unlock()
   101  
   102  			doCall(c, fn)
   103  
   104  			return c.val.(T), c.err //nolint:forcetypeassert
   105  		}
   106  
   107  		c.dups++
   108  		mu.Unlock()
   109  		c.wg.Wait()
   110  
   111  		var e *panicError
   112  		if c.err != nil && errors.As(c.err, &e) {
   113  			panic(e)
   114  		}
   115  
   116  		if errors.Is(c.err, errGoexit) {
   117  			runtime.Goexit()
   118  		}
   119  
   120  		return c.val.(T), c.err //nolint:forcetypeassert
   121  	}
   122  }
   123  
   124  // doCall handles the single call for a key.
   125  //
   126  //nolint:gocognit
   127  func genDoCall[T any](mu *sync.Mutex, ready *bool) func(c *call, fn func() (T, error)) {
   128  	return func(c *call, fn func() (T, error)) {
   129  		normalReturn := false
   130  		recovered := false
   131  
   132  		// use double-defer to distinguish panic from runtime.Goexit,
   133  		// more details see https://golang.org/cl/134395
   134  		defer func() {
   135  			// the given function invoked runtime.Goexit
   136  			if !normalReturn && !recovered {
   137  				c.err = errGoexit
   138  			}
   139  
   140  			mu.Lock()
   141  			defer mu.Unlock()
   142  			c.wg.Done()
   143  			*ready = true
   144  
   145  			if e, ok := c.err.(*panicError); ok {
   146  				// To prevent the waiting channels from being blocked forever,
   147  				// needs to ensure that this panic cannot be recovered.
   148  				if len(c.chans) > 0 {
   149  					go panic(e)
   150  					select {} // Keep this goroutine around so that it will appear in the crash dump.
   151  				} else {
   152  					panic(e)
   153  				}
   154  			} else if errors.Is(c.err, errGoexit) {
   155  				// Already in the process of goexit, no need to call again
   156  			} else {
   157  				// Normal return
   158  				for _, ch := range c.chans {
   159  					ch <- Result{c.val, c.err, c.dups > 0}
   160  				}
   161  			}
   162  		}()
   163  
   164  		func() {
   165  			defer func() {
   166  				if !normalReturn {
   167  					// Ideally, we would wait to take a stack trace until we've determined
   168  					// whether this is a panic or a runtime.Goexit.
   169  					//
   170  					// Unfortunately, the only way we can distinguish the two is to see
   171  					// whether the recover stopped the goroutine from terminating, and by
   172  					// the time we know that, the part of the stack trace relevant to the
   173  					// panic has been discarded.
   174  					if r := recover(); r != nil {
   175  						c.err = newPanicError(r)
   176  					}
   177  				}
   178  			}()
   179  
   180  			c.val, c.err = fn()
   181  			normalReturn = true
   182  		}()
   183  
   184  		if !normalReturn {
   185  			recovered = true
   186  		}
   187  	}
   188  }