github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/memoize/memoize.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package memoize supports memoizing the return values of functions with
     6  // idempotent results that are expensive to compute.
     7  //
     8  // To use this package, build a store and use it to acquire handles with the
     9  // Bind method.
    10  //
    11  package memoize
    12  
    13  import (
    14  	"context"
    15  	"flag"
    16  	"fmt"
    17  	"reflect"
    18  	"sync"
    19  	"sync/atomic"
    20  
    21  	"github.com/jhump/golang-x-tools/internal/xcontext"
    22  )
    23  
    24  var (
    25  	panicOnDestroyed = flag.Bool("memoize_panic_on_destroyed", false,
    26  		"Panic when a destroyed generation is read rather than returning an error. "+
    27  			"Panicking may make it easier to debug lifetime errors, especially when "+
    28  			"used with GOTRACEBACK=crash to see all running goroutines.")
    29  )
    30  
    31  // Store binds keys to functions, returning handles that can be used to access
    32  // the functions results.
    33  type Store struct {
    34  	mu sync.Mutex
    35  	// handles is the set of values stored.
    36  	handles map[interface{}]*Handle
    37  
    38  	// generations is the set of generations live in this store.
    39  	generations map[*Generation]struct{}
    40  }
    41  
    42  // Generation creates a new Generation associated with s. Destroy must be
    43  // called on the returned Generation once it is no longer in use. name is
    44  // for debugging purposes only.
    45  func (s *Store) Generation(name string) *Generation {
    46  	s.mu.Lock()
    47  	defer s.mu.Unlock()
    48  	if s.handles == nil {
    49  		s.handles = map[interface{}]*Handle{}
    50  		s.generations = map[*Generation]struct{}{}
    51  	}
    52  	g := &Generation{store: s, name: name}
    53  	s.generations[g] = struct{}{}
    54  	return g
    55  }
    56  
    57  // A Generation is a logical point in time of the cache life-cycle. Cache
    58  // entries associated with a Generation will not be removed until the
    59  // Generation is destroyed.
    60  type Generation struct {
    61  	// destroyed is 1 after the generation is destroyed. Atomic.
    62  	destroyed uint32
    63  	store     *Store
    64  	name      string
    65  	// destroyedBy describes the caller that togged destroyed from 0 to 1.
    66  	destroyedBy string
    67  	// wg tracks the reference count of this generation.
    68  	wg sync.WaitGroup
    69  }
    70  
    71  // Destroy waits for all operations referencing g to complete, then removes
    72  // all references to g from cache entries. Cache entries that no longer
    73  // reference any non-destroyed generation are removed. Destroy must be called
    74  // exactly once for each generation, and destroyedBy describes the caller.
    75  func (g *Generation) Destroy(destroyedBy string) {
    76  	g.wg.Wait()
    77  
    78  	prevDestroyedBy := g.destroyedBy
    79  	g.destroyedBy = destroyedBy
    80  	if ok := atomic.CompareAndSwapUint32(&g.destroyed, 0, 1); !ok {
    81  		panic("Destroy on generation " + g.name + " already destroyed by " + prevDestroyedBy)
    82  	}
    83  
    84  	g.store.mu.Lock()
    85  	defer g.store.mu.Unlock()
    86  	for k, e := range g.store.handles {
    87  		e.mu.Lock()
    88  		if _, ok := e.generations[g]; ok {
    89  			delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry.
    90  			if len(e.generations) == 0 {
    91  				delete(g.store.handles, k)
    92  				e.state = stateDestroyed
    93  				if e.cleanup != nil && e.value != nil {
    94  					e.cleanup(e.value)
    95  				}
    96  			}
    97  		}
    98  		e.mu.Unlock()
    99  	}
   100  	delete(g.store.generations, g)
   101  }
   102  
   103  // Acquire creates a new reference to g, and returns a func to release that
   104  // reference.
   105  func (g *Generation) Acquire() func() {
   106  	destroyed := atomic.LoadUint32(&g.destroyed)
   107  	if destroyed != 0 {
   108  		panic("acquire on generation " + g.name + " destroyed by " + g.destroyedBy)
   109  	}
   110  	g.wg.Add(1)
   111  	return g.wg.Done
   112  }
   113  
   114  // Arg is a marker interface that can be embedded to indicate a type is
   115  // intended for use as a Function argument.
   116  type Arg interface{ memoizeArg() }
   117  
   118  // Function is the type for functions that can be memoized.
   119  // The result must be a pointer.
   120  type Function func(ctx context.Context, arg Arg) interface{}
   121  
   122  type state int
   123  
   124  const (
   125  	stateIdle = iota
   126  	stateRunning
   127  	stateCompleted
   128  	stateDestroyed
   129  )
   130  
   131  // Handle is returned from a store when a key is bound to a function.
   132  // It is then used to access the results of that function.
   133  //
   134  // A Handle starts out in idle state, waiting for something to demand its
   135  // evaluation. It then transitions into running state. While it's running,
   136  // waiters tracks the number of Get calls waiting for a result, and the done
   137  // channel is used to notify waiters of the next state transition. Once the
   138  // evaluation finishes, value is set, state changes to completed, and done
   139  // is closed, unblocking waiters. Alternatively, as Get calls are cancelled,
   140  // they decrement waiters. If it drops to zero, the inner context is cancelled,
   141  // computation is abandoned, and state resets to idle to start the process over
   142  // again.
   143  type Handle struct {
   144  	key interface{}
   145  	mu  sync.Mutex
   146  
   147  	// generations is the set of generations in which this handle is valid.
   148  	generations map[*Generation]struct{}
   149  
   150  	state state
   151  	// done is set in running state, and closed when exiting it.
   152  	done chan struct{}
   153  	// cancel is set in running state. It cancels computation.
   154  	cancel context.CancelFunc
   155  	// waiters is the number of Gets outstanding.
   156  	waiters uint
   157  	// the function that will be used to populate the value
   158  	function Function
   159  	// value is set in completed state.
   160  	value interface{}
   161  	// cleanup, if non-nil, is used to perform any necessary clean-up on values
   162  	// produced by function.
   163  	cleanup func(interface{})
   164  }
   165  
   166  // Bind returns a handle for the given key and function.
   167  //
   168  // Each call to bind will return the same handle if it is already bound. Bind
   169  // will always return a valid handle, creating one if needed. Each key can
   170  // only have one handle at any given time. The value will be held at least
   171  // until the associated generation is destroyed. Bind does not cause the value
   172  // to be generated.
   173  //
   174  // If cleanup is non-nil, it will be called on any non-nil values produced by
   175  // function when they are no longer referenced.
   176  func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle {
   177  	// panic early if the function is nil
   178  	// it would panic later anyway, but in a way that was much harder to debug
   179  	if function == nil {
   180  		panic("the function passed to bind must not be nil")
   181  	}
   182  	if atomic.LoadUint32(&g.destroyed) != 0 {
   183  		panic("operation on generation " + g.name + " destroyed by " + g.destroyedBy)
   184  	}
   185  	g.store.mu.Lock()
   186  	defer g.store.mu.Unlock()
   187  	h, ok := g.store.handles[key]
   188  	if !ok {
   189  		h := &Handle{
   190  			key:         key,
   191  			function:    function,
   192  			generations: map[*Generation]struct{}{g: {}},
   193  			cleanup:     cleanup,
   194  		}
   195  		g.store.handles[key] = h
   196  		return h
   197  	}
   198  	h.mu.Lock()
   199  	defer h.mu.Unlock()
   200  	if _, ok := h.generations[g]; !ok {
   201  		h.generations[g] = struct{}{}
   202  	}
   203  	return h
   204  }
   205  
   206  // Stats returns the number of each type of value in the store.
   207  func (s *Store) Stats() map[reflect.Type]int {
   208  	s.mu.Lock()
   209  	defer s.mu.Unlock()
   210  
   211  	result := map[reflect.Type]int{}
   212  	for k := range s.handles {
   213  		result[reflect.TypeOf(k)]++
   214  	}
   215  	return result
   216  }
   217  
   218  // DebugOnlyIterate iterates through all live cache entries and calls f on them.
   219  // It should only be used for debugging purposes.
   220  func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
   221  	s.mu.Lock()
   222  	defer s.mu.Unlock()
   223  
   224  	for k, e := range s.handles {
   225  		var v interface{}
   226  		e.mu.Lock()
   227  		if e.state == stateCompleted {
   228  			v = e.value
   229  		}
   230  		e.mu.Unlock()
   231  		if v == nil {
   232  			continue
   233  		}
   234  		f(k, v)
   235  	}
   236  }
   237  
   238  func (g *Generation) Inherit(hs ...*Handle) {
   239  	for _, h := range hs {
   240  		if atomic.LoadUint32(&g.destroyed) != 0 {
   241  			panic("inherit on generation " + g.name + " destroyed by " + g.destroyedBy)
   242  		}
   243  
   244  		h.mu.Lock()
   245  		defer h.mu.Unlock()
   246  		if h.state == stateDestroyed {
   247  			panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
   248  		}
   249  		h.generations[g] = struct{}{}
   250  	}
   251  }
   252  
   253  // Cached returns the value associated with a handle.
   254  //
   255  // It will never cause the value to be generated.
   256  // It will return the cached value, if present.
   257  func (h *Handle) Cached(g *Generation) interface{} {
   258  	h.mu.Lock()
   259  	defer h.mu.Unlock()
   260  	if _, ok := h.generations[g]; !ok {
   261  		return nil
   262  	}
   263  	if h.state == stateCompleted {
   264  		return h.value
   265  	}
   266  	return nil
   267  }
   268  
   269  // Get returns the value associated with a handle.
   270  //
   271  // If the value is not yet ready, the underlying function will be invoked.
   272  // If ctx is cancelled, Get returns nil.
   273  func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
   274  	release := g.Acquire()
   275  	defer release()
   276  
   277  	if ctx.Err() != nil {
   278  		return nil, ctx.Err()
   279  	}
   280  	h.mu.Lock()
   281  	if _, ok := h.generations[g]; !ok {
   282  		h.mu.Unlock()
   283  
   284  		err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name)
   285  		if *panicOnDestroyed && ctx.Err() != nil {
   286  			panic(err)
   287  		}
   288  		return nil, err
   289  	}
   290  	switch h.state {
   291  	case stateIdle:
   292  		return h.run(ctx, g, arg)
   293  	case stateRunning:
   294  		return h.wait(ctx)
   295  	case stateCompleted:
   296  		defer h.mu.Unlock()
   297  		return h.value, nil
   298  	case stateDestroyed:
   299  		h.mu.Unlock()
   300  		err := fmt.Errorf("Get on destroyed entry %#v (type %T) in generation %v", h.key, h.key, g.name)
   301  		if *panicOnDestroyed {
   302  			panic(err)
   303  		}
   304  		return nil, err
   305  	default:
   306  		panic("unknown state")
   307  	}
   308  }
   309  
   310  // run starts h.function and returns the result. h.mu must be locked.
   311  func (h *Handle) run(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
   312  	childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
   313  	h.cancel = cancel
   314  	h.state = stateRunning
   315  	h.done = make(chan struct{})
   316  	function := h.function // Read under the lock
   317  
   318  	// Make sure that the generation isn't destroyed while we're running in it.
   319  	release := g.Acquire()
   320  	go func() {
   321  		defer release()
   322  		// Just in case the function does something expensive without checking
   323  		// the context, double-check we're still alive.
   324  		if childCtx.Err() != nil {
   325  			return
   326  		}
   327  		v := function(childCtx, arg)
   328  		if childCtx.Err() != nil {
   329  			// It's possible that v was computed despite the context cancellation. In
   330  			// this case we should ensure that it is cleaned up.
   331  			if h.cleanup != nil && v != nil {
   332  				h.cleanup(v)
   333  			}
   334  			return
   335  		}
   336  
   337  		h.mu.Lock()
   338  		defer h.mu.Unlock()
   339  		// It's theoretically possible that the handle has been cancelled out
   340  		// of the run that started us, and then started running again since we
   341  		// checked childCtx above. Even so, that should be harmless, since each
   342  		// run should produce the same results.
   343  		if h.state != stateRunning {
   344  			// v will never be used, so ensure that it is cleaned up.
   345  			if h.cleanup != nil && v != nil {
   346  				h.cleanup(v)
   347  			}
   348  			return
   349  		}
   350  		// At this point v will be cleaned up whenever h is destroyed.
   351  		h.value = v
   352  		h.function = nil
   353  		h.state = stateCompleted
   354  		close(h.done)
   355  	}()
   356  
   357  	return h.wait(ctx)
   358  }
   359  
   360  // wait waits for the value to be computed, or ctx to be cancelled. h.mu must be locked.
   361  func (h *Handle) wait(ctx context.Context) (interface{}, error) {
   362  	h.waiters++
   363  	done := h.done
   364  	h.mu.Unlock()
   365  
   366  	select {
   367  	case <-done:
   368  		h.mu.Lock()
   369  		defer h.mu.Unlock()
   370  		if h.state == stateCompleted {
   371  			return h.value, nil
   372  		}
   373  		return nil, nil
   374  	case <-ctx.Done():
   375  		h.mu.Lock()
   376  		defer h.mu.Unlock()
   377  		h.waiters--
   378  		if h.waiters == 0 && h.state == stateRunning {
   379  			h.cancel()
   380  			close(h.done)
   381  			h.state = stateIdle
   382  			h.done = nil
   383  			h.cancel = nil
   384  		}
   385  		return nil, ctx.Err()
   386  	}
   387  }