github.com/v2fly/tools@v0.100.0/internal/memoize/memoize.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package memoize supports memoizing the return values of functions with
     6  // idempotent results that are expensive to compute.
     7  //
     8  // To use this package, build a store and use it to acquire handles with the
     9  // Bind method.
    10  //
    11  package memoize
    12  
    13  import (
    14  	"context"
    15  	"flag"
    16  	"fmt"
    17  	"reflect"
    18  	"sync"
    19  	"sync/atomic"
    20  
    21  	"github.com/v2fly/tools/internal/xcontext"
    22  )
    23  
    24  var (
    25  	panicOnDestroyed = flag.Bool("memoize_panic_on_destroyed", false,
    26  		"Panic when a destroyed generation is read rather than returning an error. "+
    27  			"Panicking may make it easier to debug lifetime errors, especially when "+
    28  			"used with GOTRACEBACK=crash to see all running goroutines.")
    29  )
    30  
    31  // Store binds keys to functions, returning handles that can be used to access
    32  // the functions results.
    33  type Store struct {
    34  	mu sync.Mutex
    35  	// handles is the set of values stored.
    36  	handles map[interface{}]*Handle
    37  
    38  	// generations is the set of generations live in this store.
    39  	generations map[*Generation]struct{}
    40  }
    41  
    42  // Generation creates a new Generation associated with s. Destroy must be
    43  // called on the returned Generation once it is no longer in use. name is
    44  // for debugging purposes only.
    45  func (s *Store) Generation(name string) *Generation {
    46  	s.mu.Lock()
    47  	defer s.mu.Unlock()
    48  	if s.handles == nil {
    49  		s.handles = map[interface{}]*Handle{}
    50  		s.generations = map[*Generation]struct{}{}
    51  	}
    52  	g := &Generation{store: s, name: name}
    53  	s.generations[g] = struct{}{}
    54  	return g
    55  }
    56  
    57  // A Generation is a logical point in time of the cache life-cycle. Cache
    58  // entries associated with a Generation will not be removed until the
    59  // Generation is destroyed.
    60  type Generation struct {
    61  	// destroyed is 1 after the generation is destroyed. Atomic.
    62  	destroyed uint32
    63  	store     *Store
    64  	name      string
    65  	// wg tracks the reference count of this generation.
    66  	wg sync.WaitGroup
    67  }
    68  
    69  // Destroy waits for all operations referencing g to complete, then removes
    70  // all references to g from cache entries. Cache entries that no longer
    71  // reference any non-destroyed generation are removed. Destroy must be called
    72  // exactly once for each generation.
    73  func (g *Generation) Destroy() {
    74  	g.wg.Wait()
    75  	atomic.StoreUint32(&g.destroyed, 1)
    76  	g.store.mu.Lock()
    77  	defer g.store.mu.Unlock()
    78  	for k, e := range g.store.handles {
    79  		e.mu.Lock()
    80  		if _, ok := e.generations[g]; ok {
    81  			delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry.
    82  			if len(e.generations) == 0 {
    83  				delete(g.store.handles, k)
    84  				e.state = stateDestroyed
    85  				if e.cleanup != nil && e.value != nil {
    86  					e.cleanup(e.value)
    87  				}
    88  			}
    89  		}
    90  		e.mu.Unlock()
    91  	}
    92  	delete(g.store.generations, g)
    93  }
    94  
    95  // Acquire creates a new reference to g, and returns a func to release that
    96  // reference.
    97  func (g *Generation) Acquire(ctx context.Context) func() {
    98  	destroyed := atomic.LoadUint32(&g.destroyed)
    99  	if ctx.Err() != nil {
   100  		return func() {}
   101  	}
   102  	if destroyed != 0 {
   103  		panic("acquire on destroyed generation " + g.name)
   104  	}
   105  	g.wg.Add(1)
   106  	return g.wg.Done
   107  }
   108  
   109  // Arg is a marker interface that can be embedded to indicate a type is
   110  // intended for use as a Function argument.
   111  type Arg interface{ memoizeArg() }
   112  
   113  // Function is the type for functions that can be memoized.
   114  // The result must be a pointer.
   115  type Function func(ctx context.Context, arg Arg) interface{}
   116  
   117  type state int
   118  
   119  const (
   120  	stateIdle = iota
   121  	stateRunning
   122  	stateCompleted
   123  	stateDestroyed
   124  )
   125  
   126  // Handle is returned from a store when a key is bound to a function.
   127  // It is then used to access the results of that function.
   128  //
   129  // A Handle starts out in idle state, waiting for something to demand its
   130  // evaluation. It then transitions into running state. While it's running,
   131  // waiters tracks the number of Get calls waiting for a result, and the done
   132  // channel is used to notify waiters of the next state transition. Once the
   133  // evaluation finishes, value is set, state changes to completed, and done
   134  // is closed, unblocking waiters. Alternatively, as Get calls are cancelled,
   135  // they decrement waiters. If it drops to zero, the inner context is cancelled,
   136  // computation is abandoned, and state resets to idle to start the process over
   137  // again.
   138  type Handle struct {
   139  	key interface{}
   140  	mu  sync.Mutex
   141  
   142  	// generations is the set of generations in which this handle is valid.
   143  	generations map[*Generation]struct{}
   144  
   145  	state state
   146  	// done is set in running state, and closed when exiting it.
   147  	done chan struct{}
   148  	// cancel is set in running state. It cancels computation.
   149  	cancel context.CancelFunc
   150  	// waiters is the number of Gets outstanding.
   151  	waiters uint
   152  	// the function that will be used to populate the value
   153  	function Function
   154  	// value is set in completed state.
   155  	value interface{}
   156  	// cleanup, if non-nil, is used to perform any necessary clean-up on values
   157  	// produced by function.
   158  	cleanup func(interface{})
   159  }
   160  
   161  // Bind returns a handle for the given key and function.
   162  //
   163  // Each call to bind will return the same handle if it is already bound. Bind
   164  // will always return a valid handle, creating one if needed. Each key can
   165  // only have one handle at any given time. The value will be held at least
   166  // until the associated generation is destroyed. Bind does not cause the value
   167  // to be generated.
   168  //
   169  // If cleanup is non-nil, it will be called on any non-nil values produced by
   170  // function when they are no longer referenced.
   171  func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle {
   172  	// panic early if the function is nil
   173  	// it would panic later anyway, but in a way that was much harder to debug
   174  	if function == nil {
   175  		panic("the function passed to bind must not be nil")
   176  	}
   177  	if atomic.LoadUint32(&g.destroyed) != 0 {
   178  		panic("operation on destroyed generation " + g.name)
   179  	}
   180  	g.store.mu.Lock()
   181  	defer g.store.mu.Unlock()
   182  	h, ok := g.store.handles[key]
   183  	if !ok {
   184  		h := &Handle{
   185  			key:         key,
   186  			function:    function,
   187  			generations: map[*Generation]struct{}{g: {}},
   188  			cleanup:     cleanup,
   189  		}
   190  		g.store.handles[key] = h
   191  		return h
   192  	}
   193  	h.mu.Lock()
   194  	defer h.mu.Unlock()
   195  	if _, ok := h.generations[g]; !ok {
   196  		h.generations[g] = struct{}{}
   197  	}
   198  	return h
   199  }
   200  
   201  // Stats returns the number of each type of value in the store.
   202  func (s *Store) Stats() map[reflect.Type]int {
   203  	s.mu.Lock()
   204  	defer s.mu.Unlock()
   205  
   206  	result := map[reflect.Type]int{}
   207  	for k := range s.handles {
   208  		result[reflect.TypeOf(k)]++
   209  	}
   210  	return result
   211  }
   212  
   213  // DebugOnlyIterate iterates through all live cache entries and calls f on them.
   214  // It should only be used for debugging purposes.
   215  func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
   216  	s.mu.Lock()
   217  	defer s.mu.Unlock()
   218  
   219  	for k, e := range s.handles {
   220  		var v interface{}
   221  		e.mu.Lock()
   222  		if e.state == stateCompleted {
   223  			v = e.value
   224  		}
   225  		e.mu.Unlock()
   226  		if v == nil {
   227  			continue
   228  		}
   229  		f(k, v)
   230  	}
   231  }
   232  
   233  func (g *Generation) Inherit(hs ...*Handle) {
   234  	for _, h := range hs {
   235  		if atomic.LoadUint32(&g.destroyed) != 0 {
   236  			panic("inherit on destroyed generation " + g.name)
   237  		}
   238  
   239  		h.mu.Lock()
   240  		defer h.mu.Unlock()
   241  		if h.state == stateDestroyed {
   242  			panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
   243  		}
   244  		h.generations[g] = struct{}{}
   245  	}
   246  }
   247  
   248  // Cached returns the value associated with a handle.
   249  //
   250  // It will never cause the value to be generated.
   251  // It will return the cached value, if present.
   252  func (h *Handle) Cached(g *Generation) interface{} {
   253  	h.mu.Lock()
   254  	defer h.mu.Unlock()
   255  	if _, ok := h.generations[g]; !ok {
   256  		return nil
   257  	}
   258  	if h.state == stateCompleted {
   259  		return h.value
   260  	}
   261  	return nil
   262  }
   263  
   264  // Get returns the value associated with a handle.
   265  //
   266  // If the value is not yet ready, the underlying function will be invoked.
   267  // If ctx is cancelled, Get returns nil.
   268  func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
   269  	release := g.Acquire(ctx)
   270  	defer release()
   271  
   272  	if ctx.Err() != nil {
   273  		return nil, ctx.Err()
   274  	}
   275  	h.mu.Lock()
   276  	if _, ok := h.generations[g]; !ok {
   277  		h.mu.Unlock()
   278  
   279  		err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name)
   280  		if *panicOnDestroyed && ctx.Err() != nil {
   281  			panic(err)
   282  		}
   283  		return nil, err
   284  	}
   285  	switch h.state {
   286  	case stateIdle:
   287  		return h.run(ctx, g, arg)
   288  	case stateRunning:
   289  		return h.wait(ctx)
   290  	case stateCompleted:
   291  		defer h.mu.Unlock()
   292  		return h.value, nil
   293  	case stateDestroyed:
   294  		h.mu.Unlock()
   295  		err := fmt.Errorf("Get on destroyed entry %#v (type %T) in generation %v", h.key, h.key, g.name)
   296  		if *panicOnDestroyed {
   297  			panic(err)
   298  		}
   299  		return nil, err
   300  	default:
   301  		panic("unknown state")
   302  	}
   303  }
   304  
   305  // run starts h.function and returns the result. h.mu must be locked.
   306  func (h *Handle) run(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
   307  	childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
   308  	h.cancel = cancel
   309  	h.state = stateRunning
   310  	h.done = make(chan struct{})
   311  	function := h.function // Read under the lock
   312  
   313  	// Make sure that the generation isn't destroyed while we're running in it.
   314  	release := g.Acquire(ctx)
   315  	go func() {
   316  		defer release()
   317  		// Just in case the function does something expensive without checking
   318  		// the context, double-check we're still alive.
   319  		if childCtx.Err() != nil {
   320  			return
   321  		}
   322  		v := function(childCtx, arg)
   323  		if childCtx.Err() != nil {
   324  			// It's possible that v was computed despite the context cancellation. In
   325  			// this case we should ensure that it is cleaned up.
   326  			if h.cleanup != nil && v != nil {
   327  				h.cleanup(v)
   328  			}
   329  			return
   330  		}
   331  
   332  		h.mu.Lock()
   333  		defer h.mu.Unlock()
   334  		// It's theoretically possible that the handle has been cancelled out
   335  		// of the run that started us, and then started running again since we
   336  		// checked childCtx above. Even so, that should be harmless, since each
   337  		// run should produce the same results.
   338  		if h.state != stateRunning {
   339  			// v will never be used, so ensure that it is cleaned up.
   340  			if h.cleanup != nil && v != nil {
   341  				h.cleanup(v)
   342  			}
   343  			return
   344  		}
   345  		// At this point v will be cleaned up whenever h is destroyed.
   346  		h.value = v
   347  		h.function = nil
   348  		h.state = stateCompleted
   349  		close(h.done)
   350  	}()
   351  
   352  	return h.wait(ctx)
   353  }
   354  
   355  // wait waits for the value to be computed, or ctx to be cancelled. h.mu must be locked.
   356  func (h *Handle) wait(ctx context.Context) (interface{}, error) {
   357  	h.waiters++
   358  	done := h.done
   359  	h.mu.Unlock()
   360  
   361  	select {
   362  	case <-done:
   363  		h.mu.Lock()
   364  		defer h.mu.Unlock()
   365  		if h.state == stateCompleted {
   366  			return h.value, nil
   367  		}
   368  		return nil, nil
   369  	case <-ctx.Done():
   370  		h.mu.Lock()
   371  		defer h.mu.Unlock()
   372  		h.waiters--
   373  		if h.waiters == 0 && h.state == stateRunning {
   374  			h.cancel()
   375  			close(h.done)
   376  			h.state = stateIdle
   377  			h.done = nil
   378  			h.cancel = nil
   379  		}
   380  		return nil, ctx.Err()
   381  	}
   382  }